loki-ecmwf-0.3.6/0000775000175000017500000000000015167130250013741 5ustar alastairalastairloki-ecmwf-0.3.6/loki-post-import.cmake.in0000664000175000017500000000302015167130205020574 0ustar alastairalastair# (C) Copyright 2018- ECMWF. # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. # Basic path setup if( @PROJECT_NAME@_IS_BUILD_DIR_EXPORT ) set( loki_MACROS_DIR @CMAKE_CURRENT_SOURCE_DIR@/cmake ) set( loki_VENV_PATH @CMAKE_CURRENT_BINARY_DIR@/loki_env ) else() set( loki_MACROS_DIR ${loki_BASE_DIR}/share/loki/cmake ) set( loki_VENV_PATH ${loki_BASE_DIR}/var/loki_env ) endif() # Make Loki CMake macro scripts available on the search path if( NOT ${loki_MACROS_DIR} IN_LIST CMAKE_MODULE_PATH ) list( INSERT CMAKE_MODULE_PATH 0 ${loki_MACROS_DIR} ) endif() # Carry over variables from the build set( loki_HAVE_NO_INSTALL @loki_HAVE_NO_INSTALL@ ) set( loki_HAVE_EDITABLE @loki_HAVE_EDITABLE@ ) set( loki_HAVE_OMNI @loki_HAVE_OMNI@ ) set( LOKI_EXECUTABLES @LOKI_EXECUTABLES@ ) # Find Python environment if( NOT ${loki_HAVE_NO_INSTALL} ) # Detect the installed virtual environment include( loki_python_macros ) loki_find_python_venv( VENV_PATH ${loki_VENV_PATH} PYTHON_VERSION @PYTHON_VERSION@ ) endif() # Discover Loki executables and make available as CMake targets include( loki_find_executables ) loki_find_executables() # Make the Loki transformation functions available include( loki_transform ) loki-ecmwf-0.3.6/VERSION0000664000175000017500000000000615167130205015005 0ustar alastairalastair0.3.6 loki-ecmwf-0.3.6/requirements.txt0000664000175000017500000000000215167130205017215 0ustar alastairalastair. loki-ecmwf-0.3.6/pyproject.toml0000664000175000017500000000475715167130205016672 0ustar alastairalastair# Make sure we use setuptools and have all required dependencies for that [build-system] requires = [ "setuptools >= 75.0.0", "wheel", "setuptools_scm[toml] >= 6.2", ] build-backend = "setuptools.build_meta" [project] name = "loki" authors = [ {name = "ECMWF", email = "user_support_section@ecmwf.int"}, ] description = "Experimental Fortran IR to facilitate source-to-source transformations" requires-python = ">=3.8" license = {text = "Apache-2.0"} dynamic = ["version", "readme"] dependencies = [ "numpy >= 2.0", # essential for tests, loop transformations and other dependencies "pymbolic==2022.2", # essential for expression tree "PyYAML", # essential for loki-lint "pcpp", # essential for preprocessing "more-itertools", # essential for SCC transformation "click", # essential for CLI scripts "click-option-group", # essential for CLI scripts "tomli ; python_version < '3.11'", # essential for scheduler configuration "networkx", # essential for scheduler and build utilities "fparser>=0.0.15", # (almost) essential as frontend "graphviz", # optional for scheduler callgraph "tqdm", # optional for build utilities "coloredlogs", # optional for loki-build utility "junit_xml", # optional for JunitXML output in loki-lint "codetiming", # essential for scheduler and sourcefile timings "pydantic>=2.0,<2.10.0", # type checking for IR nodes ] [project.optional-dependencies] tests = [ "pytest", "pytest-cov", "pytest-xdist", "coverage2clover", "pylint!=2.11.0,!=2.11.1", "pandas", "f90wrap>=0.2.15,<0.3.0", "nbconvert", "tomli_w", ] dace = [ "dace>=1.0; python_version < '3.13'", ] docs = [ "sphinx", # to build documentation "recommonmark", # to allow parsing markdown "sphinx-rtd-theme", # ReadTheDocs theme "myst-parser", # Markdown parser for sphinx "nbsphinx", # Jupyter notebook parser for sphinx "sphinx-design", # Add panels, cards and dropdowns for sphinx ] examples = [ "jupyter", "ipyparams", ] [project.scripts] "loki-transform.py" = "loki.cli.loki_transform:cli" "loki-lint.py" = "loki.cli.loki_lint:cli" [tool.setuptools] license-files = ["LICENSE", "AUTHORS.md"] [tool.setuptools.dynamic] readme = {file = ["README.md", "INSTALL.md"], content-type = "text/markdown"} [tool.setuptools.packages.find] where = ["."] include = [ "loki", "loki.*", "scripts" ] exclude = [ "build*", "cmake*", "docs*", "example*", "lint_rules*", "loki_env*", ] namespaces = false # Enable SCM versioning [tool.setuptools_scm] loki-ecmwf-0.3.6/lint_rules/0000775000175000017500000000000015167130205016121 5ustar alastairalastairloki-ecmwf-0.3.6/lint_rules/tests/0000775000175000017500000000000015167130205017263 5ustar alastairalastairloki-ecmwf-0.3.6/lint_rules/tests/test_ifs_coding_standards_2011.py0000664000175000017500000005521515167130205025516 0ustar alastairalastair# (C) Copyright 2018- ECMWF. # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. import importlib from pathlib import Path import pytest from conftest import run_linter, available_frontends from loki import Sourcefile from loki.lint import DefaultHandler pytestmark = pytest.mark.skipif(not available_frontends(), reason='Suitable frontend not available') @pytest.fixture(scope='module', name='rules') def fixture_rules(): rules = importlib.import_module('lint_rules.ifs_coding_standards_2011') return rules @pytest.mark.parametrize('frontend', available_frontends()) @pytest.mark.parametrize('nesting_depth, lines', [ (3, []), (2, [6, 12, 16, 22, 28, 35]), (1, [5, 6, 10, 12, 16, 22, 27, 28, 34, 35])]) def test_code_body_messages(rules, frontend, nesting_depth, lines): ''' Test the number and content of messages generated by CodeBodyRule for different nesting depths. ''' fcode = """ subroutine routine_nesting(a, b, c, d, e) integer, intent(in) :: a, b, c, d, e if (a > 3) then if (b > 2) then if (c > 1) then print *, 'if-if-if' end if end if select case (d) case (0) if (e == 0) then print *, 'if-case-if' endif case (1:3) if (e == 0) then print *, 'if-range-if' else print *, 'if-range-else' endif case default if (e == 0) then print *, 'if-default-if' endif end select elseif (a == 3) then if (b > 2) then if (c > 1) then print *, 'elseif-if-if' end if end if else if (e == 0) print *, 'else-inlineif' if (b > 2) then if (c > 1) then print *, 'else-if-if' end if end if end if end subroutine routine_nesting """.strip() source = Sourcefile.from_source(fcode, frontend=frontend) messages = [] handler = DefaultHandler(target=messages.append) config = {'CodeBodyRule': {'max_nesting_depth': nesting_depth}} _ = run_linter(source, [rules.CodeBodyRule], config=config, handlers=[handler]) assert len(messages) == len(lines) keywords = ('CodeBodyRule', '[1.3]') assert all(all(keyword in msg for keyword in keywords) for msg in messages) for msg, ref_line in zip(messages, lines): assert f'limit of {nesting_depth}' in msg assert f'l. {ref_line}' in msg @pytest.mark.parametrize('frontend', available_frontends()) def test_module_naming(rules, frontend): '''Test file and modules for checking that naming is correct and matches each other.''' fcode = """ ! This is ok module module_naming_mod integer foo contains subroutine bar integer foobar end subroutine bar end module module_naming_mod ! This should complain about wrong file name module MODULE_NAMING_UPPERCASE_MOD integer foo contains subroutine bar integer foobar end subroutine bar end module MODULE_NAMING_UPPERCASE_MOD ! This should complain about wrong module and file name module module_naming integer baz end module module_naming """.strip() source = Sourcefile.from_source(fcode, frontend=frontend) # We don't actually write the file but simply set the filename to something sensible for m in source.modules: m.source.file = str(Path(__file__).parent / 'module_naming_mod.f90') messages = [] handler = DefaultHandler(target=messages.append) _ = run_linter(source, [rules.ModuleNamingRule], handlers=[handler]) assert len(messages) == 3 keywords = ('ModuleNamingRule', '[1.5]') assert all(all(keyword in msg for keyword in keywords) for msg in messages) assert all('"module_naming' in msg.lower() for msg in messages) assert all(keyword in messages[0] for keyword in ('module_naming_mod.f90', 'filename')) assert all(keyword in messages[1] for keyword in ('"_mod"', 'Name of module')) assert all(keyword in messages[2] for keyword in ('module_naming_mod.f90', 'filename')) @pytest.mark.parametrize('frontend', available_frontends()) def test_dr_hook_okay(rules, frontend): fcode = """ subroutine routine_okay use yomhook, only: lhook, dr_hook real(kind=jprb) :: zhook_handle ! Comments are non-executable statements if (lhook) then #define foobar call dr_hook('routine_okay', 0, zhook_handle) end if print *, "Foo bar" if (lhook) call dr_hook('routine_okay', 1, zhook_handle) ! Comments are non-executable statements contains subroutine routine_contained_okay real(kind=jprb) :: zhook_handle ! CPP directives should be ignored #ifndef _some_macro if (lhook) call dr_hook('routine_okay%routine_contained_okay', 0, zhook_handle) print *, "Foo bar" if (lhook) call dr_hook('routine_okay%routine_contained_okay', 1, zhook_handle) ! CPP directives should be ignored #endif end subroutine routine_contained_okay end subroutine routine_okay """.strip() source = Sourcefile.from_source(fcode, frontend=frontend) messages = [] handler = DefaultHandler(target=messages.append) _ = run_linter(source, [rules.DrHookRule], handlers=[handler]) assert len(messages) == 0 @pytest.mark.parametrize('frontend', available_frontends()) def test_dr_hook_routine(rules, frontend): fcode = """ subroutine routine_not_okay_a use yomhook, only: lhook, dr_hook real(kind=jprb) :: zhook_handle ! Error: no conditional IF(LHOOK) ! Error: no zhook_handle (Not detected because call not found) call dr_hook('routine_not_okay_a', 0) print *, "Foo bar" ! Error: subroutine name not in string argument if (lhook) call dr_hook('foobar', 1, zhook_handle) end subroutine routine_not_okay_a subroutine routine_not_okay_b use yomhook, only: lhook, dr_hook real(kind=jprb) :: zhook_handle ! Error: second argument is not 0 or 1 if (lhook) call dr_hook('routine_not_okay_b', 2, zhook_handle) print *, "Foo bar" ! Error: third argument is not zhook_handle if (lhook) call dr_hook('routine_not_okay_b', 1) end subroutine routine_not_okay_b subroutine routine_not_okay_c use yomhook, only: lhook, dr_hook real(kind=jprb) :: zhook_handle real(kind=jprb) :: red_herring red_herring = 1.0 ! Error: Executable statement before call to dr_hook if (lhook) call dr_hook('routine_not_okay_c', 2, zhook_handle) print *, "Foo bar" ! Error: Executable statement after call to dr_hook if (lhook) then call dr_hook('routine_not_okay_c', 1, zhook_handle) red_herring = 2.0 end if end subroutine routine_not_okay_c subroutine routine_not_okay_d use yomhook, only: lhook, dr_hook real(kind=jprb) :: zhook_handle real(kind=jprb) :: red_herring ! Error: First call to dr_hook is missing red_herring = 1.0 print *, "Foo bar" if (lhook) call dr_hook('routine_not_okay_d', 1, zhook_handle) end subroutine routine_not_okay_d subroutine routine_not_okay_e use yomhook, only: lhook, dr_hook real(kind=jprb) :: zhook_handle real(kind=jprb) :: red_herring if (lhook) call dr_hook('routine_not_okay_e', 0, zhook_handle) red_herring = 1.0 print *, "Foo bar" ! Error: Last call to dr_hook is missing contains subroutine routine_contained_not_okay use yomhook, only: lhook, dr_hook real(kind=jprb) :: zhook_handle real(kind=jprb) :: red_herring if (lhook) call dr_hook('routine_not_okay_e%routine_contained_not_okay', 0, zhook_handle) red_herring = 1.0 print *, "Foo bar" ! Error: String argument is not "%" if (lhook) call dr_hook('routine_contained_not_okay', 1, zhook_handle) end subroutine routine_contained_not_okay end subroutine routine_not_okay_e """.strip() source = Sourcefile.from_source(fcode, frontend=frontend) messages = [] handler = DefaultHandler(target=messages.append) _ = run_linter(source, [rules.DrHookRule], handlers=[handler]) assert len(messages) == 9 keywords = ('DrHookRule', 'DR_HOOK', '[1.9]') assert all(all(keyword in msg for keyword in keywords) for msg in messages) assert all('First executable statement must be call to DR_HOOK' in messages[i] for i in [0, 4, 6]) assert all('Last executable statement must be call to DR_HOOK' in messages[i] for i in [5, 7]) assert all('String argument to DR_HOOK call should be "' in messages[i] for i in [1, 8]) assert 'Second argument to DR_HOOK call should be "0"' in messages[2] assert 'Third argument to DR_HOOK call should be "ZHOOK_HANDLE"' in messages[3] # Later lines come first as modules are checked before subroutines assert '(l. 12)' in messages[1] assert '(l. 21)' in messages[2] assert '(l. 26)' in messages[3] assert '(l. 91)' in messages[8] assert all(f'routine_not_okay_{letter}' in messages[i] for letter, i in (('a', 0), ('c', 4), ('c', 5), ('d', 6), ('e', 7))) @pytest.mark.parametrize('frontend', available_frontends()) def test_dr_hook_module(rules, frontend): fcode = """ module some_mod contains subroutine mod_routine_okay use yomhook, only: lhook, dr_hook real(kind=jprb) :: zhook_handle if (lhook) call dr_hook('some_mod:mod_routine_okay', 0, zhook_handle) print *, "Foo bar" if (lhook) call dr_hook('some_mod:mod_routine_okay', 1, zhook_handle) contains subroutine mod_contained_routine_okay use yomhook, only: lhook, dr_hook real(kind=jprb) :: zhook_handle if (lhook) call dr_hook('some_mod:mod_routine_okay%mod_contained_routine_okay', 0, zhook_handle) print *, "Foo bar" if (lhook) call dr_hook('some_mod:mod_routine_okay%mod_contained_routine_okay', 1, zhook_handle) end subroutine mod_contained_routine_okay end subroutine mod_routine_okay subroutine mod_routine_not_okay use yomhook, only: lhook, dr_hook real(kind=jprb) :: zhook_handle ! Error: String argument does not contain module name if (lhook) call dr_hook('mod_routine_okay', 0, zhook_handle) print *, "Foo bar" if (lhook) call dr_hook('some_mod:mod_routine_not_okay', 1, zhook_handle) contains subroutine mod_contained_routine_not_okay use yomhook, only: lhook, dr_hook real(kind=jprb) :: zhook_handle ! Error: String argument does not contain module name if (lhook) call dr_hook('mod_routine_not_okay%mod_contained_routine_not_okay', 0, zhook_handle) print *, "Foo bar" ! Error: String argument does not contain parent routine name ! Error: Second argument is not 0 or 1 if (lhook) call dr_hook('some_mod:mod_contained_routine_not_okay', 8, zhook_handle) end subroutine mod_contained_routine_not_okay end subroutine mod_routine_not_okay end module some_mod """.strip() source = Sourcefile.from_source(fcode, frontend=frontend) messages = [] handler = DefaultHandler(target=messages.append) _ = run_linter(source, [rules.DrHookRule], handlers=[handler]) assert len(messages) == 4 keywords = ('DrHookRule', 'DR_HOOK', '[1.9]') assert all(all(keyword in msg for keyword in keywords) for msg in messages) assert all('String argument to DR_HOOK call should be "' in messages[i] for i in [0, 1, 2]) assert 'Second argument to DR_HOOK call should be "1"' in messages[3] # Later lines come first as modules are checked before subroutines assert '(l. 30)' in messages[0] assert '(l. 41)' in messages[1] assert '(l. 45)' in messages[2] assert '(l. 45)' in messages[3] @pytest.mark.parametrize('frontend', available_frontends()) @pytest.mark.parametrize('max_num_statements, passes', [ (10, True), (4, True), (3, False)]) def test_limit_subroutine_stmts(rules, frontend, max_num_statements, passes): '''Test for different maximum allowed number of executable statements and content of messages generated by LimitSubroutineStatementsRule.''' fcode = """ subroutine routine_limit_statements() integer :: a, b, c, d, e ! Non-exec statements #define some_macro print *, 'Hello world!' associate (aa=>a) aa = 1 b = 2 call some_routine(c, e) d = 4 end associate end subroutine routine_limit_statements """.strip() source = Sourcefile.from_source(fcode, frontend=frontend) messages = [] handler = DefaultHandler(target=messages.append) config = {'LimitSubroutineStatementsRule': {'max_num_statements': max_num_statements}} _ = run_linter(source, [rules.LimitSubroutineStatementsRule], config=config, handlers=[handler]) assert len(messages) == (0 if passes else 1) keywords = ('LimitSubroutineStatementsRule', '[2.2]', '4', str(max_num_statements), 'routine_limit_statements') assert all(all(keyword in msg for keyword in keywords) for msg in messages) @pytest.mark.parametrize('frontend', available_frontends()) @pytest.mark.parametrize('max_num_arguments, passes', [ (10, True), (8, True), (7, False), (1, False)]) def test_max_dummy_args(rules, frontend, max_num_arguments, passes): '''Test for different maximum allowed number of dummy arguments and content of messages generated by MaxDummyArgsRule.''' fcode = """ subroutine routine_max_dummy_args(a, b, c, d, e, f, g, h) integer, intent(in) :: a, b, c, d, e, f, g, h print *, a, b, c, d, e, f, g, h end subroutine routine_max_dummy_args """.strip() source = Sourcefile.from_source(fcode, frontend=frontend) messages = [] handler = DefaultHandler(target=messages.append) config = {'MaxDummyArgsRule': {'max_num_arguments': max_num_arguments}} _ = run_linter(source, [rules.MaxDummyArgsRule], config=config, handlers=[handler]) assert len(messages) == (0 if passes else 1) keywords = ('MaxDummyArgsRule', '[3.6]', '8', str(max_num_arguments), 'routine_max_dummy_args') assert all(all(keyword in msg for keyword in keywords) for msg in messages) @pytest.mark.parametrize('frontend', available_frontends()) def test_mpl_cdstring(rules, frontend): fcode = """ subroutine routine_okay use mpl_module call mpl_init(cdstring='routine_okay') end subroutine routine_okay subroutine routine_also_okay use MPL_MODULE call MPL_INIT(KPROCS=5, CDSTRING='routine_also_okay') end subroutine routine_also_okay subroutine routine_not_okay use mpl_module call mpl_init end subroutine routine_not_okay subroutine routine_also_not_okay use MPL_INIT call MPL_INIT(kprocs=5) end subroutine routine_also_not_okay """.strip() source = Sourcefile.from_source(fcode, frontend=frontend) messages = [] handler = DefaultHandler(target=messages.append) _ = run_linter(source, [rules.MplCdstringRule], handlers=[handler]) assert len(messages) == 2 assert all('[3.12]' in msg for msg in messages) assert all('MplCdstringRule' in msg for msg in messages) assert all('"CDSTRING"' in msg for msg in messages) assert all('MPL_INIT' in msg.upper() for msg in messages) assert sum('(l. 13)' in msg for msg in messages) == 1 assert sum('(l. 18)' in msg for msg in messages) == 1 @pytest.mark.parametrize('frontend', available_frontends()) def test_implicit_none(rules, frontend): fcode = """ subroutine routine_okay implicit none integer :: a a = 5 contains subroutine contained_routine_okay integer :: b b = 5 end subroutine contained_routine_okay end subroutine routine_okay module mod_okay implicit none contains subroutine contained_mod_routine_okay integer :: a a = 5 end subroutine contained_mod_routine_okay end module mod_okay subroutine routine_not_okay ! This should report integer :: a a = 5 contains subroutine contained_not_okay_routine_okay implicit none integer :: b b = 5 end subroutine contained_not_okay_routine_okay end subroutine routine_not_okay module mod_not_okay contains subroutine contained_mod_not_okay_routine_okay implicit none integer :: a a = 5 end subroutine contained_mod_not_okay_routine_okay end module mod_not_okay subroutine routine_also_not_okay ! This should report integer :: a a = 5 contains subroutine contained_routine_not_okay ! This should report integer :: b b = 5 end subroutine contained_routine_not_okay end subroutine routine_also_not_okay module mod_also_not_okay contains subroutine contained_mod_routine_not_okay ! This should report integer :: a a = 5 contains subroutine contained_contained_routine_not_okay ! This should report integer :: b b = 5 end subroutine contained_contained_routine_not_okay end subroutine contained_mod_routine_not_okay end module mod_also_not_okay """ source = Sourcefile.from_source(fcode, frontend=frontend) messages = [] handler = DefaultHandler(target=messages.append) _ = run_linter(source, [rules.ImplicitNoneRule], handlers=[handler]) assert len(messages) == 5 assert all('"IMPLICIT NONE"' in msg for msg in messages) assert all('[4.4]' in msg for msg in messages) assert sum('"routine_not_okay"' in msg for msg in messages) == 1 assert sum('"routine_also_not_okay"' in msg for msg in messages) == 1 assert sum('"contained_routine_not_okay"' in msg for msg in messages) == 1 assert sum('"contained_mod_routine_not_okay"' in msg for msg in messages) == 1 assert sum('"contained_contained_routine_not_okay"' in msg for msg in messages) == 1 @pytest.mark.parametrize('frontend', available_frontends()) def test_explicit_kind(rules, frontend): fcode = """ subroutine routine_okay use some_type_module, only : jpim, jprb integer(kind=jpim) :: i, j real(kind=jprb) :: a(3), b i = 1_JPIM + 7_JPIM j = 2_JPIM a(1:3) = 3._JPRB b = 4.0_JPRB do j=1,3 a(j) = real(j) end do end subroutine routine_okay subroutine routine_not_okay integer :: i integer(kind=1) :: j real :: a(3) real(kind=8) :: b i = 1 + 7 j = 2 a(1:3) = 3e0 b = 4.0 + 5d0 + 6._4 end subroutine routine_not_okay """.strip() source = Sourcefile.from_source(fcode, frontend=frontend) messages = [] handler = DefaultHandler(target=messages.append) # Need to include INTEGER constants in config as (temporarily) removed from defaults config = {'ExplicitKindRule': {'constant_types': ['REAL', 'INTEGER']}} _ = run_linter(source, [rules.ExplicitKindRule], config=config, handlers=[handler]) # Note: This creates one message too many, namely the literal '4' in the constant # 6._4. This is because we represent the kind parameter as an expression (which can be # an imported name, for example). Since '4' (or any other literals) are not allowed kind # values in IFS this should not be a problem in practice: it will simply create an # additional spurious error in that case assert len(messages) == 12 assert all('[4.7]' in msg for msg in messages) assert all('ExplicitKindRule' in msg for msg in messages) # Keywords to search for in the messages as tuples: # ('var name' or 'literal', 'line number', 'invalid kind value' or None) keywords = ( # Declarations ('i', '16', None), ('j', '17', '1'), ('a(3)', '18', None), ('b', '19', '8'), # Literals ('1', '21', None), ('7', '21', None), ('2', '22', None), ('3e0', '23', None), ('4.0', '24', None), ('5d0', '24', None), ('4', '24', None), ('6._4', '24', '4') ) for keys, msg in zip(keywords, messages): assert all(kw in msg for kw in keys if kw is not None) @pytest.mark.parametrize('frontend', available_frontends()) def test_banned_statements_default(rules, frontend): '''Test for banned statements with default.''' fcode = """ subroutine banned_statements() integer :: dummy dummy = 5 call foobar(dummy) go to 100 print *, dummy 100 continue end subroutine banned_statements """ source = Sourcefile.from_source(fcode, frontend=frontend) messages = [] handler = DefaultHandler(target=messages.append) _ = run_linter(source, [rules.BannedStatementsRule], handlers=[handler]) assert len(messages) == 3 keywords = ('BannedStatementsRule', '[4.11]') assert all(all(keyword in msg for keyword in keywords) for msg in messages) banned_statements = ('GO TO', 'PRINT', 'CONTINUE') assert all(any(keyword in msg for keyword in banned_statements) for msg in messages) @pytest.mark.parametrize('frontend', available_frontends()) @pytest.mark.parametrize('banned_statements, passes', [ ([], True), (['GO TO'], False), (['GO TO', 'RETURN'], False), (['RETURN'], True)]) def test_banned_statements_config(rules, frontend, banned_statements, passes): '''Test for banned statements with custom config.''' fcode = """ subroutine banned_statements() integer :: dummy dummy = 5 call foobar(dummy) go to 100 print *, dummy 100 continue end subroutine banned_statements """ source = Sourcefile.from_source(fcode, frontend=frontend) messages = [] handler = DefaultHandler(target=messages.append) config = {'BannedStatementsRule': {'banned': banned_statements}} _ = run_linter(source, [rules.BannedStatementsRule], config=config, handlers=[handler]) assert len(messages) == (0 if passes else 1) keywords = ('BannedStatementsRule', 'GO TO', '[4.11]') assert all(all(keyword in msg for keyword in keywords) for msg in messages) @pytest.mark.parametrize('frontend', available_frontends()) def test_fortran_90_operators(rules, frontend): '''Test for existence of non Fortran 90 comparison operators.''' fcode = """ subroutine test_routine(ia, ib, ic) integer, intent(in) :: ia, ib, ic ! This should produce 6 problems (one for each operator) do while (ia .ge. 3 .or. ia .le. -7) if (ib .gt. 5 .or. ib .lt. -1) then if (ic .eq. 4 .and. ib .ne. -2) then print *, 'Foo' end if end if end do ! This should produce no problems do while (ia >= 3 .or. ia <= -7) if (ib > 5 .or. ib < -1) then if (ic == 4 .and. ib /= -2) then print *, 'Foo' end if end if end do ! This should report 5 problems do while (ia >= 3 .or. & ! This <= should not cause confusion ia .le. -7) if (ib .gt. 5 .or. ib <= -1) then if (ic .gt. 4 .and. ib == -2) then print *, 'Foo' end if elseif (ib .eq. 5) then print *, 'Bar' else if (ic .gt. 2) print *, 'Baz' end if end do end subroutine test_routine """.strip() source = Sourcefile.from_source(fcode, frontend=frontend) messages = [] handler = DefaultHandler(target=messages.append) _ = run_linter(source, [rules.Fortran90OperatorsRule], handlers=[handler]) assert len(messages) == 11 keywords = ('Fortran90OperatorsRule', '[4.15]', 'Use Fortran 90 comparison operator') assert all(all(keyword in msg for keyword in keywords) for msg in messages) # Check that violations are reported in the right order f77_f90_line = ( ('.le.', '<=', '5'), ('.ge.', '>=', '5'), ('.lt.', '<', '6'), ('.gt.', '>', '6'), ('.ne.', '/=', '7'), ('.eq.', '==', '7'), ('.le.', '<=', '23'), ('.gt.', '>', '25'), ('.gt.', '>', '26'), ('.eq.', '==', '29'), ('.gt.', '>', '32'), ) for keywords, message in zip(f77_f90_line, messages): assert all(str(keyword) in message for keyword in keywords) loki-ecmwf-0.3.6/lint_rules/tests/test_ifs_arpege_coding_standards.py0000664000175000017500000002176215167130205026376 0ustar alastairalastair# (C) Copyright 2018- ECMWF. # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. import importlib import pytest from conftest import run_linter, available_frontends from loki import Sourcefile from loki.lint import DefaultHandler pytestmark = pytest.mark.skipif(not available_frontends(), reason='Suitable frontend not available') @pytest.fixture(scope='module', name='rules') def fixture_rules(): rules = importlib.import_module('lint_rules.ifs_arpege_coding_standards') return rules @pytest.mark.parametrize('frontend', available_frontends()) def test_implicit_none(rules, frontend): fcode = """ subroutine routine_okay implicit none integer :: a a = 5 contains subroutine contained_routine_not_okay ! This should report integer :: b b = 5 end subroutine contained_routine_not_okay end subroutine routine_okay module mod_okay implicit none contains subroutine contained_mod_routine_okay integer :: a a = 5 contains subroutine contained_mod_routine_contained_routine_okay integer :: b b = 2 end subroutine contained_mod_routine_contained_routine_okay end subroutine contained_mod_routine_okay end module mod_okay subroutine routine_not_okay ! This should report integer :: a a = 5 contains subroutine contained_not_okay_routine_okay implicit none integer :: b b = 5 end subroutine contained_not_okay_routine_okay end subroutine routine_not_okay module mod_not_okay contains subroutine contained_mod_not_okay_routine_okay implicit none integer :: a a = 5 end subroutine contained_mod_not_okay_routine_okay end module mod_not_okay subroutine routine_also_not_okay ! This should report integer :: a a = 5 contains subroutine contained_routine_not_okay ! This should report integer :: b b = 5 end subroutine contained_routine_not_okay end subroutine routine_also_not_okay module mod_also_not_okay contains subroutine contained_mod_routine_not_okay ! This should report integer :: a a = 5 contains subroutine contained_contained_routine_not_okay ! This should report integer :: b b = 5 end subroutine contained_contained_routine_not_okay end subroutine contained_mod_routine_not_okay end module mod_also_not_okay """ source = Sourcefile.from_source(fcode, frontend=frontend) messages = [] handler = DefaultHandler(target=messages.append) run_linter(source, [rules.MissingImplicitNoneRule], handlers=[handler]) expected_messages = ( (['[L1]', 'MissingImplicitNoneRule', '`IMPLICIT NONE`', 'mod_not_okay', '(l. 40)']), (['[L1]', 'MissingImplicitNoneRule', '`IMPLICIT NONE`', 'mod_also_not_okay', '(l. 61)']), (['[L1]', 'MissingImplicitNoneRule', '`IMPLICIT NONE`', 'contained_mod_routine_not_okay', '(l. 63)']), (['[L1]', 'MissingImplicitNoneRule', '`IMPLICIT NONE`', 'contained_contained_routine_not_okay', '(l. 68)']), (['[L1]', 'MissingImplicitNoneRule', '`IMPLICIT NONE`', 'contained_routine_not_okay', '(l. 7)']), (['[L1]', 'MissingImplicitNoneRule', '`IMPLICIT NONE`', 'routine_not_okay', '(l. 28)']), (['[L1]', 'MissingImplicitNoneRule', '`IMPLICIT NONE`', 'routine_also_not_okay', '(l. 49)']), (['[L1]', 'MissingImplicitNoneRule', '`IMPLICIT NONE`', 'contained_routine_not_okay', '(l. 54)']), ) assert len(messages) == len(expected_messages) for msg, keywords in zip(messages, expected_messages): for keyword in keywords: assert keyword in msg @pytest.mark.parametrize('frontend', available_frontends()) def test_only_param_global_var_rule(rules, frontend): fcode = """ module some_mod use other_mod, only: some_type implicit none integer, parameter :: param_ok = 123 integer, parameter :: arr_param_ok(:) = (/ 1, 2, 3 /) integer :: var_not_ok integer, allocatable :: arr_not_ok(:), other_arr_not_ok(:,:) integer, pointer :: ptr_not_ok real, parameter :: rparam_ok = -42. type(some_type) :: dt_var_not_ok type(some_type) :: dt_arr_not_ok(2) end module some_mod """ source = Sourcefile.from_source(fcode, frontend=frontend) messages = [] handler = DefaultHandler(target=messages.append) run_linter(source, [rules.OnlyParameterGlobalVarRule], handlers=[handler]) expected_messages = ( (['L3', 'OnlyParameterGlobalVarRule', 'var_not_ok', '(l. 8)']), (['L3', 'OnlyParameterGlobalVarRule', 'arr_not_ok', 'other_arr_not_ok', '(l. 9)']), (['L3', 'OnlyParameterGlobalVarRule', 'ptr_not_ok', '(l. 10)']), (['L3', 'OnlyParameterGlobalVarRule', 'dt_var_not_ok', '(l. 12)']), (['L3', 'OnlyParameterGlobalVarRule', 'dt_arr_not_ok', '(l. 13)']), ) assert len(messages) == len(expected_messages) for msg, keywords in zip(messages, expected_messages): for keyword in keywords: assert keyword in msg def test_missing_intfb_rule_subroutine(rules): fcode = """ subroutine missing_intfb_rule(a, b, dt) use some_mod, only: imported_routine use other_mod, only: imported_func use type_mod, only: imported_type implicit none integer, intent(in) :: a, b type(imported_type), intent(in) :: dt #include "included_routine.intfb.h" integer :: local_var interface subroutine local_intf_routine(X) integer, intent(in) :: x end subroutine local_intf_routine end interface #include "included_func.intfb.h" #include "other_inc_func.func.h" CALL IMPORTED_ROUTINE(A) CALL INCLUDED_ROUTINE(B) CALL MISSING_ROUTINE(A, B) CALL LOCAL_INTF_ROUTINE(A) CALL DT%PROC(A+B) LOCAL_VAR = IMPORTED_FUNC(A) LOCAL_VAR = LOCAL_VAR + MIN(INCLUDED_FUNC(B), 1) LOCAL_VAR = LOCAL_VAR + MISSING_FUNC(A, B) LOCAL_VAR = LOCAL_VAR + DT%FUNC(A+B) LOCAL_VAR = LOCAL_VAR + OTHER_INC_FUNC(A, 'STR VAL') LOCAL_VAR = LOCAL_VAR + MISSING_INC_FUNC(A, 'STR VAL') end subroutine missing_intfb_rule """.strip() source = Sourcefile.from_source(fcode) messages = [] handler = DefaultHandler(target=messages.append) run_linter(source, [rules.MissingIntfbRule], handlers=[handler]) expected_messages = ( (['[L9]', 'MissingIntfbRule', '`missing_routine`', '(l. 20)']), # (['[L9]', 'MissingIntfbRule', 'MISSING_FUNC', '(l. 25)']), (['[L9]', 'MissingIntfbRule', '`missing_inc_func`', '(l. 28)']) # NB: # - The `missing_func` is not discovered because it is syntactically # indistinguishable from an Array subscript # - The `missing_inc_func` has a string argument and can therefore be # identified as an inline call by fparser # - Calls to type-bound procedures are not reported ) assert len(messages) == len(expected_messages) for msg, keywords in zip(messages, expected_messages): for keyword in keywords: assert keyword in msg def test_missing_intfb_rule_module(rules): fcode = """ module missing_intfb_rule_mod use external_mod, only: module_imported_routine, module_imported_func implicit none interface function local_intf_func() integer local_intf_func end function local_intf_func end interface #include "included_parent.intfb.h" contains subroutine missing_intfb_rule(a, b) use some_mod, only: imported_routine use other_mod, only: imported_func implicit none integer, intent(in) :: a, b #include "included_routine.intfb.h" integer :: local_var #include "included_func.intfb.h" CALL IMPORTED_ROUTINE(A) CALL INCLUDED_ROUTINE(B) CALL MODULE_IMPORTED_ROUTINE(A, B) CALL MISSING_ROUTINE(A, B) CALL INCLUDED_PARENT(A) call missing_routine(a, b) LOCAL_VAR = IMPORTED_FUNC(A) LOCAL_VAR = LOCAL_VAR + INCLUDED_FUNC(B) LOCAL_VAR = LOCAL_VAR + MISSING_FUNC(A, KEY=B) LOCAL_VAR = LOCAL_VAR + MAX(MODULE_IMPORTED_FUNC(KEY=A), -1) LOCAL_VAR = LOCAL_VAR + LOCAL_INTF_FUNC() end subroutine missing_intfb_rule end module missing_intfb_rule_mod """.strip() source = Sourcefile.from_source(fcode) messages = [] handler = DefaultHandler(target=messages.append) run_linter(source, [rules.MissingIntfbRule], handlers=[handler]) expected_messages = ( (['[L9]', 'MissingIntfbRule', '`missing_routine`', '(l. 24)']), (['[L9]', 'MissingIntfbRule', '`missing_func`', '(l. 29)']), # NB: # - The missing function is discovered here because # the use of a keyword-argument makes it syntactically # clear to be an inline call rather than an Array subscript # - MISSING_ROUTINE is only imported once for the first occurence # - We are not reporting the intrinsic Fortran routine MAX ) assert len(messages) == len(expected_messages) for msg, keywords in zip(messages, expected_messages): for keyword in keywords: assert keyword in msg loki-ecmwf-0.3.6/lint_rules/tests/conftest.py0000664000175000017500000000211215167130205021456 0ustar alastairalastair# (C) Copyright 2018- ECMWF. # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. from loki import HAVE_FP, FP from loki.lint import Reporter, Linter __all__ = ['available_frontends', 'run_linter'] def available_frontends(): """Choose frontend to use (Linter currently relies exclusively on Fparser)""" if HAVE_FP: return [FP,] return [] def run_linter(sourcefile, rule_list, config=None, handlers=None, targets=None): """ Run the linter for the given source file with the specified list of rules. """ reporter = Reporter(handlers) linter = Linter(reporter, rules=rule_list, config=config) report = linter.check(sourcefile, targets=targets) if config: if config.get('fix', None): linter.fix(sourcefile, report) return linter loki-ecmwf-0.3.6/lint_rules/tests/test_debug_rules.py0000664000175000017500000002130015167130205023170 0ustar alastairalastair# (C) Copyright 2018- ECMWF. # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. import os import importlib from pathlib import Path import pytest from conftest import run_linter, available_frontends from loki import Sourcefile, FindInlineCalls, FindNodes, VariableDeclaration from loki.lint import DefaultHandler pytestmark = pytest.mark.skipif(not available_frontends(), reason='Supported frontend not available') @pytest.fixture(scope='module', name='rules') def fixture_rules(): rules = importlib.import_module('lint_rules.debug_rules') return rules @pytest.mark.parametrize('frontend', available_frontends()) def test_arg_size_array_slices(rules, frontend): """ Test for argument size mismatch when arguments are passed as array slices. """ fcode_driver = """ subroutine driver(klon, klev, nblk, var0, var1, var2, var3, var4, var5, & var6, var7) use yomhook, only : lhook, dr_hook, jphook implicit none integer, intent(in) :: klon, klev, nblk real, intent(in) :: var2(:,:), var4(:,:), var5(:,:), var3(klon, 137), var5(klon, 138) real, intent(in) :: var6(:,:), var7(:,:) real, intent(inout) :: var0(klon, nblk), var1(klon, 138, nblk) real(kind=jphook) :: zhook_handle integer :: klev, ibl, iproma, iend if(lhook) call dr_hook('driver', 0, zhook_handle) associate(nlev => klev) nlev = 137 do ibl = 1, nblk iproma = klon iend = iproma call kernel(klon, nlev, var0(:,ibl), var1(:,:,ibl), var2(1:iend, 1:nlev), & var3, var4(1:klon, 1:nlev+1), var5(:, 1:nlev+1), & var6_d=var6, var7_d=var7(:,1:nlev)) enddo end associate if(lhook) call dr_hook('driver', 1, zhook_handle) end subroutine driver """.strip() fcode_kernel = """ subroutine kernel(klon, klev, var0_d, var1_d, var2_d, var3_d, var4_d, var5_d, var6_d, var7_d) use yomhook, only : lhook, dr_hook, jphook implicit none integer, intent(in) :: klon, klev real, dimension(klon, klev), intent(inout) :: var0_d, var1_d real, dimension(klon, klev), intent(in) :: var2_d, var3_d, var4_d real, dimension(klon, klev+1), intent(in) :: var5_d real, intent(in) :: var6_d(klon, klev), var7_d(klon, klev) real(kind=jphook) :: zhook_handle if(lhook) call dr_hook('kernel', 0, zhook_handle) if(lhook) call dr_hook('kernel', 1, zhook_handle) end subroutine kernel """.strip() driver_source = Sourcefile.from_source(fcode_driver, frontend=frontend) kernel_source = Sourcefile.from_source(fcode_kernel, frontend=frontend) driver = driver_source['driver'] kernel = kernel_source['kernel'] driver.enrich([kernel,]) messages = [] handler = DefaultHandler(target=messages.append) _ = run_linter(driver_source, [rules.ArgSizeMismatchRule], config={'ArgSizeMismatchRule': {'max_indirections': 3}}, handlers=[handler], targets=['kernel',]) assert len(messages) == 3 keyword = 'ArgSizeMismatchRule' assert all(keyword in msg for msg in messages) args = ('var0', 'var1', 'var4') for msg, ref_arg in zip(messages, args): assert f'arg: {ref_arg}' in msg assert f'dummy_arg: {ref_arg}_d' in msg @pytest.mark.parametrize('frontend', available_frontends()) def test_arg_size_array_sequence(rules, frontend): """ Test for argument size mismatch when arguments are passed as array sequences. """ fcode_driver = """ subroutine driver(klon, klev, nblk, var0, var1, var2, var3) use yomhook, only : lhook, dr_hook, jphook implicit none integer, intent(in) :: klon, klev, nblk real, intent(inout) :: var0(klon, nblk), var1(klon, 138, nblk) real, intent(in) :: var2(klon, 137), var3(klon*137) real(kind=jphook) :: zhook_handle real, dimension(klon, 137) :: var4, var5 real :: var6 integer :: klev, ibl, iproma, iend if(lhook) call dr_hook('driver', 0, zhook_handle) klev = 137 do ibl = 1, nblk iproma = klon iend = iproma call kernel(klon, klev, var0(1,ibl), var1(1,1,ibl), var2(1:iend, 1), var3(1), & var4(1, 1), var5, var6, 1, .true.) enddo if(lhook) call dr_hook('driver', 1, zhook_handle) end subroutine driver """.strip() fcode_kernel = """ subroutine kernel(klon, klev, var0_d, var1_d, var2_d, var3_d, var4_d, var5_d, var6_d, & int_arg, log_arg) use yomhook, only : lhook, dr_hook, jphook implicit none integer, intent(in) :: klon, klev real, dimension(klon, klev), intent(inout) :: var0_d, var1_d real, dimension(klon, klev), intent(in) :: var2_d, var3_d real, intent(out) :: var4_d, var5_d, var6_d(klon, klev) integer, intent(out) :: int_arg logical, intent(out) :: log_arg real(kind=jphook) :: zhook_handle if(lhook) call dr_hook('kernel', 0, zhook_handle) if(lhook) call dr_hook('kernel', 1, zhook_handle) end subroutine kernel """.strip() driver_source = Sourcefile.from_source(fcode_driver, frontend=frontend) kernel_source = Sourcefile.from_source(fcode_kernel, frontend=frontend) driver = driver_source['driver'] kernel = kernel_source['kernel'] driver.enrich([kernel,]) messages = [] handler = DefaultHandler(target=messages.append) _ = run_linter(driver_source, [rules.ArgSizeMismatchRule], handlers=[handler], targets=['kernel',]) assert len(messages) == 4 keyword = 'ArgSizeMismatchRule' assert all(keyword in msg for msg in messages) args = ('var0', 'var1', 'var5', 'var6') for msg, ref_arg in zip(messages, args): assert f'arg: {ref_arg}' in msg assert f'dummy_arg: {ref_arg}_d' in msg @pytest.mark.parametrize('frontend', available_frontends()) def test_dynamic_ubound_checks(rules, frontend): """ Test the run-time UBOUND checking linter rule """ fcode = """ subroutine kernel(klon, klev, nblk, var0, var1, var2, var3, var4) use abort_mod implicit none integer, intent(in) :: klon, klev, nblk real, dimension(:,:,:), intent(inout) :: var0, var1 real, dimension(:,:,:), intent(inout) :: var2 real, intent(inout) :: var3(:,:), var4(:,:,:) if(ubound(var0, 1) < klon)then call abort('kernel: first dimension of var0 too short') endif if(ubound(VAR0, 2) < klev)then call abort('kernel: second dimension of var0 too short') endif if(nblk > UBoUND(vAr0, 3))then call abort('kernel: third dimension of var0 too short') endif if(nblk > UBOUND(var1, 3))then call abort('kernel: third dimension of var1 too short') endif if(ubound(var2, 1) < klon .and. ubound(var2, 2) < klev .and. ubound(var2, 3) < nblk)then call abort('kernel: dimensions of var2 too short') endif if(ubound(var4, 1) < klon .and. ubound(var4, 2) < klev .and. ubound(var4, 3) < nblk)then call abort('kernel: dimensions of var4 too short') endif call some_other_kernel(klon, klen, nblk, var0, var1, var2, var3, var4) end subroutine kernel """.strip() kernel = Sourcefile.from_source(fcode, frontend=frontend) kernel.path = Path(__file__).parent / 'dynamic_ubound_test.F90' messages = [] handler = DefaultHandler(target=messages.append) _ = run_linter(kernel, [rules.DynamicUboundCheckRule], config={'fix': True}, handlers=[handler]) # check rule violations assert len(messages) == 3 assert all('DynamicUboundCheckRule' in msg for msg in messages) assert 'var0' in messages[0] assert 'var2' in messages[1] assert 'var4' in messages[2] # check fixed subroutine routine = kernel['kernel'] icalls = [call for call in FindInlineCalls(unique=False).visit(routine.body) if call.function == 'ubound'] assert len(icalls) == 1 shape = ('klon', 'klev', 'nblk') assert all(s.name == d for s, d in zip(routine.variable_map['var0'].shape, shape)) assert all(s.name == d for s, d in zip(routine.variable_map['var2'].shape, shape)) assert all(s.name == d for s, d in zip(routine.variable_map['var4'].shape, shape)) arg_names = ['klon', 'klev', 'nblk', 'var0', 'var1', 'var2', 'var3', 'var4'] assert [arg.name.lower() for arg in routine.arguments] == arg_names # check that variable declarations have not been duplicated declarations = FindNodes(VariableDeclaration).visit(routine.spec) symbols = [s.name.lower() for decl in declarations for s in decl.symbols] assert len(symbols) == 8 assert set(symbols) == {'klon', 'klev', 'nblk', 'var0', 'var1', 'var2', 'var3', 'var4'} # check number of declarations and symbols per declarations assert len(declarations) == 5 assert len(declarations[0].symbols) == 3 for decl in declarations[1:4]: assert len(decl.symbols) == 1 assert len(declarations[4].symbols) == 2 os.remove(kernel.path) loki-ecmwf-0.3.6/lint_rules/pyproject.toml0000664000175000017500000000115215167130205021034 0ustar alastairalastair[build-system] requires = [ "setuptools >= 61", "setuptools_scm[toml] >= 6.2", ] build-backend = "setuptools.build_meta" [project] name = "lint_rules" authors = [ {name = "ECMWF", email = "user_support_section@ecmwf.int"}, ] description = "Linter rule implementations for loki-lint" requires-python = ">=3.8" license = {text = "Apache-2.0"} dynamic = ["version"] dependencies = ["loki"] [tool.setuptools] license-files = ["LICENSE", "AUTHORS.md"] packages = ["lint_rules"] # Enable SCM versioning [tool.setuptools_scm] root = ".." relative_to = "__file__" [tool.pytest.ini_options] testpaths = [ "tests" ] loki-ecmwf-0.3.6/lint_rules/lint_rules/0000775000175000017500000000000015167130205020301 5ustar alastairalastairloki-ecmwf-0.3.6/lint_rules/lint_rules/__init__.py0000664000175000017500000000105115167130205022407 0ustar alastairalastair# (C) Copyright 2018- ECMWF. # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. from importlib.metadata import version, PackageNotFoundError try: __version__ = version("lint_rules") except PackageNotFoundError: # package is not installed pass loki-ecmwf-0.3.6/lint_rules/lint_rules/debug_rules.py0000664000175000017500000003176215167130205023164 0ustar alastairalastair# (C) Copyright 2018- ECMWF. # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. import operator as _op from loki import ( FindNodes, CallStatement, Assignment, Scalar, RangeIndex, do_resolve_associates, simplify, Sum, Product, IntLiteral, as_tuple, SubstituteExpressions, Array, symbolic_op, StringLiteral, is_constant, LogicLiteral, VariableDeclaration, flatten, FindInlineCalls, Conditional, FindExpressions, Comparison ) from loki.lint import GenericRule, RuleType class ArgSizeMismatchRule(GenericRule): """ Rule to check for argument size mismatch in subroutine/function calls """ type = RuleType.WARN config = { 'max_indirections': 2, } @staticmethod def range_to_sum(lower, upper): """ Method to convert lower and upper bounds of a :any:`RangeIndex` to a :any:`Sum` expression. """ return Sum((IntLiteral(1), upper, Product((IntLiteral(-1), lower)))) @staticmethod def compare_sizes(arg_size, alt_arg_size, dummy_arg_size): """ Compare all possible argument size candidates with dummy arg size. """ for i in range(len(arg_size) + 1): dims = tuple(alt_arg_size[:i]) dims += tuple(arg_size[i:]) dims = Product(dims) if symbolic_op(dims, _op.eq, dummy_arg_size): return True return False @classmethod def get_explicit_arg_size(cls, arg, dims): """ Method to return the size of a subroutine argument whose bounds are explicitly declared. """ if isinstance(arg, Scalar): size = as_tuple(IntLiteral(1)) else: size = () for dim in dims: if isinstance(dim, RangeIndex): size += as_tuple(simplify(cls.range_to_sum(dim.lower, dim.upper))) else: size += as_tuple(dim) return size @classmethod def get_implicit_arg_size(cls, arg, dims): """ Method to return the size of a subroutine argument whose bounds are potentially implicitly declared. """ size = () for count, dim in enumerate(dims): if isinstance(dim, RangeIndex): if not dim.upper: if isinstance(arg.shape[count], RangeIndex): upper = arg.shape[count].upper else: upper = arg.shape[count] else: upper = dim.upper if not dim.lower: if isinstance(arg.shape[count], RangeIndex): lower = arg.shape[count].lower else: lower = IntLiteral(1) else: lower = dim.lower size += as_tuple(cls.range_to_sum(lower, upper)) else: size += as_tuple(dim) return size @classmethod def check_subroutine(cls, subroutine, rule_report, config, **kwargs): """ Method to check for argument size mismatches across subroutine calls. It requires all :any:`CallStatement` nodes to be enriched, and requires all subroutine arguments *to not be* of type :any:`DeferredTypeSymbol`. Therefore relevant modules should be parsed before parsing the current :any:`Subroutine`. """ max_indirections = config['max_indirections'] # first resolve associates do_resolve_associates(subroutine) assign_map = {a.lhs: a.rhs for a in FindNodes(Assignment).visit(subroutine.body)} decl_symbols = flatten([decl.symbols for decl in FindNodes(VariableDeclaration).visit(subroutine.spec)]) decl_symbols = [sym for sym in decl_symbols if sym.type.initial] assign_map.update({sym: sym.initial for sym in decl_symbols}) targets = as_tuple(kwargs.get('targets', None)) calls = [c for c in FindNodes(CallStatement).visit(subroutine.body) if c.name in targets] for call in calls: # check if calls are enriched if not call.routine: continue arg_map = {carg: rarg for rarg, carg in call.arg_iter()} for arg in arg_map: if isinstance(arg_map[arg], Scalar): dummy_arg_size = as_tuple(IntLiteral(1)) else: # we can't proceed if dummy arg has assumed shape component if any(None in (dim.lower, dim.upper) for dim in arg_map[arg].shape if isinstance(dim, RangeIndex)): continue dummy_arg_size = cls.get_explicit_arg_size(arg_map[arg], arg_map[arg].shape) dummy_arg_size = SubstituteExpressions(dict(call.arg_iter())).visit(dummy_arg_size) # TODO: skip string literal args if isinstance(arg, StringLiteral): continue arg_size = () alt_arg_size = () # check if argument is scalar if isinstance(arg, (Scalar, LogicLiteral)) or is_constant(arg): arg_size += as_tuple(IntLiteral(1)) alt_arg_size += as_tuple(IntLiteral(1)) else: # check if arg has assumed size component if any(None in (dim.lower, dim.upper) for dim in arg.shape if isinstance(dim, RangeIndex)): # each dim must have explicit range-index to be sure of arg size if not arg.dimensions: continue if not all(isinstance(dim, RangeIndex) for dim in arg.dimensions): continue if any(None in (dim.lower, dim.upper) for dim in arg.dimensions): continue arg_size = cls.get_explicit_arg_size(arg, arg.dimensions) alt_arg_size = arg_size else: # compute dim sizes assuming single element if arg.dimensions: arg_size = cls.get_implicit_arg_size(arg, arg.dimensions) arg_size = as_tuple([IntLiteral(1) if not isinstance(a, Sum) else simplify(a) for a in arg_size]) else: arg_size = cls.get_explicit_arg_size(arg, arg.shape) # compute dim sizes assuming array sequence alt_arg_size = cls.get_implicit_arg_size(arg, arg.dimensions) ubounds = [dim.upper if isinstance(dim, RangeIndex) else dim for dim in arg.shape] alt_arg_size = as_tuple([simplify(Sum((Product((IntLiteral(-1), a)), ubounds[i], IntLiteral(1)))) if not isinstance(a, Sum) else simplify(a) for i, a in enumerate(alt_arg_size)]) alt_arg_size += cls.get_explicit_arg_size(arg, arg.shape[len(arg.dimensions):]) # first check using unmodified dimension names dummy_size = Product(dummy_arg_size) stat = cls.compare_sizes(arg_size, alt_arg_size, dummy_size) # we check for a configurable number of indirections for the dummy and arg dimension names for _ in range(max_indirections): if stat: break # if necessary, update dummy arg dimension names and check dummy_arg_size = SubstituteExpressions(assign_map).visit(dummy_arg_size) dummy_size = Product(dummy_arg_size) stat = cls.compare_sizes(arg_size, alt_arg_size, dummy_size) if stat: break # if necessary, update arg dimension names and check arg_size = SubstituteExpressions(assign_map).visit(arg_size) alt_arg_size = SubstituteExpressions(assign_map).visit(alt_arg_size) stat = cls.compare_sizes(arg_size, alt_arg_size, dummy_size) if not stat: msg = f'Size mismatch:: arg: {arg}, dummy_arg: {arg_map[arg]} ' msg += f'in {call} in {subroutine}' rule_report.add(msg, call) class DynamicUboundCheckRule(GenericRule): """ Rule to check for run-time ubound checks for assumed shape dummy arguments """ type = RuleType.WARN fixable = True @staticmethod def is_assumed_shape(arg): """ Method to check if argument is an assumed shape array. """ if all(isinstance(dim, RangeIndex) for dim in arg.shape): return all(dim.upper is None and dim.lower is None for dim in arg.shape) return False @staticmethod def get_ubound_checks(subroutine): """ Method to return UBOUND checks nested within a :any:`Conditional`. """ cond_map = {cond: FindInlineCalls(unique=False).visit(cond.condition) for cond in FindNodes(Conditional).visit(subroutine.body)} return {call: cond for cond, calls in cond_map.items() for call in calls} @classmethod def get_assumed_shape_args(cls, subroutine): """ Method to return all assumed-shape dummy arguments in a :any:`Subroutine`. """ args = [arg for arg in subroutine.arguments if isinstance(arg, Array)] return [arg for arg in args if cls.is_assumed_shape(arg)] @classmethod def check_subroutine(cls, subroutine, rule_report, config, **kwargs): """ Method to check for run-time ubound checks for assumed shape dummy arguments """ ubound_checks = cls.get_ubound_checks(subroutine) args = cls.get_assumed_shape_args(subroutine) for arg in args: checks = [c for c in ubound_checks if arg.name in c.arguments] params = flatten([p for c in checks for p in c.arguments if not p == arg]) if all(IntLiteral(d+1) in params for d in range(len(arg.shape))): msg = f'Run-time UBOUND checks for assumed-shape arg: {arg}' rule_report.add(msg, subroutine) @classmethod def fix_subroutine(cls, subroutine, rule_report, config): """ Method to fix run-time ubound checks for assumed shape dummy arguments """ ubound_checks = cls.get_ubound_checks(subroutine) args = cls.get_assumed_shape_args(subroutine) node_map = {} var_map = {} for arg in args: checks = [c for c in ubound_checks if arg.name in c.arguments] params = {p: c for c in checks for p in c.arguments if not p == arg} # check if ubounds of all dimensions are tested if all(IntLiteral(d+1) in params for d in range(len(arg.shape))): new_shape = () for d in range(len(arg.shape)): conditional = ubound_checks[params[IntLiteral(d+1)]] node_map[conditional] = None # extract comparison expressions in case they are nested in a logical operation conditions = [c for c in FindExpressions().visit(conditional.condition) if isinstance(c, Comparison)] conditions = [c for c in conditions if c.operator in ('<', '>')] cond = [c for c in conditions if arg.name in c and IntLiteral(d+1) in c][0] # build ordered tuple for declaration shape if 'ubound' in FindExpressions().visit(cond.left): new_shape += as_tuple(cond.right) else: new_shape += as_tuple(cond.left) vtype = arg.type.clone(shape=new_shape) var_map.update({arg: arg.clone(type=vtype, dimensions=new_shape)}) # update variable declarations subroutine.spec = SubstituteExpressions(var_map).visit(subroutine.spec) for decl in FindNodes(VariableDeclaration).visit(subroutine.spec): if decl.dimensions: if not all(sym.shape == decl.dimensions for sym in decl.symbols): new_decls = as_tuple(VariableDeclaration(as_tuple(sym)) for sym in decl.symbols) node_map.update({decl: new_decls}) return node_map # Create the __all__ property of the module to contain only the rule names __all__ = tuple(name for name in dir() if name.endswith('Rule') and name != 'GenericRule') loki-ecmwf-0.3.6/lint_rules/lint_rules/ifs_coding_standards_2011.py0000664000175000017500000005445615167130205025503 0ustar alastairalastair# (C) Copyright 2018- ECMWF. # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. # pylint: disable=invalid-all-format """ Implementation of rules in the IFS coding standards document (2011) for loki-lint. """ from pathlib import Path import re from pymbolic.primitives import Expression from loki import ( Visitor, FindNodes, ExpressionFinder, ExpressionRetriever, Node, flatten, as_tuple, strip_inline_comments, Module, Subroutine, BasicType, ir ) from loki.lint import GenericRule, RuleType from loki.expression import symbols as sym class CodeBodyRule(GenericRule): # Coding standards 1.3 type = RuleType.WARN docs = { 'id': '1.3', 'title': ('Rules for Code Body: ' 'Nesting of conditional blocks should not be more than {max_nesting_depth} ' 'levels deep;'), } config = { 'max_nesting_depth': 3, } class NestingDepthVisitor(Visitor): @classmethod def default_retval(cls): return [] def __init__(self, max_nesting_depth): super().__init__() self.max_nesting_depth = max_nesting_depth def visit(self, o, *args, **kwargs): return flatten(super().visit(o, *args, **kwargs)) def visit_Conditional(self, o, **kwargs): level = kwargs.pop('level', 0) too_deep = [] if level >= self.max_nesting_depth and not getattr(o, 'inline', False): too_deep = [o] too_deep += self.visit(o.body, level=level + 1, **kwargs) if o.has_elseif: too_deep += self.visit(o.else_body, level=level, **kwargs) else: too_deep += self.visit(o.else_body, level=level + 1, **kwargs) return too_deep def visit_MultiConditional(self, o, **kwargs): level = kwargs.pop('level', 0) too_deep = [] if level >= self.max_nesting_depth and not getattr(o, 'inline', False): too_deep = [o] too_deep += self.visit(o.bodies, level=level + 1, **kwargs) too_deep += self.visit(o.else_body, level=level + 1, **kwargs) return too_deep visit_TypeConditional = visit_MultiConditional @classmethod def check_subroutine(cls, subroutine, rule_report, config, **kwargs): '''Check the code body: Nesting of conditional blocks.''' too_deep = cls.NestingDepthVisitor(config['max_nesting_depth']).visit(subroutine.body) msg = f'Nesting of conditionals exceeds limit of {config["max_nesting_depth"]}' for node in too_deep: rule_report.add(msg, node) class ModuleNamingRule(GenericRule): # Coding standards 1.5 type = RuleType.WARN docs = { 'id': '1.5', 'title': ('Naming Schemes for Modules: All modules should end with "_mod". ' 'Module filename should match the name of the module it contains.'), } @classmethod def check_module(cls, module, rule_report, config): '''Check the module name and the name of the source file.''' if not module.name.lower().endswith('_mod'): msg = f'Name of module "{module.name}" should end with "_mod"' rule_report.add(msg, module) if module.source.file: path = Path(module.source.file) if module.name.lower() != path.stem.lower(): msg = f'Module filename "{path.name}" does not match module name "{module.name}"' rule_report.add(msg, module) class DrHookRule(GenericRule): # Coding standards 1.9 type = RuleType.SERIOUS docs = { 'id': '1.9', 'title': 'Rules for DR_HOOK', } non_exec_nodes = (ir.Comment, ir.CommentBlock, ir.Pragma, ir.PreprocessorDirective) @classmethod def _find_lhook_conditional(cls, ast, is_reversed=False): cond = None for node in reversed(ast) if is_reversed else ast: if isinstance(node, ir.Conditional): if node.condition == 'LHOOK': cond = node break elif not isinstance(node, cls.non_exec_nodes): # Break if executable statement encountered break return cond @classmethod def _find_lhook_call(cls, cond, is_reversed=False): call = None if cond: # We use as_tuple here because the conditional can be inline and then its body is not # iterable but a single node (e.g., CallStatement) body = reversed(as_tuple(cond.body)) if is_reversed else as_tuple(cond.body) for node in body: if isinstance(node, ir.CallStatement) and node.name == 'DR_HOOK': call = node elif not isinstance(node, cls.non_exec_nodes): # Break if executable statement encountered break return call @staticmethod def _get_string_argument(scope): string_arg = scope.name.upper() while hasattr(scope, 'parent') and scope.parent: scope = scope.parent if isinstance(scope, Subroutine): string_arg = scope.name.upper() + '%' + string_arg elif isinstance(scope, Module): string_arg = scope.name.upper() + ':' + string_arg return string_arg @classmethod def _check_lhook_call(cls, call, subroutine, rule_report, pos='First'): if call is None: msg = f'{pos} executable statement must be call to DR_HOOK' rule_report.add(msg, subroutine) elif call.arguments: string_arg = cls._get_string_argument(subroutine) if not isinstance(call.arguments[0], sym.StringLiteral) or \ call.arguments[0].value.upper() != string_arg: msg = f'String argument to DR_HOOK call should be "{string_arg}"' rule_report.add(msg, call) second_arg = {'First': '0', 'Last': '1'} if not (len(call.arguments) > 1 and isinstance(call.arguments[1], sym.IntLiteral) and str(call.arguments[1].value) == second_arg[pos]): msg = f'Second argument to DR_HOOK call should be "{second_arg[pos]}"' rule_report.add(msg, call) if not (len(call.arguments) > 2 and call.arguments[2] == 'ZHOOK_HANDLE'): msg = 'Third argument to DR_HOOK call should be "ZHOOK_HANDLE".' rule_report.add(msg, call) @classmethod def check_subroutine(cls, subroutine, rule_report, config, **kwargs): '''Check that first and last executable statements in the subroutine are conditionals with calls to DR_HOOK in their body and that the correct arguments are given to the call.''' # Extract the AST for the subroutine body ast = subroutine.body if isinstance(ast, ir.Section): ast = ast.body ast = flatten(ast) # Look for conditionals in subroutine body first_cond = cls._find_lhook_conditional(ast) last_cond = cls._find_lhook_conditional(ast, is_reversed=True) # Find calls to DR_HOOK first_call = cls._find_lhook_call(first_cond) last_call = cls._find_lhook_call(last_cond, is_reversed=True) cls._check_lhook_call(first_call, subroutine, rule_report) cls._check_lhook_call(last_call, subroutine, rule_report, pos='Last') class LimitSubroutineStatementsRule(GenericRule): # Coding standards 2.2 type = RuleType.WARN docs = { 'id': '2.2', 'title': 'Subroutines should have no more than {max_num_statements} executable statements.', } config = { 'max_num_statements': 300 } # List of nodes that are considered executable statements exec_nodes = ( ir.Assignment, ir.MaskedStatement, ir.Intrinsic, ir.Allocation, ir.Deallocation, ir.Nullify, ir.CallStatement ) # Pattern for intrinsic nodes that are allowed as non-executable statements match_non_exec_intrinsic_node = re.compile(r'\s*(?:PRINT|FORMAT)', re.I) @classmethod def check_subroutine(cls, subroutine, rule_report, config, **kwargs): '''Count the number of nodes in the subroutine and check if they exceed a given maximum number. ''' # Count total number of executable nodes nodes = FindNodes(cls.exec_nodes).visit(subroutine.ir) num_nodes = len(nodes) # Subtract number of non-exec intrinsic nodes intrinsic_nodes = filter(lambda node: isinstance(node, ir.Intrinsic), nodes) num_nodes -= sum(1 for _ in filter( lambda node: cls.match_non_exec_intrinsic_node.match(node.text), intrinsic_nodes)) if num_nodes > config['max_num_statements']: msg = (f'Subroutine has {num_nodes} executable statements ' f'(should not have more than {config["max_num_statements"]})') rule_report.add(msg, subroutine) class MaxDummyArgsRule(GenericRule): # Coding standards 3.6 type = RuleType.INFO docs = { 'id': '3.6', 'title': 'Routines should have no more than {max_num_arguments} dummy arguments.', } config = { 'max_num_arguments': 50 } @classmethod def check_subroutine(cls, subroutine, rule_report, config, **kwargs): """ Count the number of dummy arguments and report if given maximum number exceeded. """ num_arguments = len(subroutine.arguments) if num_arguments > config['max_num_arguments']: msg = (f'Subroutine has {num_arguments} dummy arguments ' f'(should not have more than {config["max_num_arguments"]})') rule_report.add(msg, subroutine) class MplCdstringRule(GenericRule): # Coding standards 3.12 type = RuleType.SERIOUS docs = { 'id': '3.12', 'title': 'Calls to MPL subroutines should provide a "CDSTRING" identifying the caller.', } @classmethod def check_subroutine(cls, subroutine, rule_report, config, **kwargs): '''Check all calls to MPL subroutines for a CDSTRING.''' for call in FindNodes(ir.CallStatement).visit(subroutine.ir): if str(call.name).upper().startswith('MPL_'): for kw, _ in call.kwarguments: if kw.upper() == 'CDSTRING': break else: msg = f'No "CDSTRING" provided in call to {call.name}' rule_report.add(msg, call) class ImplicitNoneRule(GenericRule): # Coding standards 4.4 type = RuleType.SERIOUS docs = { 'id': '4.4', 'title': '"IMPLICIT NONE" is mandatory in all routines.', } _regex = re.compile(r'implicit\s+none\b', re.I) @staticmethod def check_for_implicit_none(ast): """ Check for intrinsic nodes that match the regex. """ for intr in FindNodes(ir.Intrinsic).visit(ast): if ImplicitNoneRule._regex.match(intr.text): break else: return False return True @classmethod def check_subroutine(cls, subroutine, rule_report, config, **kwargs): """ Check for IMPLICIT NONE in the subroutine's spec or any enclosing scope. """ found_implicit_none = cls.check_for_implicit_none(subroutine.ir) # Check if enclosing scopes contain implicit none scope = subroutine.parent while scope and not found_implicit_none: if hasattr(scope, 'spec') and scope.spec: found_implicit_none = cls.check_for_implicit_none(scope.spec) scope = scope.parent if hasattr(scope, 'parent') else None if not found_implicit_none: # No 'IMPLICIT NONE' intrinsic node was found rule_report.add('No "IMPLICIT NONE" found', subroutine) class ExplicitKindRule(GenericRule): # Coding standards 4.7 type = RuleType.SERIOUS docs = { 'id': '4.7', 'title': ('Variables and constants must be declared with explicit kind, using the kinds ' 'defined in "PARKIND1" and "PARKIND2".'), } config = { 'declaration_types': ['INTEGER', 'REAL'], 'constant_types': ['REAL'], # Coding standards document includes INTEGERS here 'allowed_type_kinds': { 'INTEGER': ['JPIM', 'JPIT', 'JPIB', 'JPIA', 'JPIS', 'JPIH'], 'REAL': ['JPRB', 'JPRM', 'JPRS', 'JPRT', 'JPRH', 'JPRD', 'JPHOOK'] } } @staticmethod def check_kind_declarations(subroutine, types, allowed_type_kinds, rule_report): '''Helper function that carries out the check for explicit kind specification on all declarations. ''' for decl in FindNodes(ir.VariableDeclaration).visit(subroutine.spec): decl_type = decl.symbols[0].type if decl_type.dtype in types: if not decl_type.kind: # Declared without any KIND specification msg = f'{", ".join(str(var) for var in decl.symbols)} without explicit KIND declared' rule_report.add(msg, decl) elif allowed_type_kinds.get(decl_type.dtype): if decl_type.kind not in allowed_type_kinds[decl_type.dtype]: # We have a KIND but it does not match any of the allowed kinds msg = (f'{decl_type.kind!s} is not an allowed KIND value for ' f'{", ".join(str(var) for var in decl.symbols)}') rule_report.add(msg, decl) @staticmethod def check_kind_literals(subroutine, types, allowed_type_kinds, rule_report): '''Helper function that carries out the check for explicit kind specification on all literals. ''' class FindLiteralsWithKind(ExpressionFinder): """ Custom expression finder that that yields all literals of the types specified in the config and stops recursion on loop ranges and array subscripts (to avoid warnings about integer constants in these cases) """ retriever = ExpressionRetriever( query=lambda e: isinstance(e, types), recurse_query=lambda e: not isinstance(e, (sym.Array, sym.Range)) ) for node, exprs in FindLiteralsWithKind(unique=False, with_ir_node=True).visit(subroutine.ir): for literal in exprs: if not literal.kind: rule_report.add(f'{literal} used without explicit KIND', node) elif allowed_type_kinds.get(literal.__class__): if str(literal.kind).upper() not in allowed_type_kinds[literal.__class__]: msg = f'{literal.kind} is not an allowed KIND value for {literal}' rule_report.add(msg, node) @classmethod def check_subroutine(cls, subroutine, rule_report, config, **kwargs): '''Check for explicit kind information in constants and variable declarations. ''' # 1. Check variable declarations for explicit KIND # # When we check variable type information, we have BasicType values to identify # whether a variable is REAL, INTEGER, ... Therefore, we create a map that uses # the corresponding BasicType values as keys to look up allowed kinds for each type. # Since the case does not matter, we convert all allowed type kinds to upper case. types = tuple(BasicType.from_str(name) for name in config['declaration_types']) allowed_type_kinds = {} if config.get('allowed_type_kinds'): allowed_type_kinds = {BasicType.from_str(name): [kind.upper() for kind in kinds] for name, kinds in config['allowed_type_kinds'].items()} cls.check_kind_declarations(subroutine, types, allowed_type_kinds, rule_report) # 2. Check constants for explicit KIND # # Constants are represented by an instance of some Literal class, which directly # gives us their type. Therefore, we create a map that uses the corresponding # Literal types as keys to look up allowed kinds for each type. Again, we # convert all allowed type kinds to upper case. type_map = {'INTEGER': sym.IntLiteral, 'REAL': sym.FloatLiteral, 'LOGICAL': sym.LogicLiteral, 'CHARACTER': sym.StringLiteral} types = tuple(type_map[name] for name in config['constant_types']) if config.get('allowed_type_kinds'): allowed_type_kinds = {type_map[name]: [kind.upper() for kind in kinds] for name, kinds in config['allowed_type_kinds'].items()} cls.check_kind_literals(subroutine, types, allowed_type_kinds, rule_report) class BannedStatementsRule(GenericRule): # Coding standards 4.11 type = RuleType.WARN docs = { 'id': '4.11', 'title': 'Banned statements.', } config = { 'banned': ['STOP', 'PRINT', 'RETURN', 'ENTRY', 'DIMENSION', 'DOUBLE PRECISION', 'COMPLEX', 'GO TO', 'CONTINUE', 'FORMAT', 'COMMON', 'EQUIVALENCE'], } @classmethod def check_subroutine(cls, subroutine, rule_report, config, **kwargs): '''Check for banned statements in intrinsic nodes.''' for intr in FindNodes(ir.Intrinsic).visit(subroutine.ir): for keyword in config['banned']: if keyword.lower() in intr.text.lower(): rule_report.add(f'Banned keyword "{keyword}"', intr) class Fortran90OperatorsRule(GenericRule): # Coding standards 4.15 type = RuleType.WARN docs = { 'id': '4.15', 'title': 'Use Fortran 90 comparison operators.' } fixable = True ''' Regex patterns for each operator that match F77 and F90 operators as named groups, thus allowing to easily find out which operator was used. ''' _op_patterns = { '==': re.compile(r'(?P\.eq\.)|(?P==)', re.I), '!=': re.compile(r'(?P\.ne\.)|(?P/=)', re.I), '>=': re.compile(r'(?P\.ge\.)|(?P>=)', re.I), '<=': re.compile(r'(?P\.le\.)|(?P<=)', re.I), '>': re.compile(r'(?P\.gt\.)|(?P>(?!=))', re.I), '<': re.compile(r'(?P\.lt\.)|(?P<(?!=))', re.I), } _op_map = { '==': '.eq.', '/=': '.ne.', '>=': '.ge.', '<=': '.le.', '>': '.gt.', '<': '.lt.' } class ComparisonRetriever(Visitor): """ Bespoke expression retriever that extracts 3-tuples containing ``(node, expression root, comparison)`` for all :any:`Comparison` nodes. """ retriever = ExpressionRetriever(lambda e: isinstance(e, sym.Comparison)) def visit_Node(self, o, **kwargs): """ Generic visitor method that will call the :any:`ExpressionRetriever` only on :class:`pymbolic.primitives.Expression` children, collecting ``(node, expression root, comparison)`` tuples for all matches. """ retval = () for ch in flatten(o.children): if isinstance(ch, Expression): comparisons = self.retriever.retrieve(ch) if comparisons: retval += ((o, ch, comparisons),) elif isinstance(ch, Node): retval += self.visit(ch, **kwargs) return retval def visit_tuple(self, o, **kwargs): """ Specialized handling of tuples to concatenate the nested tuples returned by :meth:`visit_Node`. """ retval = () for ch in o: if ch is not None: retval += self.visit(ch, **kwargs) return retval visit_list = visit_tuple @classmethod def check_subroutine(cls, subroutine, rule_report, config, **kwargs): '''Check for the use of Fortran 90 comparison operators.''' # Use the bespoke visitor to retrieve all comparison nodes alongside with their expression root # and the IR node they belong to for node, expr_root, expr_list in cls.ComparisonRetriever().visit(subroutine.ir): # Use the string representation of the expression to find the source line lstart, lend = node.source.find(str(expr_root)) lines = node.source.clone_lines((lstart, lend)) # For each comparison operator, use the original source code (because the frontends always # translate them to F90 operators) to check if F90 or F77 operators were used for op in sorted({op.operator for op in expr_list}): # find source line for operator op_str = op if op != '!=' else '/=' line = [line for line in lines if op_str in strip_inline_comments(line.string)] if not line: line = [line for line in lines if op_str in strip_inline_comments(line.string.replace(cls._op_map[op_str], op_str))] source_string = strip_inline_comments(line[0].string) matches = cls._op_patterns[op].findall(source_string) for f77, _ in matches: if f77: msg = f'Use Fortran 90 comparison operator "{op_str}" instead of "{f77}"' rule_report.add(msg, node) @classmethod def fix_subroutine(cls, subroutine, rule_report, config): '''Replace by Fortran 90 comparison operators.''' # We only have to invalidate the source string for the expression. This will cause the # backend to regenerate the source string for that node and use Fortran 90 operators # automatically mapper = {} for report in rule_report.problem_reports: new_expr = report.location new_expr.update_metadata({'source': None}) mapper[report.location] = new_expr return mapper # Create the __all__ property of the module to contain only the rule names __all__ = tuple(name for name in dir() if name.endswith('Rule') and name != 'GenericRule') loki-ecmwf-0.3.6/lint_rules/lint_rules/ifs_arpege_coding_standards.py0000664000175000017500000001531315167130205026350 0ustar alastairalastair# (C) Copyright 2018- ECMWF. # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. """ Implementation of rules from the IFS Arpege coding standards as :any:`GenericRule` See https://sites.ecmwf.int/docs/ifs-arpege-coding-standards/fortran for the current version of the coding standards. """ from collections import defaultdict import re try: from fparser.two.Fortran2003 import Intrinsic_Name _intrinsic_fortran_names = Intrinsic_Name.function_names except ImportError: _intrinsic_fortran_names = () from loki import ( FindInlineCalls, FindNodes, GenericRule, Module, RuleType ) from loki import ir __all__ = [ 'MissingImplicitNoneRule', 'OnlyParameterGlobalVarRule', 'MissingIntfbRule', ] class MissingImplicitNoneRule(GenericRule): """ ``IMPLICIT NONE`` must be present in all scoping units but may be omitted in module procedures. """ type = RuleType.SERIOUS docs = { 'id': 'L1', 'title': ( 'IMPLICIT NONE must figure in all scoping units. ' 'Once per module is sufficient.' ), } _regex = re.compile(r'implicit\s+none\b', re.I) @classmethod def check_for_implicit_none(cls, ir_): """ Check for intrinsic nodes that match the regex. """ for intr in FindNodes(ir.Intrinsic).visit(ir_): if cls._regex.match(intr.text): break else: return False return True @classmethod def check_module(cls, module, rule_report, config): """ Check for ``IMPLICIT NONE`` in the module's spec. """ found_implicit_none = cls.check_for_implicit_none(module.spec) if not found_implicit_none: # No 'IMPLICIT NONE' intrinsic node was found rule_report.add('No `IMPLICIT NONE` found', module) @classmethod def check_subroutine(cls, subroutine, rule_report, config, **kwargs): """ Check for ``IMPLICIT NONE`` in the subroutine's spec or an enclosing :any:`Module` scope. """ found_implicit_none = cls.check_for_implicit_none(subroutine.ir) # Check if enclosing scopes contain implicit none scope = subroutine.parent while scope and not found_implicit_none: if isinstance(scope, Module) and hasattr(scope, 'spec') and scope.spec: found_implicit_none = cls.check_for_implicit_none(scope.spec) scope = scope.parent if hasattr(scope, 'parent') else None if not found_implicit_none: # No 'IMPLICIT NONE' intrinsic node was found rule_report.add('No `IMPLICIT NONE` found', subroutine) class OnlyParameterGlobalVarRule(GenericRule): """ Only parameters to be declared as global variables. """ type = RuleType.SERIOUS docs = { 'id': 'L3', 'title': 'Only parameters to be declared as global variables.' } @classmethod def check_module(cls, module, rule_report, config): for decl in module.declarations: if not decl.symbols[0].type.parameter: msg = f'Global variable(s) declared that are not parameters: {", ".join(s.name for s in decl.symbols)}' rule_report.add(msg, decl) class MissingIntfbRule(GenericRule): """ Calls to subroutines and functions that are provided neither by a module nor by a CONTAINS statement, must have a matching explicit interface block. """ type = RuleType.SERIOUS docs = { 'id': 'L9', 'title': ( 'Explicit interface blocks required for procedures that are not ' 'imported or internal subprograms' ) } @classmethod def _get_external_symbols(cls, program_unit): """ Collect all imported symbols in :data:`program_unit` and parent scopes and return as a set of lower-case names """ external_symbols = {name.lower() for name in _intrinsic_fortran_names} if program_unit.parent: external_symbols |= cls._get_external_symbols(program_unit.parent) # Get imported symbols external_symbols |= { s.name.lower() for import_ in program_unit.imports for s in import_.symbols or () } # Collect all symbols declared via intfb includees c_includes = [ include for include in FindNodes(ir.Import).visit(program_unit.ir) if include.c_import ] external_symbols |= { include.module[:-8].lower() for include in c_includes if include.module.endswith('.intfb.h') } external_symbols |= { include.module[:-7].lower() for include in c_includes if include.module.endswith('.func.h') } # Add locally declared interface symbols external_symbols |= {s.name.lower() for s in program_unit.interface_symbols} # Add internal subprograms and module procedures for routine in program_unit.routines: external_symbols.add(routine.name.lower()) return external_symbols @staticmethod def _add_report(rule_report, node, call_name): """ Register a missing interface block for a call to :data:`call_name` in the :any:`RuleReport` """ msg = f'Missing import or interface block for called procedure `{call_name}`' rule_report.add(msg, node) @classmethod def check_subroutine(cls, subroutine, rule_report, config, **kwargs): """ Check all :any:`CallStatement` and :any:`InlineCall` for a matching import or interface block. """ external_symbols = cls._get_external_symbols(subroutine) # Collect all calls to routines without a corresponding symbol missing_calls = defaultdict(list) for call in FindNodes(ir.CallStatement).visit(subroutine.body): if not call.name.parent and str(call.name).lower() not in external_symbols: missing_calls[str(call.name).lower()] += [call] for node, calls in FindInlineCalls(with_ir_node=True).visit(subroutine.body): for call in calls: if not call.function.parent and call.name.lower() not in external_symbols: missing_calls[call.name.lower()] += [node] # Create reports for each missing routine only for the first occurence for name, calls in missing_calls.items(): cls._add_report(rule_report, calls[0], name) loki-ecmwf-0.3.6/lint_rules/AUTHORS.md0000777000175000017500000000000015167130205021445 2../AUTHORS.mdustar alastairalastairloki-ecmwf-0.3.6/lint_rules/LICENSE0000777000175000017500000000000015167130205020341 2../LICENSEustar alastairalastairloki-ecmwf-0.3.6/AUTHORS.md0000664000175000017500000000065015167130205015411 0ustar alastairalastair# Authors and Contributors - A. Beggs (ECMWF) - J. Ericsson (ECMWF) - R. Heilemann Myhre (Met Norway) - S. Karppinen (FMI) - R. Kazeroni (CNRS/IPSL) - P. Kiepas (École polytechnique/IPSL) - M. Lange (ECMWF) - J. Legaux (CERFACS) - O. Marsden (ECMWF) - A. Nawab (ECMWF) - B. Reuter (ECMWF) - J. Schmalfuß - M. Staneker (ECMWF) If you have contributed to this project, please add your name in the above alphabetical list. loki-ecmwf-0.3.6/loki/0000775000175000017500000000000015167130205014677 5ustar alastairalastairloki-ecmwf-0.3.6/loki/__init__.py0000664000175000017500000000721415167130205017014 0ustar alastairalastair# (C) Copyright 2018- ECMWF. # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. """ The Loki source-to-source translation package for Fortran codes. """ from importlib.metadata import version, PackageNotFoundError # Import the global configuration map from loki.config import * # noqa from loki.frontend import * # noqa from loki.sourcefile import * # noqa from loki.subroutine import * # noqa from loki.program_unit import * # noqa from loki.module import * # noqa from loki.ir import * # noqa from loki.expression import * # noqa from loki.types import * # noqa from loki.tools import * # noqa from loki.logging import * # noqa from loki.backend import * # noqa from loki.jit_build import * # noqa # pylint: disable=redefined-builtin from loki.batch import * # noqa from loki.lint import * # noqa from loki.analyse import * # noqa from loki.dimension import * # noqa from loki.transformations import * # noqa from loki.function import * # noqa from loki.cli import * # noqa try: __version__ = version("loki") except PackageNotFoundError: # package is not installed pass # Add flag to trigger an initial print out of the global config config.register('print-config', False, env_variable='LOKI_PRINT_CONFIG', preprocess=lambda i: bool(i) if isinstance(i, int) else i) # Define Loki's global config options config.register('log-level', 'INFO', callback=set_log_level, preprocess=lambda i: log_levels[i]) config.register('debug', None, env_variable='LOKI_DEBUG', callback=set_excepthook, preprocess=lambda i: auto_post_mortem_debugger if i else None) # Define Loki's temporary directory for generating intermediate files config.register('tmp-dir', None, env_variable='LOKI_TMP_DIR') # Causes external frontend preprocessor to dump intermediate soruce files config.register('cpp-dump-files', False, env_variable='LOKI_CPP_DUMP_FILES', preprocess=lambda i: bool(i) if isinstance(i, int) else i) # Causes OMNI frontend to dump intermediate XML files to LOKI_TMP_DIR config.register('omni-dump-xml', False, env_variable='LOKI_OMNI_DUMP_XML', preprocess=lambda i: bool(i) if isinstance(i, int) else i) # Enable strict frontend behaviour (fail on unknown/unsupported language features) config.register('frontend-strict-mode', False, env_variable='LOKI_FRONTEND_STRICT_MODE', preprocess=lambda i: bool(i) if isinstance(i, int) else i) # Enable frontends to store original source reference with line count config.register('frontend-store-source', True, env_variable='LOKI_FRONTEND_STORE_SOURCE', preprocess=lambda i: bool(i) if isinstance(i, int) else i) # Force symbol comparison and object equality to be case sensitive config.register('case-sensitive', False, env_variable='LOKI_CASE_SENSITIVE', preprocess=lambda i: bool(i) if isinstance(i, int) else i) # Specify a timeout for the REGEX frontend to catch catastrophic backtracking config.register('regex-frontend-timeout', 30, env_variable='LOKI_REGEX_FRONTEND_TIMEOUT', preprocess=int) # The number of worker threads to use in the jit compilation package config.register('jit-build-workers', 3, env_variable='LOKI_JIT_BUILD_WORKERS', preprocess=int) # Trigger configuration initialisation, including # a scan of the current environment variables config.initialize() if config['print-config']: config.print_state() loki-ecmwf-0.3.6/loki/frontend/0000775000175000017500000000000015167130205016516 5ustar alastairalastairloki-ecmwf-0.3.6/loki/frontend/__init__.py0000664000175000017500000000160615167130205020632 0ustar alastairalastair# (C) Copyright 2018- ECMWF. # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. """ Frontend parsers that create Loki IR from input Fortran code. This includes code sanitisation utilities and several frontend parser interfaces, including the REGEX-frontend that is used for fast source code exploration in large call and dependency trees. """ from loki.frontend.preprocessing import * # noqa from loki.frontend.source import * # noqa from loki.frontend.omni import * # noqa from loki.frontend.fparser import * # noqa from loki.frontend.util import * # noqa from loki.frontend.regex import * # noqa loki-ecmwf-0.3.6/loki/frontend/tests/0000775000175000017500000000000015167130205017660 5ustar alastairalastairloki-ecmwf-0.3.6/loki/frontend/tests/__init__.py0000664000175000017500000000057015167130205021773 0ustar alastairalastair# (C) Copyright 2018- ECMWF. # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. loki-ecmwf-0.3.6/loki/frontend/tests/test_regex_frontend.py0000664000175000017500000013571115167130205024312 0ustar alastairalastair# (C) Copyright 2018- ECMWF. # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. """ Verify correct parsing behaviour of the REGEX frontend """ from pathlib import Path import platform from time import perf_counter import pytest from loki import Function, Module, Subroutine, Sourcefile, RawSource, config from loki.frontend import ( available_frontends, OMNI, FP, REGEX, RegexParserClass ) from loki.ir import nodes as ir, FindNodes, PreprocessorDirective from loki.types import BasicType, ProcedureType, DerivedType @pytest.fixture(scope='module', name='here') def fixture_here(): return Path(__file__).parent @pytest.fixture(scope='module', name='testdir') def fixture_testdir(here): return here.parent.parent/'tests' @pytest.fixture(name='reset_regex_frontend_timeout') def fixture_reset_regex_frontend_timeout(): original_timeout = config['regex-frontend-timeout'] yield config['regex-frontend-timeout'] = original_timeout def test_regex_subroutine_from_source(): """ Verify that the regex frontend is able to parse subroutines """ fcode = """ subroutine routine_b( ! arg 1 i, ! arg2 j ) use parkind1, only : jpim implicit none integer, intent(in) :: i, j integer b b = 4 call contained_c(i) call routine_a() contains !abc ^$^** integer(kind=jpim) function contained_e(i) integer, intent(in) :: i contained_e = i end function subroutine contained_c(i) integer, intent(in) :: i integer c c = 5 end subroutine contained_c ! cc£$^£$^ subroutine contained_d(i) integer, intent(in) :: i integer c c = 8 end subroutine !add"£^£$ end subroutine routine_b """.strip() routine = Subroutine.from_source(fcode, frontend=REGEX) assert routine.name == 'routine_b' assert not routine.is_function assert routine.arguments == () assert routine.argnames == [] assert [r.name for r in routine.subroutines] == ['contained_e', 'contained_c', 'contained_d'] contained_c = routine['contained_c'] assert contained_c.name == 'contained_c' assert not contained_c.is_function assert contained_c.arguments == () assert contained_c.argnames == [] contained_e = routine['contained_e'] assert contained_e.name == 'contained_e' assert contained_e.is_function assert contained_e.arguments == () assert contained_e.argnames == [] contained_d = routine['contained_d'] assert contained_d.name == 'contained_d' assert not contained_d.is_function assert contained_d.arguments == () assert contained_d.argnames == [] code = routine.to_fortran() assert code.count('SUBROUTINE') == 6 assert code.count('FUNCTION') == 2 assert code.count('CONTAINS') == 1 def test_regex_module_from_source(): """ Verify that the regex frontend is able to parse modules """ fcode = """ module some_module use foobar implicit none integer, parameter :: k = selected_int_kind(5) contains subroutine module_routine integer m m = 2 call routine_b(m, 6) end subroutine module_routine integer(kind=k) function module_function(n) integer n module_function = n + 2 end function module_function end module some_module """.strip() module = Module.from_source(fcode, frontend=REGEX) assert module.name == 'some_module' assert [r.name for r in module.subroutines] == ['module_routine', 'module_function'] code = module.to_fortran() assert code.count('MODULE') == 2 assert code.count('SUBROUTINE') == 2 assert code.count('FUNCTION') == 2 assert code.count('CONTAINS') == 1 def test_regex_sourcefile_from_source(): """ Verify that the regex frontend is able to parse source files containing multiple modules and subroutines """ fcode = """ subroutine routine_a integer a, i a = 1 i = a + 1 call routine_b(a, i) end subroutine routine_a module some_module contains subroutine module_routine integer m m = 2 call routine_b(m, 6) end subroutine module_routine function module_function(n) integer n integer module_function module_function = n + 3 end function module_function end module some_module module other_module integer :: n end module subroutine routine_b( ! arg 1 i, ! arg2 j, k!arg3 ) integer, intent(in) :: i, j, k integer b b = 4 call contained_c(i) call routine_a() contains !abc ^$^** subroutine contained_c(i) integer, intent(in) :: i integer c c = 5 end subroutine contained_c ! cc£$^£$^ integer function contained_e(i) integer, intent(in) :: i contained_e = i end function subroutine contained_d(i) integer, intent(in) :: i integer c c = 8 end subroutine !add"£^£$ endsubroutine routine_b function function_d(d) integer d d = 6 end function function_d module last_module implicit none contains subroutine last_routine1 call contained() contains subroutine contained integer n n = 1 end subroutine contained end subroutine last_routine1 subroutine last_routine2 call contained2() contains subroutine contained2 integer m m = 1 end subroutine contained2 end subroutine last_routine2 end module last_module """.strip() sourcefile = Sourcefile.from_source(fcode, frontend=REGEX) assert [m.name for m in sourcefile.modules] == ['some_module', 'other_module', 'last_module'] assert [r.name for r in sourcefile.routines] == [ 'routine_a', 'routine_b', 'function_d' ] assert [r.name for r in sourcefile.all_subroutines] == [ 'routine_a', 'routine_b', 'function_d', 'module_routine', 'module_function', 'last_routine1', 'last_routine2' ] assert len(r := sourcefile['last_module']['last_routine1'].routines) == 1 and r[0].name == 'contained' assert len(r := sourcefile['last_module']['last_routine2'].routines) == 1 and r[0].name == 'contained2' code = sourcefile.to_fortran() assert code.count('SUBROUTINE') == 18 assert code.count('FUNCTION') == 6 assert code.count('CONTAINS') == 5 assert code.count('MODULE') == 6 def test_regex_sourcefile_from_file(testdir): """ Verify that the regex frontend is able to parse source files containing multiple modules and subroutines """ sourcefile = Sourcefile.from_file(testdir/'sources/sourcefile.f90', frontend=REGEX) assert [m.name for m in sourcefile.modules] == ['some_module'] assert [r.name for r in sourcefile.routines] == [ 'routine_a', 'routine_b', 'function_d' ] assert [r.name for r in sourcefile.all_subroutines] == [ 'routine_a', 'routine_b', 'function_d', 'module_routine', 'module_function' ] routine_b = sourcefile['ROUTINE_B'] assert routine_b.name == 'routine_b' assert not routine_b.is_function assert routine_b.arguments == () assert routine_b.argnames == [] assert [r.name for r in routine_b.subroutines] == ['contained_c'] function_d = sourcefile['function_d'] assert function_d.name == 'function_d' assert function_d.is_function assert function_d.arguments == () assert function_d.argnames == [] assert not function_d.contains code = sourcefile.to_fortran() assert code.count('SUBROUTINE') == 8 assert code.count('FUNCTION') == 4 assert code.count('CONTAINS') == 2 assert code.count('MODULE') == 2 def test_regex_sourcefile_from_file_parser_classes(testdir): filepath = testdir/'sources/Fortran-extract-interface-source.f90' module_names = {'bar', 'foo'} routine_names = { 'func_simple', 'func_simple_1', 'func_simple_2', 'func_simple_pure', 'func_simple_recursive_pure', 'func_simple_elemental', 'func_with_use_and_args', 'func_with_parameters', 'func_with_parameters_1', 'func_with_contains', 'func_mix_local_and_result', 'sub_simple', 'sub_simple_1', 'sub_simple_2', 'sub_simple_3', 'sub_with_contains', 'sub_with_renamed_import', 'sub_with_external', 'sub_with_end' } module_routine_names = {'foo_sub', 'foo_func'} # Empty parse (since we don't match typedef without having the enclosing module first) sourcefile = Sourcefile.from_file(filepath, frontend=REGEX, parser_classes=RegexParserClass.TypeDefClass) assert not sourcefile.subroutines assert not sourcefile.modules assert FindNodes(RawSource).visit(sourcefile.ir) assert sourcefile._incomplete assert sourcefile._parser_classes == RegexParserClass.TypeDefClass # Incremental addition of program unit objects sourcefile.make_complete(frontend=REGEX, parser_classes=RegexParserClass.ProgramUnitClass) assert sourcefile._incomplete assert sourcefile._parser_classes == RegexParserClass.ProgramUnitClass | RegexParserClass.TypeDefClass # Note that the program unit objects don't include the TypeDefClass because it's lower in the hierarchy # and was not matched previously assert all( module._parser_classes == RegexParserClass.ProgramUnitClass for module in sourcefile.modules ) assert all( routine._parser_classes == RegexParserClass.ProgramUnitClass for routine in sourcefile.routines ) assert {module.name.lower() for module in sourcefile.modules} == module_names assert {routine.name.lower() for routine in sourcefile.routines} == routine_names assert {routine.name.lower() for routine in sourcefile.all_subroutines} == routine_names | module_routine_names assert {routine.name.lower() for routine in sourcefile['func_with_contains'].routines} == {'func_with_contains_1'} assert {routine.name.lower() for routine in sourcefile['sub_with_contains'].routines} == { 'sub_with_contains_first', 'sub_with_contains_second', 'sub_with_contains_third' } for module in sourcefile.modules: assert not module.imports for routine in sourcefile.all_subroutines: assert not routine.imports assert not sourcefile['bar'].typedefs # Validate that a re-parse with same parser classes does not change anything sourcefile.make_complete(frontend=REGEX, parser_classes=RegexParserClass.ProgramUnitClass) assert sourcefile._incomplete assert sourcefile._parser_classes == RegexParserClass.ProgramUnitClass | RegexParserClass.TypeDefClass for module in sourcefile.modules: assert not module.imports for routine in sourcefile.all_subroutines: assert not routine.imports assert not sourcefile['bar'].typedefs # Incremental addition of imports sourcefile.make_complete( frontend=REGEX, parser_classes=RegexParserClass.ProgramUnitClass | RegexParserClass.ImportClass ) assert sourcefile._parser_classes == ( RegexParserClass.ProgramUnitClass | RegexParserClass.TypeDefClass | RegexParserClass.ImportClass ) # Note that the program unit objects don't include the TypeDefClass because it's lower in the hierarchy # and was not matched previously assert all( module._parser_classes == ( RegexParserClass.ProgramUnitClass | RegexParserClass.ImportClass ) for module in sourcefile.modules ) assert all( routine._parser_classes == ( RegexParserClass.ProgramUnitClass | RegexParserClass.ImportClass ) for routine in sourcefile.routines ) assert {module.name.lower() for module in sourcefile.modules} == module_names assert {routine.name.lower() for routine in sourcefile.routines} == routine_names assert {routine.name.lower() for routine in sourcefile.all_subroutines} == routine_names | module_routine_names assert {routine.name.lower() for routine in sourcefile['func_with_contains'].routines} == {'func_with_contains_1'} assert {routine.name.lower() for routine in sourcefile['sub_with_contains'].routines} == { 'sub_with_contains_first', 'sub_with_contains_second', 'sub_with_contains_third' } program_units_with_imports = { 'foo': ['bar'], 'func_with_use_and_args': ['foo', 'bar'], 'sub_with_contains': ['bar'], 'sub_with_renamed_import': ['bar'] } for unit in module_names | routine_names | module_routine_names: if unit in program_units_with_imports: assert [import_.module.lower() for import_ in sourcefile[unit].imports] == program_units_with_imports[unit] else: assert not sourcefile[unit].imports assert not sourcefile['bar'].typedefs # Parse the rest sourcefile.make_complete(frontend=REGEX, parser_classes=RegexParserClass.AllClasses) assert sourcefile._parser_classes == RegexParserClass.AllClasses assert all( module._parser_classes == RegexParserClass.AllClasses for module in sourcefile.modules ) assert all( routine._parser_classes == RegexParserClass.AllClasses for routine in sourcefile.routines ) assert {module.name.lower() for module in sourcefile.modules} == module_names assert {routine.name.lower() for routine in sourcefile.routines} == routine_names assert {routine.name.lower() for routine in sourcefile.all_subroutines} == routine_names | module_routine_names assert {routine.name.lower() for routine in sourcefile['func_with_contains'].routines} == {'func_with_contains_1'} assert {routine.name.lower() for routine in sourcefile['sub_with_contains'].routines} == { 'sub_with_contains_first', 'sub_with_contains_second', 'sub_with_contains_third' } program_units_with_imports = { 'foo': ['bar'], 'func_with_use_and_args': ['foo', 'bar'], 'sub_with_contains': ['bar'], 'sub_with_renamed_import': ['bar'] } for unit in module_names | routine_names | module_routine_names: if unit in program_units_with_imports: assert [import_.module.lower() for import_ in sourcefile[unit].imports] == program_units_with_imports[unit] else: assert not sourcefile[unit].imports # Check access via properties assert 'bar' in sourcefile assert 'food' in sourcefile['bar'] assert sorted(sourcefile['bar'].typedef_map) == ['food', 'organic'] assert sourcefile['bar'].definitions == sourcefile['bar'].typedefs + ('i_am_dim',) assert 'cooking_method' in sourcefile['bar']['food'] assert 'foobar' not in sourcefile['bar']['food'] assert sourcefile['bar']['food'].interface_symbols == () # Check that triggering a full parse works from nested scopes assert sourcefile['bar']._incomplete sourcefile['bar']['food'].make_complete() assert not sourcefile['bar']._incomplete def test_regex_raw_source(): """ Verify that unparsed source appears in-between matched objects """ fcode = """ ! Some comment before the module ! module some_mod ! Some docstring ! docstring ! docstring use some_mod ! Some comment ! comment ! comment end module some_mod ! Other comment at the end """.strip() source = Sourcefile.from_source(fcode, frontend=REGEX) assert len(source.ir.body) == 3 assert isinstance(source.ir.body[0], RawSource) assert source.ir.body[0].source.lines == (1, 2) assert source.ir.body[0].text == '! Some comment before the module\n!' assert source.ir.body[0].source.string == source.ir.body[0].text assert isinstance(source.ir.body[1], Module) assert source.ir.body[1].source.lines == (3, 11) assert source.ir.body[1].source.string.startswith('module') assert isinstance(source.ir.body[2], RawSource) assert source.ir.body[2].source.lines == (12, 13) assert source.ir.body[2].text == '\n! Other comment at the end' assert source.ir.body[2].source.string == source.ir.body[2].text module = source['some_mod'] assert len(module.spec.body) == 3 assert isinstance(module.spec.body[0], RawSource) assert isinstance(module.spec.body[1], ir.Import) assert isinstance(module.spec.body[2], RawSource) assert module.spec.body[0].text.count('docstring') == 3 assert module.spec.body[2].text.count('comment') == 3 def test_regex_raw_source_with_cpp(): """ Verify that unparsed source appears in-between matched objects and preprocessor statements are preserved """ fcode = """ ! Some comment before the subroutine #ifdef RS6K @PROCESS HOT(NOVECTOR) NOSTRICT #endif SUBROUTINE SOME_ROUTINE (KLON, KLEV) IMPLICIT NONE INTEGER, INTENT(IN) :: KLON, KLEV ! Comment inside routine END SUBROUTINE SOME_ROUTINE """.strip() source = Sourcefile.from_source(fcode, frontend=REGEX) assert len(source.ir.body) == 2 assert isinstance(source.ir.body[0], RawSource) assert source.ir.body[0].source.lines == (1, 4) assert source.ir.body[0].text.startswith('! Some comment before the subroutine\n#') assert source.ir.body[0].text.endswith('#endif') assert source.ir.body[0].source.string == source.ir.body[0].text assert isinstance(source.ir.body[1], Subroutine) assert source.ir.body[1].source.lines == (5, 9) assert source.ir.body[1].source.string.startswith('SUBROUTINE') def test_regex_raw_source_with_cpp_incomplete(): """ Verify that unparsed source appears inside matched objects if parser classes are used to restrict the matching """ fcode = """ SUBROUTINE driver(a, b, c) INTEGER, INTENT(INOUT) :: a, b, c #include "kernel.intfb.h" CALL kernel(a, b ,c) END SUBROUTINE driver """.strip() parser_classes = RegexParserClass.ProgramUnitClass source = Sourcefile.from_source(fcode, frontend=REGEX, parser_classes=parser_classes) assert len(source.ir.body) == 1 driver = source['driver'] assert isinstance(driver, Subroutine) assert not driver.docstring assert not driver.body assert not driver.contains assert driver.spec and len(driver.spec.body) == 1 assert isinstance(driver.spec.body[0], RawSource) assert 'INTEGER, INTENT' in driver.spec.body[0].text assert '#include' in driver.spec.body[0].text @pytest.mark.parametrize('frontend', available_frontends( xfail=[(OMNI, 'Non-standard notation needs full preprocessing')] )) def test_make_complete_sanitize(frontend): """ Test that attempts to first REGEX-parse and then complete source code with unsupported features that require "frontend sanitization". """ fcode = """ ! Some comment before the subroutine #ifdef RS6K @PROCESS HOT(NOVECTOR) NOSTRICT #endif SUBROUTINE SOME_ROUTINE (KLON, KLEV) IMPLICIT NONE INTEGER, INTENT(IN) :: KLON, KLEV ! Comment inside routine END SUBROUTINE SOME_ROUTINE """.strip() source = Sourcefile.from_source(fcode, frontend=REGEX) # Ensure completion handles the non-supported features (@PROCESS) source.make_complete(frontend=frontend) comments = FindNodes(ir.Comment).visit(source.ir) assert len(comments) == 2 if frontend == FP else 1 assert comments[0].text == '! Some comment before the subroutine' if frontend == FP: assert comments[1].text == '@PROCESS HOT(NOVECTOR) NOSTRICT' directives = FindNodes(PreprocessorDirective).visit(source.ir) assert len(directives) == 2 assert directives[0].text == '#ifdef RS6K' assert directives[1].text == '#endif' @pytest.mark.skipif(platform.system() == 'Darwin', reason='Timeout utility test sporadically fails on MacOS CI runners.' ) @pytest.mark.usefixtures('reset_regex_frontend_timeout') def test_regex_timeout(): """ This source fails to parse because of missing SUBROUTINE in END statement, and the test verifies that a timeout is encountered """ fcode = """ subroutine some_routine(a) real, intent(in) :: a end """.strip() # Test timeout config['regex-frontend-timeout'] = 1 start = perf_counter() with pytest.raises(RuntimeError) as exc: _ = Sourcefile.from_source(fcode, frontend=REGEX) stop = perf_counter() assert .9 < stop - start < 1.1 assert 'REGEX frontend timeout of 1 s exceeded' in str(exc.value) # Test it works fine with proper Fortran: fcode += ' subroutine' source = Sourcefile.from_source(fcode, frontend=REGEX) assert len(source.subroutines) == 1 assert source.subroutines[0].name == 'some_routine' def test_regex_module_imports(): """ Verify that the regex frontend is able to find and correctly parse Fortran imports """ fcode = """ module some_mod use no_symbols_mod use only_mod, only: my_var use test_rename_mod, first_var1 => var1, first_var3 => var3 use test_other_rename_mod, only: second_var1 => var1 use test_other_rename_mod, only: other_var2 => var2, other_var3 => var3 implicit none end module some_mod """.strip() module = Module.from_source(fcode, frontend=REGEX) imports = FindNodes(ir.Import).visit(module.spec) assert len(imports) == 5 assert [import_.module for import_ in imports] == [ 'no_symbols_mod', 'only_mod', 'test_rename_mod', 'test_other_rename_mod', 'test_other_rename_mod' ] assert set(module.imported_symbols) == { 'my_var', 'first_var1', 'first_var3', 'second_var1', 'other_var2', 'other_var3' } assert module.imported_symbol_map['first_var1'].type.use_name == 'var1' assert module.imported_symbol_map['first_var3'].type.use_name == 'var3' assert module.imported_symbol_map['second_var1'].type.use_name == 'var1' assert module.imported_symbol_map['other_var2'].type.use_name == 'var2' assert module.imported_symbol_map['other_var3'].type.use_name == 'var3' def test_regex_subroutine_imports(): """ Verify that the regex frontend is able to find and correctly parse Fortran imports """ fcode = """ subroutine some_routine use no_symbols_mod use only_mod, only: my_var use test_rename_mod, first_var1 => var1, first_var3 => var3 use test_other_rename_mod, only: second_var1 => var1 use test_other_rename_mod, only: other_var2 => var2, other_var3 => var3 implicit none end subroutine some_routine """.strip() routine = Subroutine.from_source(fcode, frontend=REGEX) imports = FindNodes(ir.Import).visit(routine.spec) assert len(imports) == 5 assert [import_.module for import_ in imports] == [ 'no_symbols_mod', 'only_mod', 'test_rename_mod', 'test_other_rename_mod', 'test_other_rename_mod' ] assert set(routine.imported_symbols) == { 'my_var', 'first_var1', 'first_var3', 'second_var1', 'other_var2', 'other_var3' } assert routine.imported_symbol_map['first_var1'].type.use_name == 'var1' assert routine.imported_symbol_map['first_var3'].type.use_name == 'var3' assert routine.imported_symbol_map['second_var1'].type.use_name == 'var1' assert routine.imported_symbol_map['other_var2'].type.use_name == 'var2' assert routine.imported_symbol_map['other_var3'].type.use_name == 'var3' def test_regex_import_linebreaks(): """ Verify correct handling of line breaks in import statements """ fcode = """ module file_io_mod USE PARKIND1 , ONLY : JPIM, JPRB, JPRD #ifdef HAVE_SERIALBOX USE m_serialize, ONLY: & fs_create_savepoint, & fs_add_serializer_metainfo, & fs_get_serializer_metainfo, & fs_read_field, & fs_write_field USE utils_ppser, ONLY: & ppser_initialize, & ppser_finalize, & ppser_serializer, & ppser_serializer_ref, & ppser_set_mode, & ppser_savepoint #endif #ifdef HAVE_HDF5 USE hdf5_file_mod, only: hdf5_file #endif implicit none end module file_io_mod """.strip() module = Module.from_source(fcode, frontend=REGEX) imports = FindNodes(ir.Import).visit(module.spec) assert len(imports) == 4 assert [import_.module for import_ in imports] == ['PARKIND1', 'm_serialize', 'utils_ppser', 'hdf5_file_mod'] assert all( s in module.imported_symbols for s in [ 'JPIM', 'JPRB', 'JPRD', 'fs_create_savepoint', 'fs_add_serializer_metainfo', 'fs_get_serializer_metainfo', 'fs_read_field', 'fs_write_field', 'ppser_initialize', 'ppser_finalize', 'ppser_serializer', 'ppser_serializer_ref', 'ppser_set_mode', 'ppser_savepoint', 'hdf5_file' ] ) def test_regex_typedef(): """ Verify that the regex frontend is able to parse type definitions and correctly parse procedure bindings. """ fcode = """ module typebound_item implicit none type some_type contains procedure, nopass :: routine => module_routine procedure :: some_routine procedure, pass :: other_routine procedure :: routine1, & & routine2 => routine ! procedure :: routine1 ! procedure :: routine2 => routine end type some_type contains subroutine module_routine integer m m = 2 end subroutine module_routine subroutine some_routine(self) class(some_type) :: self call self%routine end subroutine some_routine subroutine other_routine(self, m) class(some_type), intent(inout) :: self integer, intent(in) :: m integer :: j j = m call self%routine1 call self%routine2 end subroutine other_routine subroutine routine(self) class(some_type) :: self call self%some_routine end subroutine routine subroutine routine1(self) class(some_type) :: self call module_routine end subroutine routine1 end module typebound_item """.strip() module = Module.from_source(fcode, frontend=REGEX) assert 'some_type' in module.typedef_map some_type = module.typedef_map['some_type'] proc_bindings = { 'routine': ('module_routine',), 'some_routine': None, 'other_routine': None, 'routine1': None, 'routine2': ('routine',) } assert len(proc_bindings) == len(some_type.variables) assert all(proc in some_type.variables for proc in proc_bindings) assert all( some_type.variable_map[proc].type.bind_names == bind for proc, bind in proc_bindings.items() ) def test_regex_typedef_generic(): fcode = """ module typebound_header implicit none type header_type contains procedure :: member_routine => header_member_routine procedure :: routine_real => header_routine_real procedure :: routine_integer generic :: routine => routine_real, routine_integer end type header_type contains subroutine header_member_routine(self, val) class(header_type) :: self integer, intent(in) :: val integer :: j j = val end subroutine header_member_routine subroutine header_routine_real(self, val) class(header_type) :: self real, intent(out) :: val val = 1.0 end subroutine header_routine_real subroutine routine_integer(self, val) class(header_type) :: self integer, intent(out) :: val val = 1 end subroutine routine_integer end module typebound_header """.strip() module = Module.from_source(fcode, frontend=REGEX) assert 'header_type' in module.typedef_map header_type = module.typedef_map['header_type'] proc_bindings = { 'member_routine': ('header_member_routine',), 'routine_real': ('header_routine_real',), 'routine_integer': None, 'routine': ('routine_real', 'routine_integer') } assert len(proc_bindings) == len(header_type.variables) assert all(proc in header_type.variables for proc in proc_bindings) assert all( ( header_type.variable_map[proc].type.bind_names == bind and header_type.variable_map[proc].type.initial is None ) for proc, bind in proc_bindings.items() ) def test_regex_loki_69(): """ Test compliance of REGEX frontend with edge cases reported in LOKI-69. This should become a full-blown Scheduler test when REGEX frontend undeprins the scheduler. """ fcode = """ subroutine random_call_0(v_out,v_in,v_inout) implicit none real(kind=jprb),intent(in) :: v_in real(kind=jprb),intent(out) :: v_out real(kind=jprb),intent(inout) :: v_inout end subroutine random_call_0 !subroutine random_call_1(v_out,v_in,v_inout) !implicit none ! ! real(kind=jprb),intent(in) :: v_in ! real(kind=jprb),intent(out) :: v_out ! real(kind=jprb),intent(inout) :: v_inout ! ! !end subroutine random_call_1 subroutine random_call_2(v_out,v_in,v_inout) implicit none real(kind=jprb),intent(in) :: v_in real(kind=jprb),intent(out) :: v_out real(kind=jprb),intent(inout) :: v_inout end subroutine random_call_2 subroutine test(v_out,v_in,v_inout,some_logical) implicit none real(kind=jprb),intent(in ) :: v_in real(kind=jprb),intent(out ) :: v_out real(kind=jprb),intent(inout) :: v_inout logical,intent(in) :: some_logical v_inout = 0._jprb if(some_logical)then call random_call_0(v_out,v_in,v_inout) endif if(some_logical) call random_call_2 end subroutine test """.strip() source = Sourcefile.from_source(fcode, frontend=REGEX) assert [r.name for r in source.all_subroutines] == ['random_call_0', 'random_call_2', 'test'] calls = FindNodes(ir.CallStatement).visit(source['test'].ir) assert [call.name for call in calls] == ['RANDOM_CALL_0', 'random_call_2'] variable_map_test = source['test'].variable_map v_in_type = variable_map_test['v_in'].type assert v_in_type.dtype is BasicType.REAL assert v_in_type.kind == 'jprb' def test_regex_variable_declaration(testdir): """ Test correct parsing of derived type variable declarations """ filepath = testdir/'sources/projTypeBound/typebound_item.F90' source = Sourcefile.from_file(filepath, frontend=REGEX) driver = source['driver'] assert driver.variables == ('constant', 'obj', 'obj2', 'header', 'other_obj', 'derived', 'x', 'i') assert source['module_routine'].variables == ('m',) assert source['other_routine'].variables == ('self', 'm', 'j') assert source['routine'].variables == ('self',) assert source['routine1'].variables == ('self',) # Check this for REGEX and complete parse to make sure their behaviour is aligned for _ in range(2): var_map = driver.symbol_map assert isinstance(var_map['obj'].type.dtype, DerivedType) assert var_map['obj'].type.dtype.name == 'some_type' assert isinstance(var_map['obj2'].type.dtype, DerivedType) assert var_map['obj2'].type.dtype.name == 'some_type' assert isinstance(var_map['header'].type.dtype, DerivedType) assert var_map['header'].type.dtype.name == 'header_type' assert isinstance(var_map['other_obj'].type.dtype, DerivedType) assert var_map['other_obj'].type.dtype.name == 'other' assert isinstance(var_map['derived'].type.dtype, DerivedType) assert var_map['derived'].type.dtype.name == 'other' assert isinstance(var_map['x'].type.dtype, BasicType) assert var_map['x'].type.dtype is BasicType.REAL assert isinstance(var_map['i'].type.dtype, BasicType) assert var_map['i'].type.dtype is BasicType.INTEGER # While we're here: let's check the call statements, too calls = FindNodes(ir.CallStatement).visit(driver.ir) assert len(calls) == 7 assert all(isinstance(call.name.type.dtype, ProcedureType) for call in calls) # Note: we're explicitly accessing the string name here (instead of relying # on the StrCompareMixin) as some have dimensions that only show up in the full # parse assert calls[0].name.name == 'obj%other_routine' assert calls[0].name.parent.name == 'obj' assert calls[1].name.name == 'obj2%some_routine' assert calls[1].name.parent.name == 'obj2' assert calls[2].name.name == 'header%member_routine' assert calls[2].name.parent.name == 'header' assert calls[3].name.name == 'header%routine' assert calls[3].name.parent.name == 'header' assert calls[4].name.name == 'header%routine' assert calls[4].name.parent.name == 'header' assert calls[5].name.name == 'other_obj%member' assert calls[5].name.parent.name == 'other_obj' assert calls[6].name.name == 'derived%var%member_routine' assert calls[6].name.parent.name == 'derived%var' assert calls[6].name.parent.parent.name == 'derived' # Hack: Split the procedure binding into one-per-line until Fparser # supports this... module = source['typebound_item'] module.source.string = module.source.string.replace( 'procedure :: routine1,', 'procedure :: routine1\nprocedure ::' ) source.make_complete() def test_regex_variable_declaration_parentheses(): fcode = """ subroutine definitely_not_allfpos(ydfpdata) implicit none integer, parameter :: NMaxCloudTypes = 12 type(tfpdata), intent(in) :: ydfpdata type(tfpofn) :: ylofn(size(ydfpdata%yfpos%yfpgeometry%yfpusergeo)) real, dimension(nproma, max(nang, 1), max(nfre, 1)) :: not_an_annoying_ecwam_var character(len=511) :: cloud_type_name(NMaxCloudTypes) = ["","","","","","","","","","","",""], other_name = "", names(3) = (/ "", "", "" /) character(len=511) :: more_names(2) = (/ "What", " is" /), naaaames(2) = [ " going ", "on?" ] end subroutine definitely_not_allfpos """.strip() source = Sourcefile.from_source(fcode, frontend=REGEX) routine = source['definitely_not_allfpos'] assert routine.variables == ( 'nmaxcloudtypes', 'ydfpdata', 'ylofn', 'not_an_annoying_ecwam_var', 'cloud_type_name', 'other_name', 'names', 'more_names', 'naaaames' ) assert routine.symbol_map['not_an_annoying_ecwam_var'].type.dtype is BasicType.REAL assert routine.symbol_map['cloud_type_name'].type.dtype is BasicType.CHARACTER def test_regex_call_statement_parentheses(): """Correct handling of nested parentheses, reported in #585""" fcode = """ subroutine a_function(arg1, arg3) implicit none integer, intent(inout) :: arg1, arg3 call parse_me_wrong(arg1, arg2=[1,2,3], arg3) call parse_me_wrong2(arg1, arg2=(/1,2,3/), arg3) end subroutine a_function """.strip() source = Sourcefile.from_source(fcode, frontend=REGEX) routine = source['a_function'] calls = FindNodes(ir.CallStatement).visit(routine.ir) assert [call.name for call in calls] == ['parse_me_wrong', 'parse_me_wrong2'] def test_regex_preproc_in_contains(): fcode = """ module preproc_in_contains implicit none public :: routine1, routine2, func contains #include "some_include.h" subroutine routine1 end subroutine routine1 module subroutine mod_routine call other_routine contains #define something subroutine other_routine end subroutine other_routine end subroutine mod_routine elemental function func real func end function func end module preproc_in_contains """.strip() source = Sourcefile.from_source(fcode, frontend=REGEX) expected_names = {'preproc_in_contains', 'routine1', 'mod_routine', 'func'} actual_names = {r.name for r in source.all_subroutines} | {m.name for m in source.modules} assert expected_names == actual_names assert isinstance(source['mod_routine']['other_routine'], Subroutine) def test_regex_interface_subroutine(): fcode = """ subroutine test(callback) implicit none interface subroutine some_kernel(a, b, c) integer, intent(in) :: a, b integer, intent(out) :: c end subroutine some_kernel SUBROUTINE other_kernel(a) integer, intent(inout) :: a end subroutine end interface INTERFACE function other_func(a) integer, intent(in) :: a integer, other_func end function other_func end interface abstract interface function callback_func(a) result(b) integer, intent(in) :: a integer :: b end FUNCTION callback_func end INTERFACE procedure(callback_func), pointer, intent(in) :: callback integer :: a, b, c a = callback(1) b = other_func(a) call some_kernel(a, b, c) call other_kernel(c) end subroutine test """.strip() # Make sure only the host subroutine is captured source = Sourcefile.from_source(fcode, frontend=REGEX) assert len(source.subroutines) == 1 assert source.subroutines[0].name == 'test' assert source.subroutines[0].source.lines == (1, 38) # Make sure this also works for module procedures fcode = f""" module my_mod implicit none contains {fcode} end module my_mod """.strip() source = Sourcefile.from_source(fcode, frontend=REGEX) assert not source.subroutines assert len(source.all_subroutines) == 1 assert source.all_subroutines[0].name == 'test' assert source.all_subroutines[0].source.lines == (4, 41) def test_regex_interface_module(): fcode = """ module my_mod implicit none interface subroutine ext1 (x, y, z) real, dimension(100, 100), intent(inout) :: x, y, z end subroutine ext1 subroutine ext2 (x, z) real, intent(in) :: x complex(kind = 4), intent(inout) :: z(2000) end subroutine ext2 function ext3 (p, q) logical ext3 integer, intent(in) :: p(1000) logical, intent(in) :: q(1000) end function ext3 end interface interface sub subroutine sub_int (a) integer, intent(in) :: a(:) end subroutine sub_int subroutine sub_real (a) real, intent(in) :: a(:) end subroutine sub_real end interface sub interface func module procedure func_int module procedure func_real end interface func contains subroutine sub_int (a) integer, intent(in) :: a(:) end subroutine sub_int subroutine sub_real (a) real, intent(in) :: a(:) end subroutine sub_real integer module function func_int (a) integer, intent(in) :: a(:) end function func_int real module function func_real (a) real, intent(in) :: a(:) end function func_real end module my_mod """.strip() source = Sourcefile.from_source(fcode, frontend=REGEX, parser_classes=RegexParserClass.ProgramUnitClass) assert len(source.modules) == 1 assert source['my_mod'] is not None assert not source['my_mod'].interfaces source.make_complete( frontend=REGEX, parser_class=RegexParserClass.ProgramUnitClass | RegexParserClass.InterfaceClass ) assert len(source['my_mod'].interfaces) == 3 assert source['my_mod'].symbols == ( 'ext1', 'ext2', 'ext3', 'sub', 'sub_int', 'sub_real', 'func', 'func_int', 'func_real', 'func_int', 'func_real', 'sub_int', 'sub_real', 'func_int', 'func_real' ) def test_regex_function_inline_return_type(): fcode = """ REAL(KIND=JPRB) FUNCTION DOT_PRODUCT_ECV() END FUNCTION DOT_PRODUCT_ECV SUBROUTINE DOT_PROD_SP_2D() END SUBROUTINE DOT_PROD_SP_2D """.strip() source = Sourcefile.from_source(fcode, frontend=REGEX) assert { routine.name.lower() for routine in source.subroutines } == {'dot_product_ecv', 'dot_prod_sp_2d'} assert isinstance(source['dot_product_ecv'], Function) assert isinstance(source['dot_prod_sp_2d'], Subroutine) source.make_complete() function = source['dot_product_ecv'] assert function.return_type.dtype == BasicType.REAL assert function.return_type.kind == 'JPRB' @pytest.mark.parametrize('frontend', available_frontends()) def test_regex_prefix(frontend, tmp_path): fcode = """ module some_mod implicit none contains pure elemental real function f_elem(a) real, intent(in) :: a f_elem = a end function f_elem pure recursive integer function fib(i) result(fib_i) integer, intent(in) :: i if (i <= 0) then fib_i = 0 else if (i == 1) then fib_i = 1 else fib_i = fib(i-1) + fib(i-2) end if end function fib end module some_mod """.strip() source = Sourcefile.from_source(fcode, frontend=REGEX) assert source['f_elem'].prefix == ('pure elemental real',) assert source['fib'].prefix == ('pure recursive integer',) source.make_complete(frontend=frontend, xmods=[tmp_path]) assert tuple(p.lower() for p in source['f_elem'].prefix) == ('pure', 'elemental') assert tuple(p.lower() for p in source['fib'].prefix) == ('pure', 'recursive') def test_regex_fypp(): """ Test that unexpanded fypp-annotations are handled gracefully in the REGEX frontend. """ fcode = """ module fypp_mod ! A pre-set array of pre-prcessor variables #:mute #:set foo = [2,3,4,5] #:endmute contains ! A non-templated routine subroutine first_routine(i, x) integer, intent(in) :: i real, intent(inout) :: x(3) end subroutine first_routine ! A fypp-loop with in-place directives for subroutine names #:for bar in foo #:set rname = 'routine_%s' % (bar,) subroutine ${rname}$ (i, x) integer, intent(in) :: i real, intent(inout) :: x(3) end subroutine ${rname}$ #:endfor ! Another non-templated routine subroutine last_routine(i, x) integer, intent(in) :: i real, intent(inout) :: x(3) end subroutine last_routine end module fypp_mod """ source = Sourcefile.from_source(fcode, frontend=REGEX) module = source['fypp_mod'] assert isinstance(module, Module) # Check that only non-templated routines are included assert len(module.routines) == 2 assert module.routines[0].name == 'first_routine' assert module.routines[1].name == 'last_routine' def test_declaration_whitespace_attributes(): """ Test correct behaviour with/without white space inside declaration attributes (reported in #318). """ fcode = """ subroutine my_whitespace_declaration_routine(kdim, state_t0, paux) use type_header, only: dimension_type, STATE_TYPE, aux_type, jprb implicit none TYPE( DIMENSION_TYPE) , INTENT (IN) :: KDIM type (state_type ) , intent ( in ) :: state_t0 TYPE (AUX_TYPE) , InteNT( In) :: PAUX CHARACTER ( LEN=10) :: STR REAL( KIND = JPRB ) :: VAR end subroutine """.strip() routine = Subroutine.from_source(fcode, frontend=REGEX) # Verify that variables and dtype information has been extracted correctly assert routine.variables == ('kdim', 'state_t0', 'paux', 'str', 'var') assert isinstance(routine.variable_map['kdim'].type.dtype, DerivedType) assert routine.variable_map['kdim'].type.dtype.name.lower() == 'dimension_type' assert isinstance(routine.variable_map['state_t0'].type.dtype, DerivedType) assert routine.variable_map['state_t0'].type.dtype.name.lower() == 'state_type' assert isinstance(routine.variable_map['paux'].type.dtype, DerivedType) assert routine.variable_map['paux'].type.dtype.name.lower() == 'aux_type' assert routine.variable_map['str'].type.dtype == BasicType.CHARACTER assert routine.variable_map['var'].type.dtype == BasicType.REAL routine.make_complete() # Verify that additional type attributes are correct after full parse assert routine.variables == ('kdim', 'state_t0', 'paux', 'str', 'var') assert isinstance(routine.variable_map['kdim'].type.dtype, DerivedType) assert routine.variable_map['kdim'].type.dtype.name.lower() == 'dimension_type' assert routine.variable_map['kdim'].type.intent == 'in' assert isinstance(routine.variable_map['state_t0'].type.dtype, DerivedType) assert routine.variable_map['state_t0'].type.dtype.name.lower() == 'state_type' assert routine.variable_map['state_t0'].type.intent == 'in' assert isinstance(routine.variable_map['paux'].type.dtype, DerivedType) assert routine.variable_map['paux'].type.dtype.name.lower() == 'aux_type' assert routine.variable_map['paux'].type.intent == 'in' assert routine.variable_map['str'].type.dtype == BasicType.CHARACTER assert routine.variable_map['str'].type.length == 10 assert routine.variable_map['var'].type.dtype == BasicType.REAL assert routine.variable_map['var'].type.kind == 'jprb' def test_regex_sanitize_fypp_line_annotations(): """ Test that fypp line number annotations are sanitized correctly. """ fcode = """ module some_templated_mod # 1 "/path-to-hypp-macro/macro.hypp" 1 # 2 "/path-to-hypp-macro/macro.hypp" # 3 "/path-to-hypp-macro/macro.hypp" # 5 "/path-to-fypp-template/template.fypp" 2 integer :: a0 integer :: a1 integer :: a2 integer :: a3 integer :: a4 end module some_templated_mod """ module = Module.from_source(fcode, frontend=REGEX) decls = FindNodes(ir.VariableDeclaration).visit(module.spec) assert len(decls) == 5 def test_regex_pragma(): """ Make sure the regex frontend can parse pragmas. """ fcode = """ SUBROUTINE FOO(A) INTEGER, INTENT(IN) :: A ! make sure this won't end up as VariableDeclaration ! INTEGER :: B ! make sure this won't end up as VariableDeclaration !$loki INTEGER :: C ! this is just a comment !$loki this-is-a-pragma !$acc this is another openacc pragma !$omp multiline & !$omp & pragma to be tested END SUBROUTINE FOO """.strip() source = Sourcefile.from_source(fcode, frontend=REGEX) routine = source['FOO'] pragmas = FindNodes(ir.Pragma).visit(routine.ir) var_decls = FindNodes(ir.VariableDeclaration).visit(routine.ir) assert len(pragmas) == 4 assert pragmas[0].keyword == 'loki' assert pragmas[0].content == 'INTEGER :: C' assert pragmas[1].keyword == 'loki' assert pragmas[1].content == 'this-is-a-pragma' assert pragmas[2].keyword == 'acc' assert pragmas[2].content == 'this is another openacc pragma' assert pragmas[3].keyword == 'omp' assert pragmas[3].content == 'multiline & pragma to be tested' assert len(var_decls) == 1 assert var_decls[0].symbols == ('A',) # compare with fully parsed source source.make_complete() compl_pragmas = FindNodes(ir.Pragma).visit(routine.ir) for compl_pragma, pragma in zip(compl_pragmas, pragmas): assert compl_pragma.keyword == pragma.keyword assert compl_pragma.content == pragma.content def test_regex_comments(): """ Make sure the REGEX frontend doesn't match any comments """ fcode = """ SUBROUTINE my_routine ! use my_mod use other_mod, only: foo use third_mod ! use fourth_mod use fifth_mod! , only: bar implicit none ! type my_type type other_type end type integer :: var !, val ! $acc not an acc pragma !$ acc also not an acc pragma var = 1 !$acc definitely not a pragma !!$acc not a pragma either !$$acc no pragma var = 1 & &+1!$acc again no pragma call some_routine(var) var = var ! + function(val) var = var + 1 ! call other_routine(val) !call third routine(val) END SUBROUTINE my_routine """.strip() source = Sourcefile.from_source(fcode, frontend=REGEX) routine = source['my_routine'] assert len(routine.imports) == 3 assert [imprt.module for imprt in routine.imports] == ['other_mod', 'third_mod', 'fifth_mod'] assert len(calls := FindNodes(ir.CallStatement).visit(routine.ir)) == 1 and calls[0].name == 'some_routine' assert not FindNodes(ir.Pragma).visit(routine.ir) loki-ecmwf-0.3.6/loki/frontend/tests/test_frontends.py0000664000175000017500000010515715167130205023304 0ustar alastairalastair# (C) Copyright 2018- ECMWF. # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. """ Verify correct frontend behaviour and correct parsing of certain Fortran language features. """ import numpy as np import pytest from loki import ( Function, Module, Subroutine, Sourcefile, BasicType, config, config_override ) from loki.jit_build import jit_compile from loki.expression import symbols as sym from loki.frontend import available_frontends, OMNI, FP, HAVE_FP from loki.ir import nodes as ir, FindNodes, FindVariables @pytest.fixture(name='reset_frontend_mode') def fixture_reset_frontend_mode(): original_frontend_mode = config['frontend-strict-mode'] yield config['frontend-strict-mode'] = original_frontend_mode @pytest.mark.parametrize('frontend', available_frontends()) def test_check_alloc_opts(tmp_path, frontend): """ Test the use of SOURCE and STAT in allocate """ fcode = """ module alloc_mod integer, parameter :: jprb = selected_real_kind(13,300) type explicit real(kind=jprb) :: scalar, vector(3), matrix(3, 3) real(kind=jprb) :: red_herring end type explicit type deferred real(kind=jprb), allocatable :: scalar, vector(:), matrix(:, :) real(kind=jprb), allocatable :: red_herring end type deferred contains subroutine alloc_deferred(item) type(deferred), intent(inout) :: item integer :: stat allocate(item%vector(3), stat=stat) allocate(item%matrix(3, 3)) end subroutine alloc_deferred subroutine free_deferred(item) type(deferred), intent(inout) :: item integer :: stat deallocate(item%vector, stat=stat) deallocate(item%matrix) end subroutine free_deferred subroutine check_alloc_source(item, item2) type(explicit), intent(inout) :: item type(deferred), intent(inout) :: item2 real(kind=jprb), allocatable :: vector(:), vector2(:) allocate(vector, source=item%vector) vector(:) = vector(:) + item%scalar item%vector(:) = vector(:) allocate(vector2, source=item2%vector) ! Try mold here when supported by fparser vector2(:) = item2%scalar item2%vector(:) = vector2(:) end subroutine check_alloc_source end module alloc_mod """.strip() # Parse the source and validate the IR module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path]) allocations = FindNodes(ir.Allocation).visit(module['check_alloc_source'].body) assert len(allocations) == 2 assert all(alloc.data_source is not None for alloc in allocations) assert all(alloc.status_var is None for alloc in allocations) allocations = FindNodes(ir.Allocation).visit(module['alloc_deferred'].body) assert len(allocations) == 2 assert all(alloc.data_source is None for alloc in allocations) assert allocations[0].status_var is not None assert allocations[1].status_var is None deallocs = FindNodes(ir.Deallocation).visit(module['free_deferred'].body) assert len(deallocs) == 2 assert deallocs[0].status_var is not None assert deallocs[1].status_var is None # Sanity check for the backend assert module.to_fortran().lower().count(', stat=stat') == 2 # Generate Fortran and test it filepath = tmp_path/(f'frontends_check_alloc_{frontend}.f90') mod = jit_compile(module, filepath=filepath, objname='alloc_mod') item = mod.explicit() item.scalar = 1. item.vector[:] = 1. item2 = mod.deferred() mod.alloc_deferred(item2) item2.scalar = 2. item2.vector[:] = -1. mod.check_alloc_source(item, item2) assert (item.vector == 2.).all() assert (item2.vector == 2.).all() mod.free_deferred(item2) @pytest.mark.parametrize('frontend', available_frontends()) def test_associates(tmp_path, frontend): """Test the use of associate to access and modify other items""" fcode = """ module derived_types_mod integer, parameter :: jprb = selected_real_kind(13,300) type explicit real(kind=jprb) :: scalar, vector(3), matrix(3, 3) real(kind=jprb) :: red_herring end type explicit type deferred real(kind=jprb), allocatable :: scalar, vector(:), matrix(:, :) real(kind=jprb), allocatable :: red_herring end type deferred contains subroutine alloc_deferred(item) type(deferred), intent(inout) :: item allocate(item%vector(3)) allocate(item%matrix(3, 3)) end subroutine alloc_deferred subroutine free_deferred(item) type(deferred), intent(inout) :: item deallocate(item%vector) deallocate(item%matrix) end subroutine free_deferred subroutine associates(item) type(explicit), intent(inout) :: item type(deferred) :: item2 item%scalar = 17.0 associate(vector2=>item%matrix(:,1)) vector2(:) = 3. item%matrix(:,3) = vector2(:) end associate associate(vector=>item%vector) item%vector(2) = vector(1) vector(3) = item%vector(1) + vector(2) vector(1) = 1. end associate call alloc_deferred(item2) associate(vec=>item2%vector(2)) vec = 1. end associate call free_deferred(item2) end subroutine associates end module """ # Test the internals module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path]) routine = module['associates'] variables = FindVariables().visit(routine.body) assert all( v.shape == ('3',) for v in variables if v.name in ['vector', 'vector2'] ) for assoc in FindNodes(ir.Associate).visit(routine.body): for var in FindVariables().visit(assoc.body): if var.name in assoc.variables: assert var.scope is assoc assert var.type.parent is None else: assert var.scope is routine # Test the generated module filepath = tmp_path/(f'derived_types_associates_{frontend}.f90') mod = jit_compile(module, filepath=filepath, objname='derived_types_mod') item = mod.explicit() item.scalar = 0. item.vector[0] = 5. item.vector[1:2] = 0. mod.associates(item) assert item.scalar == 17.0 and (item.vector == [1., 5., 10.]).all() @pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'OMNI fails to read without full module')])) def test_associates_deferred(frontend): """ Verify that reading in subroutines with deferred external type definitions and associates working on that are supported. """ fcode = """ SUBROUTINE ASSOCIATES_DEFERRED(ITEM, IDX) USE SOME_MOD, ONLY: SOME_TYPE IMPLICIT NONE TYPE(SOME_TYPE), INTENT(IN) :: ITEM INTEGER, INTENT(IN) :: IDX ASSOCIATE(SOME_VAR=>ITEM%SOME_VAR(IDX), SOME_OTHER_VAR=>ITEM%SOME_VAR(ITEM%OFFSET)) SOME_VAR = 5 END ASSOCIATE END SUBROUTINE """ routine = Subroutine.from_source(fcode, frontend=frontend) variables = {v.name: v for v in FindVariables().visit(routine.body)} assert len(variables) == 6 some_var = variables['SOME_VAR'] assert isinstance(some_var, sym.DeferredTypeSymbol) assert some_var.name.upper() == 'SOME_VAR' assert some_var.type.dtype == BasicType.DEFERRED associate = FindNodes(ir.Associate).visit(routine.body)[0] assert some_var.scope is associate some_other_var = variables['SOME_OTHER_VAR'] assert isinstance(some_var, sym.DeferredTypeSymbol) assert some_other_var.name.upper() == 'SOME_OTHER_VAR' assert some_other_var.type.dtype == BasicType.DEFERRED assert some_other_var.type.shape == ('ITEM%OFFSET',) assert some_other_var.scope is associate @pytest.mark.parametrize('frontend', available_frontends()) def test_associates_expr(tmp_path, frontend): """Verify that associates with expressions are supported""" fcode = """ subroutine associates_expr(in, out) implicit none integer, intent(in) :: in(3) integer, intent(out) :: out(3) out(:) = 0 associate(a=>1+3) out(:) = out(:) + a end associate associate(b=>2*in(:) + in(:)) out(:) = out(:) + b(:) end associate end subroutine associates_expr """.strip() routine = Subroutine.from_source(fcode, frontend=frontend) variables = {v.name: v for v in FindVariables().visit(routine.body)} assert len(variables) == 4 assert isinstance(variables['a'], sym.DeferredTypeSymbol) assert variables['a'].type.dtype is BasicType.DEFERRED # TODO: support type derivation for expressions assert isinstance(variables['b'], sym.Array) # Note: this is an array because we have a shape assert variables['b'].type.dtype is BasicType.DEFERRED # TODO: support type derivation for expressions assert variables['b'].type.shape == ('3',) filepath = tmp_path/(f'associates_expr_{frontend}.f90') function = jit_compile(routine, filepath=filepath, objname=routine.name) a = np.array([1, 2, 3], dtype='i') b = np.zeros(3, dtype='i') function(a, b) assert np.all(b == [7, 10, 13]) @pytest.mark.parametrize('frontend', available_frontends()) def test_enum(tmp_path, frontend): """Verify that enums are represented correctly""" # F2008, Note 4.67 fcode = """ subroutine test_enum (out) implicit none ! Comment 1 ENUM, BIND(C) ENUMERATOR :: RED = 4, BLUE = 9 ! Comment 2 ENUMERATOR YELLOW END ENUM ! Comment 3 integer, intent(out) :: out out = RED + BLUE + YELLOW end subroutine test_enum """.strip() routine = Subroutine.from_source(fcode, frontend=frontend) # Check Enum exists enums = FindNodes(ir.Enumeration).visit(routine.spec) assert len(enums) == 1 # Check symbols are available assert enums[0].symbols == ('red', 'blue', 'yellow') assert all(name in routine.symbols for name in ('red', 'blue', 'yellow')) assert all(s.scope is routine for s in enums[0].symbols) # Check assigned values assert routine.symbol_map['red'].type.initial == '4' assert routine.symbol_map['blue'].type.initial == '9' assert routine.symbol_map['yellow'].type.initial is None # Verify comments are preserved (don't care about the actual place) code = routine.to_fortran() for i in range(1, 4): assert f'! Comment {i}' in code # Check fgen produces valid code and runs filepath = tmp_path/(f'{routine.name}_{frontend}.f90') function = jit_compile(routine, filepath=filepath, objname=routine.name) out = function() assert out == 23 @pytest.mark.parametrize('frontend', available_frontends()) @pytest.mark.usefixtures('reset_frontend_mode') def test_frontend_strict_mode(frontend, tmp_path): """ Verify that frontends fail on unsupported features if strict mode is enabled """ # Parameterized derived types currently not implemented fcode = """ module frontend_strict_mode implicit none TYPE matrix ( k, b ) INTEGER, KIND :: k = 4 INTEGER (8), LEN :: b REAL (k) :: element (b,b) END TYPE matrix end module frontend_strict_mode """ config['frontend-strict-mode'] = True with pytest.raises(NotImplementedError): Module.from_source(fcode, frontend=frontend, xmods=[tmp_path]) config['frontend-strict-mode'] = False module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path]) assert 'matrix' in module.symbol_attrs assert 'matrix' in module.typedef_map @pytest.mark.parametrize('frontend', available_frontends()) def test_frontend_pragma_vs_comment(frontend, tmp_path): """ Make sure pragmas and comments are identified correctly """ fcode = """ module frontend_pragma_vs_comment implicit none !$some pragma integer :: var1 !!$some comment integer :: var2 !some comment integer :: var3 !$some pragma integer :: var4 ! !$some comment integer :: var5 end module frontend_pragma_vs_comment """.strip() module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path]) pragmas = FindNodes(ir.Pragma).visit(module.ir) comments = FindNodes(ir.Comment).visit(module.ir) assert len(pragmas) == 2 assert len(comments) == 3 assert all(pragma.keyword == 'some' for pragma in pragmas) assert all(pragma.content == 'pragma' for pragma in pragmas) assert all('some comment' in comment.text for comment in comments) @pytest.mark.parametrize('frontend', available_frontends()) def test_frontend_main_program(frontend): """ Loki can't handle PROGRAM blocks and the frontends should throw an exception """ fcode = """ program hello print *, "Hello World!" end program """.strip() with config_override({'frontend-strict-mode': True}): with pytest.raises(NotImplementedError): Sourcefile.from_source(fcode, frontend=frontend) source = Sourcefile.from_source(fcode, frontend=frontend) assert source.ir.body == () @pytest.mark.parametrize('frontend', available_frontends()) def test_frontend_source_lineno(frontend): """ ... """ fcode = """ subroutine driver call kernel() call kernel() call kernel() end subroutine driver """ source = Sourcefile.from_source(fcode, frontend=frontend) routine = source['driver'] calls = FindNodes(ir.CallStatement).visit(routine.body) assert calls[0] != calls[1] assert calls[1] != calls[2] assert calls[0].source.lines[0] < calls[1].source.lines[0] < calls[2].source.lines[0] @pytest.mark.parametrize( 'frontend', available_frontends(include_regex=True, xfail=[(OMNI, 'OMNI may segfault on empty files')]) ) @pytest.mark.parametrize('fcode', ['', '\n', '\n\n\n\n']) def test_frontend_empty_file(frontend, fcode): """Ensure that all frontends can handle empty source files correctly (#186)""" source = Sourcefile.from_source(fcode, frontend=frontend) assert isinstance(source.ir, ir.Section) assert not source.to_fortran().strip() @pytest.mark.parametrize('frontend', available_frontends()) def test_pragma_line_continuation(frontend): """ Test that multi-line pragmas are parsed and dealt with correctly. """ fcode = """ SUBROUTINE TOTO(A,B) IMPLICIT NONE REAL, INTENT(IN) :: A REAL, INTENT(INOUT) :: B !$ACC PARALLEL LOOP GANG & !$ACC& PRESENT(ZRDG_LCVQ,ZFLU_QSATS,ZRDG_CVGQ) & !$ACC& PRIVATE (JBLK) & !$ACC& VECTOR_LENGTH (YDCPG_OPTS%KLON) !$ACC SEQUENTIAL END SUBROUTINE TOTO """ routine = Subroutine.from_source(fcode, frontend=frontend) pragmas = FindNodes(ir.Pragma).visit(routine.body) assert len(pragmas) == 2 assert pragmas[0].keyword == 'ACC' assert 'PARALLEL' in pragmas[0].content assert 'PRESENT' in pragmas[0].content assert 'PRIVATE' in pragmas[0].content assert 'VECTOR_LENGTH' in pragmas[0].content assert pragmas[1].content == 'SEQUENTIAL' # Check that source object was generated right assert pragmas[0].source assert pragmas[0].source.lines == (8, 8) if frontend == OMNI else (8, 11) assert pragmas[1].source assert pragmas[1].source.lines == (12, 12) @pytest.mark.parametrize('frontend', available_frontends()) def test_comment_block_clustering(frontend): """ Test that multiple :any:`Comment` nodes into a :any:`CommentBlock`. """ fcode = """ subroutine test_comment_block(a, b) ! What is this? ! Ohhh, ... a docstring? real, intent(inout) :: a, b a = a + 1.0 ! Never gonna b = b + 2 ! give you ! up... a = a + b ! Shut up, ... ! Rick! end subroutine test_comment_block """ routine = Subroutine.from_source(fcode, frontend=frontend) comments = FindNodes(ir.Comment).visit(routine.spec) assert len(comments) == 0 blocks = FindNodes(ir.CommentBlock).visit(routine.spec) assert len(blocks) == 0 assert isinstance(routine.docstring[0], ir.CommentBlock) assert len(routine.docstring[0].comments) == 2 assert routine.docstring[0].comments[0].text == '! What is this?' assert routine.docstring[0].comments[1].text == '! Ohhh, ... a docstring?' comments = FindNodes(ir.Comment).visit(routine.body) assert len(comments) == 2 if frontend == FP else 1 assert comments[-1].text == '! Never gonna' blocks = FindNodes(ir.CommentBlock).visit(routine.body) assert len(blocks) == 2 assert len(blocks[0].comments) == 3 if frontend == FP else 2 assert blocks[0].comments[0].text == '! give you' assert blocks[0].comments[1].text == '! up...' assert len(blocks[1].comments) == 2 assert blocks[1].comments[0].text == '! Shut up, ...' assert blocks[1].comments[1].text == '! Rick!' @pytest.mark.parametrize('frontend', available_frontends( xfail=[(OMNI, 'OMNI strips comments during parse')] )) def test_inline_comments(frontend): """ Test that multiple :any:`Comment` nodes into a :any:`CommentBlock`. """ fcode = """ subroutine test_inline_comments(a, b) real, intent(inout) :: a, b ! We don't need no education real, external :: alien_func ! We don't need no thought control integer :: i a = a + 1.0 ! Who said that? b = b + 2 ! All in all it's just another do i=1, 10 b = b + 2 ! Brick in the ... enddo a = a + alien_func() ! wall ! end subroutine test_inline_comments """ routine = Subroutine.from_source(fcode, frontend=frontend) decls = FindNodes(ir.VariableDeclaration).visit(routine.spec) assert len(decls) == 2 assert decls[0].comment.text == "! We don't need no education" assert decls[1].comment is None proc_decls = FindNodes(ir.ProcedureDeclaration).visit(routine.spec) assert len(proc_decls) == 1 assert proc_decls[0].comment.text == "! We don't need no thought control" assigns = FindNodes(ir.Assignment).visit(routine.body) assert len(assigns) == 4 assert assigns[0].comment is None assert assigns[1].comment.text == "! All in all it's just another" assert assigns[2].comment.text == '! Brick in the ...' assert assigns[3].comment.text == '! wall !' comments = FindNodes(ir.Comment).visit(routine.body) assert len(comments) == 4 assert comments[1].text == '! Who said that?' assert comments[0].text == comments[2].text == comments[3].text == '' @pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'OMNI does not like Loki pragmas, yet!')])) def test_frontend_routine_variables_dimension_pragmas(frontend): """ Test that `!$loki dimension` pragmas can be used to override the conceptual `.shape` of local and argument variables. """ fcode = """ subroutine routine_variables_dimensions(x, y, v0, v1, v2, v3, v4) integer, parameter :: jprb = selected_real_kind(13,300) integer, intent(in) :: x, y !$loki dimension(10) real(kind=jprb), intent(inout) :: v0(:) !$loki dimension(x) real(kind=jprb), intent(inout) :: v1(:) !$loki dimension(x,y,:) real(kind=jprb), dimension(:,:,:), intent(inout) :: v2, v3 !$loki dimension(x,y) real(kind=jprb), pointer, intent(inout) :: v4(:,:) !$loki dimension(x+y,2*x) real(kind=jprb), allocatable :: v5(:,:) !$loki dimension(x/2, x**2, (x+y)/x) real(kind=jprb), dimension(:, :, :), pointer :: v6 end subroutine routine_variables_dimensions """ def to_str(expr): return str(expr).lower().replace(' ', '') routine = Subroutine.from_source(fcode, frontend=frontend) assert routine.variable_map['v0'].shape[0] == 10 assert isinstance(routine.variable_map['v0'].shape[0], sym.IntLiteral) assert isinstance(routine.variable_map['v1'].shape[0], sym.Scalar) assert routine.variable_map['v2'].shape[0] == 'x' assert routine.variable_map['v2'].shape[1] == 'y' assert routine.variable_map['v2'].shape[2] == ':' assert isinstance(routine.variable_map['v2'].shape[0], sym.Scalar) assert isinstance(routine.variable_map['v2'].shape[1], sym.Scalar) assert isinstance(routine.variable_map['v2'].shape[2], sym.RangeIndex) assert routine.variable_map['v3'].shape[0] == 'x' assert routine.variable_map['v3'].shape[1] == 'y' assert routine.variable_map['v3'].shape[2] == ':' assert isinstance(routine.variable_map['v3'].shape[0], sym.Scalar) assert isinstance(routine.variable_map['v3'].shape[1], sym.Scalar) assert isinstance(routine.variable_map['v3'].shape[2], sym.RangeIndex) assert routine.variable_map['v4'].shape[0] == 'x' assert routine.variable_map['v4'].shape[1] == 'y' assert isinstance(routine.variable_map['v4'].shape[0], sym.Scalar) assert isinstance(routine.variable_map['v4'].shape[1], sym.Scalar) assert to_str(routine.variable_map['v5'].shape[0]) == 'x+y' assert to_str(routine.variable_map['v5'].shape[1]) == '2*x' assert isinstance(routine.variable_map['v5'].shape[0], sym.Sum) assert isinstance(routine.variable_map['v5'].shape[1], sym.Product) assert to_str(routine.variable_map['v6'].shape[0]) == 'x/2' assert to_str(routine.variable_map['v6'].shape[1]) == 'x**2' assert to_str(routine.variable_map['v6'].shape[2]) == '(x+y)/x' assert isinstance(routine.variable_map['v6'].shape[0], sym.Quotient) assert isinstance(routine.variable_map['v6'].shape[1], sym.Power) assert isinstance(routine.variable_map['v6'].shape[2], sym.Quotient) @pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'OMNI does not like Loki pragmas, yet!')])) def test_frontend_module_variables_dimension_pragmas(frontend, tmp_path): """ Test that `!$loki dimension` pragmas can be used to override the conceptual `.shape` of module variables. """ code_mod = """ module mod_variable_dimensions integer, parameter :: jprb = selected_real_kind(13,300) integer :: x, y !$loki dimension(10) real(kind=jprb), intent(inout) :: v0(:) !$loki dimension(x) real(kind=jprb), intent(inout) :: v1(:) !$loki dimension(x,y,:) real(kind=jprb), dimension(:,:,:), intent(inout) :: v2, v3 !$loki dimension(x,y) real(kind=jprb), pointer, intent(inout) :: v4(:,:) !$loki dimension(x+y,2*x) real(kind=jprb), allocatable :: v5(:,:) !$loki dimension(x/2, x**2, (x+y)/x) real(kind=jprb), dimension(:, :, :), pointer :: v6 end module mod_variable_dimensions """ def to_str(expr): return str(expr).lower().replace(' ', '') mod = Module.from_source(code_mod, frontend=frontend, xmods=[tmp_path]) variable_map = mod.variable_map assert variable_map['v0'].shape[0] == 10 assert isinstance(variable_map['v0'].shape[0], sym.IntLiteral) assert isinstance(variable_map['v1'].shape[0], sym.Scalar) assert variable_map['v2'].shape[0] == 'x' assert variable_map['v2'].shape[1] == 'y' assert variable_map['v2'].shape[2] == ':' assert isinstance(variable_map['v2'].shape[0], sym.Scalar) assert isinstance(variable_map['v2'].shape[1], sym.Scalar) assert isinstance(variable_map['v2'].shape[2], sym.RangeIndex) assert variable_map['v3'].shape[0] == 'x' assert variable_map['v3'].shape[1] == 'y' assert variable_map['v3'].shape[2] == ':' assert isinstance(variable_map['v3'].shape[0], sym.Scalar) assert isinstance(variable_map['v3'].shape[1], sym.Scalar) assert isinstance(variable_map['v3'].shape[2], sym.RangeIndex) assert variable_map['v4'].shape[0] == 'x' assert variable_map['v4'].shape[1] == 'y' assert isinstance(variable_map['v4'].shape[0], sym.Scalar) assert isinstance(variable_map['v4'].shape[1], sym.Scalar) assert to_str(variable_map['v5'].shape[0]) == 'x+y' assert to_str(variable_map['v5'].shape[1]) == '2*x' assert isinstance(variable_map['v5'].shape[0], sym.Sum) assert isinstance(variable_map['v5'].shape[1], sym.Product) assert to_str(variable_map['v6'].shape[0]) == 'x/2' assert to_str(variable_map['v6'].shape[1]) == 'x**2' assert to_str(variable_map['v6'].shape[2]) == '(x+y)/x' assert isinstance(variable_map['v6'].shape[0], sym.Quotient) assert isinstance(variable_map['v6'].shape[1], sym.Power) assert isinstance(variable_map['v6'].shape[2], sym.Quotient) @pytest.mark.parametrize('frontend', available_frontends()) def test_import_of_private_symbols(tmp_path, frontend): """ Verify that only public symbols are imported from other modules. """ code_mod_private = """ module mod_private private integer :: var end module mod_private """ code_mod_public = """ module mod_public public integer:: var end module mod_public """ code_mod_main = """ module mod_main use mod_public use mod_private contains subroutine test_routine() integer :: result result = var end subroutine test_routine end module mod_main """ mod_private = Module.from_source(code_mod_private, frontend=frontend, xmods=[tmp_path]) mod_public = Module.from_source(code_mod_public, frontend=frontend, xmods=[tmp_path]) mod_main = Module.from_source( code_mod_main, frontend=frontend, definitions=[mod_private, mod_public], xmods=[tmp_path] ) var = mod_main.subroutines[0].body.body[0].rhs # Check if this is really our symbol assert var.name == "var" assert var.scope is mod_main # Check if the symbol is imported assert var.type.imported is True # Check if the symbol comes from the mod_public module assert var.type.module is mod_public @pytest.mark.parametrize('frontend', available_frontends()) def test_access_spec(tmp_path, frontend): """ Check that access-spec statements are dealt with correctly. """ code_mod_private_var_public = """ module mod_private_var_public private integer :: var public :: var end module mod_private_var_public """ code_mod_public_var_private = """ module mod_public_var_private public integer :: var private :: var end module mod_public_var_private """ code_mod_main = """ module mod_main use mod_private_var_public use mod_public_var_private contains subroutine test_routine() integer :: result result = var end subroutine test_routine end module mod_main """ mod_private_var_public = Module.from_source(code_mod_private_var_public, frontend=frontend, xmods=[tmp_path]) mod_public_var_private = Module.from_source(code_mod_public_var_private, frontend=frontend, xmods=[tmp_path]) mod_main = Module.from_source( code_mod_main, frontend=frontend, definitions=[mod_private_var_public, mod_public_var_private], xmods=[tmp_path] ) var = mod_main.subroutines[0].body.body[0].rhs # Check if this is really our symbol assert var.name == "var" assert var.scope is mod_main # Check if the symbol is imported assert var.type.imported is True # Check if the symbol comes from the mod_private_var_public module assert var.type.module is mod_private_var_public @pytest.mark.parametrize('frontend', available_frontends( xfail=[(OMNI, 'OMNI does not like intrinsic shading for member functions!')] )) def test_intrinsic_shadowing(tmp_path, frontend): """ Test that locally defined functions that shadow intrinsics are handled. """ fcode_algebra = """ module algebra_mod implicit none contains function dot_product(a, b) result(c) real(kind=8), intent(inout) :: a(:), b(:) real(kind=8) :: c end function dot_product function min(x, y) real(kind=8), intent(in) :: x, y real(kind=8) :: min min = y if (x < y) min = x end function min end module algebra_mod """ fcode = """ module test_intrinsics_mod use algebra_mod, only: dot_product implicit none contains subroutine test_intrinsics(a, b, c, d) use algebra_mod, only: min implicit none real(kind=8), intent(inout) :: a(:), b(:) real(kind=8) :: c, d, e c = dot_product(a, b) d = max(c, a(1)) e = min(c, a(1)) contains function max(x, y) real(kind=8), intent(in) :: x, y real(kind=8) :: max max = y if (x > y) max = x end function max end subroutine test_intrinsics end module test_intrinsics_mod """ algebra = Module.from_source(fcode_algebra, frontend=frontend, xmods=[tmp_path]) module = Module.from_source( fcode, definitions=algebra, frontend=frontend, xmods=[tmp_path] ) routine = module['test_intrinsics'] assigns = FindNodes(ir.Assignment).visit(routine.body) assert len(assigns) == 3 assert isinstance(assigns[0].rhs.function, sym.ProcedureSymbol) assert not assigns[0].rhs.function.type.is_intrinsic assert assigns[0].rhs.function.type.dtype.procedure == algebra['dot_product'] assert isinstance(assigns[1].rhs.function, sym.ProcedureSymbol) assert not assigns[1].rhs.function.type.is_intrinsic assert assigns[1].rhs.function.type.dtype.procedure == routine.members[0] assert isinstance(assigns[2].rhs.function, sym.ProcedureSymbol) assert not assigns[2].rhs.function.type.is_intrinsic assert assigns[2].rhs.function.type.dtype.procedure == algebra['min'] @pytest.mark.parametrize('frontend', available_frontends()) def test_function_symbol_scoping(frontend): """ Check that the return symbol of a function has the right scope """ fcode = """ real(kind=8) function double_real(i) implicit none integer, intent(in) :: i double_real = dble(i*2) end function double_real """ routine = Function.from_source(fcode, frontend=frontend) rtyp = routine.symbol_attrs['double_real'] assert rtyp.dtype == BasicType.REAL assert rtyp.kind == 8 assigns = FindNodes(ir.Assignment).visit(routine.body) assert len(assigns) == 1 assert assigns[0].lhs == 'double_real' assert isinstance(assigns[0].lhs, sym.Scalar) assert assigns[0].lhs.type.dtype == BasicType.REAL assert assigns[0].lhs.type.kind == 8 assert assigns[0].lhs.scope == routine @pytest.mark.parametrize('frontend', available_frontends()) def test_frontend_derived_type_imports(tmp_path, frontend): """ Checks that provided module and type info is attached during parse """ fcode_module = """ module my_type_mod type my_type real(kind=8) :: a, b(:) end type my_type end module my_type_mod """ fcode = """ subroutine test_derived_type_parse use my_type_mod, only: my_type implicit none type(my_type) :: obj obj%a = 42.0 obj%b = 66.6 end subroutine test_derived_type_parse """ module = Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path]) routine = Subroutine.from_source( fcode, definitions=module, frontend=frontend, xmods=[tmp_path] ) assert len(module.typedefs) == 1 assert module.typedefs[0].name == 'my_type' # Ensure that the imported type is recognised as such assert len(routine.imports) == 1 assert routine.imports[0].module == 'my_type_mod' assert len(routine.imports[0].symbols) == 1 assert routine.imports[0].symbols[0] == 'my_type' assert isinstance(routine.imports[0].symbols[0], sym.DerivedTypeSymbol) # Ensure that the declared variable and its components are recognised assigns = FindNodes(ir.Assignment).visit(routine.body) assert len(assigns) == 2 assert isinstance(assigns[0].lhs, sym.Scalar) assert assigns[0].lhs.type.dtype == BasicType.REAL assert isinstance(assigns[1].lhs, sym.Array) assert assigns[1].lhs.type.dtype == BasicType.REAL assert assigns[1].lhs.type.shape == (':',) @pytest.mark.skipif(not HAVE_FP, reason="Assumed size declarations only supported for FP") def test_assumed_size_declarations(): """ Test if assumed size declarations are correctly parsed. """ fcode = """ subroutine kernel(a, b, c) implicit none real, intent(in) :: a(*) real, intent(in) :: b(8,*) real, intent(in) :: c(8,0:*) end subroutine kernel """ kernel = Subroutine.from_source(fcode, frontend=FP) variable_map = kernel.variable_map a = variable_map['a'] b = variable_map['b'] c = variable_map['c'] assert len(a.shape) == 1 assert len(b.shape) == 2 assert b.shape[0] == 8 assert len(c.shape) == 2 assert c.shape[0] == 8 assert c.shape[1].lower == 0 assert all('*' in str(shape) for shape in [a.shape, b.shape, c.shape]) @pytest.mark.parametrize('frontend', available_frontends()) def test_empty_print_statement(frontend): """ Test if an empty print statement (PRINT *) is parsed correctly. """ fcode = """ SUBROUTINE test_routine() IMPLICIT NONE print * ! Using single quotes to simplify the test comparison (see below) print *, 'test_text' END SUBROUTINE test_routine """.strip() routine = Subroutine.from_source(fcode, frontend=frontend) print_stmts = [ intr for intr in FindNodes(ir.Intrinsic).visit(routine.ir) if 'print' in intr.text.lower() ] assert print_stmts[0].text.lower() == "print *" # NOTE: OMNI always uses single quotes ('') to represent string data in PRINT statements # while fparser will mimic the quotes used in the parsed source code assert print_stmts[1].text.lower() == "print *, 'test_text'" @pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'OMNI fails to read without full module')])) def test_select_type(frontend, tmp_path): fcode = """ module select_type_mod use imported_type_mod, only: imported_type implicit none type, abstract :: base end type base type, extends(base) :: derived1 real :: val end type derived1 type, extends(base) :: derived2 integer :: val end type derived2 contains subroutine select_type_routine(arg, arg2) class(base), intent(inout) :: arg class(imported_type), intent(inout) :: arg2 select type( arg ) class is(derived1) arg%val = 1.0 class is(derived2) arg%val = 1 class default print *, 'error' end select ! Some comment before the second select select type( arg ) type is(base) write(*,*) 'default' end select select type( arg2 ) ! inline comment type is(imported_type) print *, 'imported type' end select end subroutine select_type_routine end module select_type_mod """.strip() module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path]) tconds = FindNodes(ir.TypeConditional).visit(module['select_type_routine'].body) assert len(tconds) == 3 assert tconds[0].expr == 'arg' assert tconds[0].values == ( ('derived1', True), ('derived2', True) ) assert len(tconds[0].bodies) == 2 assert len(tconds[0].else_body) == 1 assert tconds[1].expr == 'arg' assert tconds[1].values == (('base', False),) assert not tconds[1].else_body assert tconds[2].expr == 'arg2' assert tconds[2].values == (('imported_type', False),) assert not tconds[2].else_body comments = FindNodes(ir.Comment).visit(module['select_type_routine'].body) assert len(comments) == 2 assert 'Some comment' in comments[0].text assert 'inline comment' in comments[1].text loki-ecmwf-0.3.6/loki/frontend/tests/test_fparser_source.py0000664000175000017500000002103415167130205024313 0ustar alastairalastair# (C) Copyright 2018- ECMWF. # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. """ Verify correct frontend behaviour with regards to source parsing and sanitisation. """ import pytest from loki import Module, Subroutine, Sourcefile, config_override from loki.frontend import FP from loki.ir import nodes as ir, FindNodes @pytest.mark.parametrize('from_file', (True, False)) @pytest.mark.parametrize('preprocess', (True, False)) def test_source_sanitize_fp_source(tmp_path, from_file, preprocess): """ Test that source sanitizing works as expected and postprocessing rules are correctly applied """ fcode = """ subroutine some_routine(input_path) implicit none character(len=255), intent(in) :: input_path integer :: ios, fu write(*,*) "we print CPP value ", MY_VAR ! In the following line the PP definition should be replace by '0' ! or the actual line number write(*,*) "We are in line ",__LINE__ open (action='read', file=TRIM(input_path), iostat=ios, newunit=fu) end subroutine some_routine """.strip() if from_file: filepath = tmp_path/'some_routine.F90' filepath.write_text(fcode) obj = Sourcefile.from_file(filepath, frontend=FP, preprocess=preprocess, defines=('MY_VAR=5',)) else: obj = Sourcefile.from_source(fcode, frontend=FP, preprocess=preprocess, defines=('MY_VAR=5',)) if preprocess: # CPP takes care of that assert '"We are in line ", 8' in obj.to_fortran() assert '"we print CPP value ", 5' in obj.to_fortran() else: # source sanitisation takes care of that assert '"We are in line ", 0' in obj.to_fortran() assert '"we print CPP value ", MY_VAR' in obj.to_fortran() assert 'newunit=fu' in obj.to_fortran() @pytest.mark.parametrize('preprocess', (True, False)) def test_source_sanitize_fp_subroutine(preprocess): """ Test that source sanitizing works as expected and postprocessing rules are correctly applied """ fcode = """ subroutine some_routine(input_path) implicit none character(len=255), intent(in) :: input_path integer :: ios, fu write(*,*) "we print CPP value ", MY_VAR ! In the following line the PP definition should be replace by '0' ! or the actual line number write(*,*) "We are in line ",__LINE__ open (action='read', file=TRIM(input_path), iostat=ios, newunit=fu) end subroutine some_routine """.strip() obj = Subroutine.from_source(fcode, frontend=FP, preprocess=preprocess, defines=('MY_VAR=5',)) if preprocess: # CPP takes care of that assert '"We are in line ", 8' in obj.to_fortran() assert '"we print CPP value ", 5' in obj.to_fortran() else: # source sanitisation takes care of that assert '"We are in line ", 0' in obj.to_fortran() assert '"we print CPP value ", MY_VAR' in obj.to_fortran() assert 'newunit=fu' in obj.to_fortran() @pytest.mark.parametrize('preprocess', (True, False)) def test_source_sanitize_fp_module(preprocess): """ Test that source sanitizing works as expected and postprocessing rules are correctly applied """ fcode = """ module some_mod implicit none integer line = __LINE__ + MY_VAR contains subroutine some_routine(input_path) implicit none character(len=255), intent(in) :: input_path integer :: ios, fu write(*,*) "we print CPP value ", MY_VAR ! In the following line the PP definition should be replace by '0' ! or the actual line number write(*,*) "We are in line ",__LINE__ open (action='read', file=TRIM(input_path), iostat=ios, newunit=fu) end subroutine some_routine end module some_mod """.strip() obj = Module.from_source(fcode, frontend=FP, preprocess=preprocess, defines=('MY_VAR=5',)) if preprocess: # CPP takes care of that assert 'line = 3 + 5' in obj.to_fortran() assert '"We are in line ", 12' in obj.to_fortran() assert '"we print CPP value ", 5' in obj.to_fortran() else: # source sanitisation takes care of that assert 'line = 0 + MY_VAR' in obj.to_fortran() assert '"We are in line ", 0' in obj.to_fortran() assert '"we print CPP value ", MY_VAR' in obj.to_fortran() assert 'newunit=fu' in obj.to_fortran() # TODO: Add tests for source sanitizer with other frontends @pytest.mark.parametrize('store_source', (True, False)) def test_fparser_source_parsing(store_source): fcode = """ module test_source_mod use my_kind_mod, only: akind implicit none type my_type real(kind=akind) :: scalar, vector(3) integer :: asize end type my_type contains subroutine my_test_routine(n, rick, dave) integer, intent(in) :: n real(kind=akind), intent(inout) :: rick, dave(n) integer :: i do i=1, n if (dave(i) > 0.5) then dave(i) = dave(i) + rick end if end do forall(i=1:n) dave(i) = dave(i) + 2.0 end forall end subroutine my_test_routine end module test_source_mod """ with config_override({'frontend-store-source': store_source}): source = Sourcefile.from_source(fcode, frontend=FP) module = source['test_source_mod'] routine = module['my_test_routine'] if store_source: assert module.spec.source and module.spec.source.lines == (3, 10) assert module.contains.source and module.contains.source.lines == (11, 27) assert routine.spec.source and routine.spec.source.lines == (14, 16) assert routine.body.source and routine.body.source.lines == (17, 26) else: assert not module.spec.source assert not module.contains.source assert not routine.spec.source assert not routine.body.source decls = FindNodes(ir.VariableDeclaration).visit(routine.spec) loops = FindNodes(ir.Loop).visit(routine.body) conds = FindNodes(ir.Conditional).visit(routine.body) assigns = FindNodes(ir.Assignment).visit(routine.body) foralls = FindNodes(ir.Forall).visit(routine.body) assert len(decls) == 3 and len(loops) == 1 and len(conds) == 1 assert len(assigns) == 2 and len(foralls) == 1 if store_source: assert decls[0].source and decls[0].source.lines == (14, 14) assert decls[1].source and decls[1].source.lines == (15, 15) assert decls[2].source and decls[2].source.lines == (16, 16) assert loops[0].source and loops[0].source.lines == (18, 22) assert conds[0].source and conds[0].source.lines == (19, 21) assert assigns[0].source and assigns[0].source.lines == (20, 20) assert assigns[1].source and assigns[1].source.lines == (25, 25) assert foralls[0].source and foralls[0].source.lines == (24, 26) else: assert not decls[0].source and not decls[1].source and not decls[2].source assert not loops[0].source assert not conds[0].source assert not assigns[0].source imprts = FindNodes(ir.Import).visit(module.spec) intrs = FindNodes(ir.Intrinsic).visit(module.spec) tdefs = FindNodes(ir.TypeDef).visit(module.spec) assert len(imprts) == 1 and len(tdefs) == 1 and len(intrs) == 1 tdecls = FindNodes(ir.VariableDeclaration).visit(tdefs[0].body) assert len(tdecls) == 2 if store_source: assert imprts[0].source and imprts[0].source.lines == (3, 3) assert intrs[0].source and intrs[0].source.lines == (4, 4) assert tdefs[0].source and tdefs[0].source.lines == (6, 9) assert tdecls[0].source and tdecls[0].source.lines == (7, 7) assert tdecls[1].source and tdecls[1].source.lines == (8, 8) else: assert not imprts[0].source assert not intrs[0].source assert not tdefs[0].source assert not tdecls[0].source assert not tdecls[1].source def test_fparser_sanitize_fypp_line_annotations(): """ Test that fypp line number annotations are sanitized correctly. """ fcode = """ module some_templated_mod # 1 "/path-to-hypp-macro/macro.hypp" 1 # 2 "/path-to-hypp-macro/macro.hypp" # 3 "/path-to-hypp-macro/macro.hypp" # 5 "/path-to-fypp-template/template.fypp" 2 integer :: a0 integer :: a1 integer :: a2 integer :: a3 integer :: a4 end module some_templated_mod """ module = Module.from_source(fcode, frontend=FP) decls = FindNodes(ir.VariableDeclaration).visit(module.spec) assert len(decls) == 5 loki-ecmwf-0.3.6/loki/frontend/tests/test_omni.py0000664000175000017500000000557015167130205022242 0ustar alastairalastair# (C) Copyright 2018- ECMWF. # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. """ Specific test battery for the OMNI parser frontend. """ import pytest from loki import Module, Subroutine from loki.frontend import OMNI, HAVE_OMNI from loki.ir import nodes as ir, FindNodes, FindVariables @pytest.mark.skipif(not HAVE_OMNI, reason='Test tequires OMNI frontend.') def test_derived_type_definitions(tmp_path): """ Test correct parsing of derived type declarations. """ fcode = """ module omni_derived_type_mod type explicit real(kind=8) :: scalar, vector(3), matrix(3, 3) end type explicit type deferred real(kind=8), allocatable :: scalar, vector(:), matrix(:, :) end type deferred type ranged real(kind=8) :: scalar, vector(1:3), matrix(0:3, 0:3) end type ranged end module omni_derived_type_mod """ # Parse the source and validate the IR module = Module.from_source(fcode, frontend=OMNI, xmods=[tmp_path]) assert len(module.typedefs) == 3 explicit_symbols = FindVariables(unique=False).visit(module['explicit'].body) assert explicit_symbols == ('scalar', 'vector(3)', 'matrix(3, 3)') deferred_symbols = FindVariables(unique=False).visit(module['deferred'].body) assert deferred_symbols == ('scalar', 'vector(:)', 'matrix(:, :)') ranged_symbols = FindVariables(unique=False).visit(module['ranged'].body) assert ranged_symbols == ('scalar', 'vector(3)', 'matrix(0:3, 0:3)') @pytest.mark.skipif(not HAVE_OMNI, reason='Test tequires OMNI frontend.') def test_array_dimensions(tmp_path): """ Test correct parsing of derived type declarations. """ fcode = """ subroutine omni_array_indexing(n, a, b) integer, intent(in) :: n real(kind=8), intent(inout) :: a(3), b(n) real(kind=8) :: c(n, n) real(kind=8) :: d(1:n, 0:n) a(:) = 11. b(1:n) = 42. c(2:n, 0:n) = 66. d(:, 0:n) = 68. end subroutine omni_array_indexing """ # Parse the source and validate the IR routine = Subroutine.from_source(fcode, frontend=OMNI, xmods=[tmp_path]) # OMNI separate declarations per variable decls = FindNodes(ir.VariableDeclaration).visit(routine.spec) assert len(decls) == 5 assert decls[0].symbols == ('n',) assert decls[1].symbols == ('a(3)',) assert decls[2].symbols == ('b(n)',) assert decls[3].symbols == ('c(n, n)',) assert decls[4].symbols == ('d(n, 0:n)',) assigns = FindNodes(ir.Assignment).visit(routine.body) assert len(assigns) == 4 assert assigns[0].lhs == 'a(:)' assert assigns[1].lhs == 'b(1:n)' assert assigns[2].lhs == 'c(2:n, 0:n)' assert assigns[3].lhs == 'd(:, 0:n)' loki-ecmwf-0.3.6/loki/frontend/regex.py0000664000175000017500000012301615167130205020205 0ustar alastairalastair# (C) Copyright 2018- ECMWF. # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. """ The implementation of a regex parser frontend This is intended to allow for fast, partial extraction of IR objects from Fortran source files without the need to generate a complete parse tree. """ from abc import abstractmethod from enum import Flag, auto import re from codetiming import Timer from loki import ir from loki.config import config from loki.expression import symbols as sym from loki.frontend.source import Source, FortranReader from loki.frontend.util import combine_multiline_pragmas from loki.logging import debug from loki.tools import as_tuple, timeout from loki.types import BasicType, ProcedureType, DerivedType, SymbolAttributes __all__ = ['RegexParserClass', 'parse_regex_source', 'HAVE_REGEX'] HAVE_REGEX = True """Indicate that the regex frontend is available.""" class RegexParserClass(Flag): """ Classes to configure active patterns in the :any:`REGEX` frontend Every :class:`Pattern` in the frontend is categorized as one of these classes. By specifying some (or all of them) as ``parser_classes`` to :any:`parse_regex_source`, pattern matching can be switched on and off for some pattern classes, and thus the overall parse time reduced. """ EmptyClass = 0 ProgramUnitClass = auto() InterfaceClass = auto() ImportClass = auto() TypeDefClass = auto() DeclarationClass = auto() CallClass = auto() PragmaClass = auto() AllClasses = ProgramUnitClass | InterfaceClass | ImportClass | TypeDefClass | \ DeclarationClass | CallClass | PragmaClass # pylint: disable=unsupported-binary-operation class Pattern: """ Base class for patterns used in the :any:`REGEX` frontend Parameters ---------- pattern : str The regex pattern used for matching flags : re.RegexFlag Regular expression flag(s) to use when compiling and matching the pattern """ def __init__(self, pattern, flags=None): self.pattern = re.compile(pattern, flags) @abstractmethod def match(self, reader, parser_classes, scope): """ Match the stored pattern against the source string in the reader object This method must be implemented by every child class to provide the matching logic. It is not necessary to check the selected :data:`parser_classes` here, as the relevant :meth:`match` method will only be called for :class:`Pattern` classes that are active. :data:`parser_classes` is only passed here to forward it to use it when matching recursively. If this match method matches against a single line, it should return a :any:`Node` if matched successfully, or otherwise `None`. If this match method matches a block, e.g. a :any:`Subroutine`, then this should return a 3-tupel ``(pre, node, new_reader)``, with each entry: - ``pre`` : A :any:`FortranReader` object representing any unmatched source code fragments prior to the matched object. Can be `None` if there are none or if there was no match. - ``node``: The object created as the result of a successful match, or `None`. - ``new_reader``: A :any:`FortranReader` object representing any unmatched source code fragments past the matched object. Can be `None` if there are none and should be the original :data:`reader` object if there was no match. Parameters ---------- reader : :any:`FortranReader` The reader object containing a sanitized Fortran source parser_classes : RegexParserClass Active parser classes for matching scope : :any:`Scope` The parent scope for the current source fragment """ @classmethod def match_block_candidates(cls, reader, candidates, parser_classes=None, scope=None): """ Attempt to match block candidates It will automatically skip :data:`candidates` that are inactive due to the chosen :data:`parser_classes`. Parameters ---------- reader : :any:`FortranReader` The reader object containing a sanitized Fortran source candidates : list of str The list of candidate classes to match parser_classes : RegexParserClass Active parser classes for matching scope : :any:`Scope` The parent scope for the current source fragment """ if parser_classes is None: parser_classes = RegexParserClass.AllClasses ir_ = [] # Extract source bits that would be swept under the rag when sanitizing head = reader.source_from_head() tail = reader.source_from_tail() for idx, candidate_name in enumerate(candidates): candidate = PATTERN_REGISTRY[candidate_name] if not candidate.parser_class & parser_classes: continue while reader: pre, match, reader = candidate.match(reader, parser_classes=parser_classes, scope=scope) if not match: assert pre is None break if pre: # See if any of the other candidates match before this match ir_ += cls.match_block_candidates( pre, candidates[idx+1:], parser_classes=parser_classes, scope=scope ) ir_ += [match] if reader: source = reader.to_source(include_padding=True) ir_ += [ir.RawSource(text=source.string, source=source)] if head is not None and (not ir_ or ir_[0].source.lines[0] > head.lines[1]): # Insert the header bit only if the recursion hasn't already taken care of it ir_ = [ir.RawSource(text=head.string, source=head)] + ir_ if tail is not None: ir_ += [ir.RawSource(text=tail.string, source=tail)] return ir_ @classmethod def match_statement_candidates(cls, reader, candidates, parser_classes=None, scope=None): """ Attempt to match single-line statement candidates It will automatically skip :data:`candidates` that are inactive due to the chosen :data:`parser_classes`. Parameters ---------- reader : :any:`FortranReader` The reader object containing a sanitized Fortran source candidates : list of str The list of candidate classes to match parser_classes : RegexParserClass Active parser classes for matching scope : :any:`Scope` The parent scope for the current source fragment """ if parser_classes is None: parser_classes = RegexParserClass.AllClasses # Extract source bits that would be swept under the rag when sanitizing head = reader.source_from_head() filtered_candidates = [PATTERN_REGISTRY[candidate_name] for candidate_name in candidates] filtered_candidates = [ candidate for candidate in filtered_candidates if candidate.parser_class & parser_classes ] ir_ = [] last_match = -1 if filtered_candidates: for idx, _ in enumerate(reader): for candidate in filtered_candidates: match = candidate.match(reader, parser_classes=parser_classes, scope=scope) if match: if last_match - idx > 1: span = (reader.sanitized_spans[last_match + 1], reader.sanitized_spans[idx]) source = reader.source_from_sanitized_span(span) ir_ += [ir.RawSource(source.string, source=source)] last_match = idx ir_ += [match] break if head is not None and ir_: ir_ = [ir.RawSource(text=head.string, source=head)] + ir_ tail_span = (reader.sanitized_spans[last_match + 1], None) source = reader.source_from_sanitized_span(tail_span, include_padding=True) if source: ir_ += [ir.RawSource(source.string, source=source)] return ir_ @classmethod def match_block_statement_candidates( cls, reader, block_candidates, statement_candidates, parser_classes=None, scope=None ): """ Attempt to match block candidates and subsequently attempt to match statement candidates on unmatched sections It will automatically skip :data:`candidates` that are inactive due to the chosen :data:`parser_classes`. This is essentially equivalent to :meth:`match_block_candidates` but applies :meth:`match_statement_candidates` to the unmatched tail source instead of returning it as a :any:`RawSource` object straight away. Parameters ---------- reader : :any:`FortranReader` The reader object containing a sanitized Fortran source block_candidates : list of str The list of block candidate classes to match statement_candidates : list of str The list of statement candidate classes to match parser_classes : RegexParserClass Active parser classes for matching scope : :any:`Scope` The parent scope for the current source fragment """ if parser_classes is None: parser_classes = RegexParserClass.AllClasses # Extract source bits that would be swept under the rag when sanitizing head = reader.source_from_head() ir_ = [] for idx, candidate_name in enumerate(block_candidates): candidate = PATTERN_REGISTRY[candidate_name] if not candidate.parser_class & parser_classes: continue while reader: pre, match, reader = candidate.match(reader, parser_classes=parser_classes, scope=scope) if not match: assert pre is None break if pre: # See if any of the other candidates match before this match ir_ += cls.match_block_statement_candidates( pre, block_candidates[idx+1:], statement_candidates, scope=scope ) ir_ += [match] if head is not None and ir_ and reader: # Insert the head source bits only if we have matched something, otherwise # the statement candidate matching will take care of this ir_ = [ir.RawSource(text=head.string, source=head)] + ir_ if reader: ir_ += cls.match_statement_candidates( reader, statement_candidates, parser_classes=parser_classes, scope=scope ) return ir_ _pattern_opening_parenthesis = re.compile(r'\(') _pattern_closing_parenthesis = re.compile(r'\)') _pattern_opening_bracket = re.compile(r'\[') _pattern_closing_bracket = re.compile(r'\]') _pattern_quoted_string = re.compile(r'(?:\'.*?\')|(?:".*?")') @classmethod def _remove_quoted_string_nested_parentheses(cls, string): """ Remove any quoted strings and parentheses with their content in the given string """ string = cls._pattern_quoted_string.sub('', string) p_open = [match.start() for match in cls._pattern_opening_parenthesis.finditer(string)] p_close = [match.start() for match in cls._pattern_closing_parenthesis.finditer(string)] b_open = [match.start() for match in cls._pattern_opening_bracket.finditer(string)] b_close = [match.start() for match in cls._pattern_closing_bracket.finditer(string)] if len(p_open) > len(p_close): # Note: fparser's reader has currently problems with opening # quotes in comments in combination with line continuation, thus # potentially failing to sanitize the string correctly. # In that case, we'll just discard everything after the first # opening parenthesis, well aware that we're potentially # loosing information... # See https://github.com/stfc/fparser/issues/264 return string[:p_open[0]] assert len(p_open) == len(p_close) assert len(b_open) == len(b_close) if not p_close and not b_close: return string def _match_spans(open_, close_): # We match pairs of parentheses starting at the end by pushing and popping from a stack. # Whenever the stack runs out, we have fully resolved a set of (nested) parenthesis and # record the corresponding span if not close_: return [] spans = [] stack = [close_.pop()] while open_: if not close_ or open_[-1] > close_[-1]: assert stack start = open_.pop() end = stack.pop() if not stack: spans.append((start, end)) else: stack.append(close_.pop()) assert not (stack or open_ or close_) return spans p_spans = _match_spans(p_open, p_close) b_spans = _match_spans(b_open, b_close) # Merge the span lists (and reverse the order into ascending in the process) spans = [] while p_spans and b_spans: if p_spans[-1][0] < b_spans[-1][0]: spans.append(p_spans.pop()) else: spans.append(b_spans.pop()) if p_spans: spans += p_spans[::-1] if b_spans: spans += b_spans[::-1] # We should now be left with no parentheses anymore and can build the new string # by using everything between these parenthesis "spans" starts, ends = zip(*spans) new_string = string[:min(starts)] for (_, start), (end, _) in zip(spans[:-1], spans[1:]): new_string += string[start+1:end] new_string += string[max(ends)+1:] return new_string @Timer(logger=debug, text=lambda s: f'[Loki::REGEX] Executed parse_regex_source in {s:.2f}s') def parse_regex_source(source, parser_classes=None, scope=None): """ Generate a reduced Loki IR from regex parsing of the given Fortran source The IR nodes that should be matched can be configured via :data:`parser_classes`. Any non-matched source code snippets are retained as :any:`RawSource` objects. Parameters ---------- source : str or :any:`Source` The raw source string parser_classes : RegexParserClass Active parser classes for matching scope : :any:`Scope`, optional The enclosing parent scope """ if parser_classes is None: parser_classes = RegexParserClass.AllClasses candidates = ('ModulePattern', 'SubroutineFunctionPattern') if isinstance(source, Source): reader = FortranReader(source.string) else: reader = FortranReader(source) timeout_message = f'REGEX frontend timeout of {config["regex-frontend-timeout"]} s exceeded' with timeout(config['regex-frontend-timeout'], message=timeout_message): ir_ = Pattern.match_block_candidates(reader, candidates, parser_classes=parser_classes, scope=scope) lines = (1, source.count('\n') + 1) source = Source(lines, string=source) return ir.Section(body=as_tuple(ir_), source=source) class ModulePattern(Pattern): """ Pattern to match :any:`Module` objects """ parser_class = RegexParserClass.ProgramUnitClass def __init__(self): super().__init__( r'^module[ \t]+(?P\w+)\b.*?$' r'(?P.*?)' r'(?P^contains\n(?:' r'(?:[ \t\w()=]*?subroutine.*?^end[ \t]*subroutine\b(?:[ \t]*\w+)?\n)|' r'(?:[ \t\w()=]*?function.*?^end[ \t]*function\b(?:[ \t]*\w+)?\n)|' r'(?:^#\w+.*?\n)' r')*?)?' r'^end[ \t]*module\b(?:[ \t](?P=name))?', re.IGNORECASE | re.DOTALL | re.MULTILINE ) def match(self, reader, parser_classes, scope): """ Match the provided source string against the pattern for a :any:`Module` Parameters ---------- reader : :any:`FortranReader` The reader object containing a sanitized Fortran source parser_classes : RegexParserClass Active parser classes for matching scope : :any:`Scope` The parent scope for the current source fragment """ from loki import Module # pylint: disable=import-outside-toplevel,cyclic-import match = self.pattern.search(reader.sanitized_string) if not match: return None, None, reader # Check if the Module node has been created before by looking it up in the scope module = None name = match['name'] if scope is not None and name in scope.symbol_attrs: module_type = scope.symbol_attrs[name] # Look-up only in current scope! if module_type and module_type.dtype.module != BasicType.DEFERRED: module = module_type.dtype.module if module is None: source = reader.source_from_sanitized_span(match.span()) module = Module(name=name, source=source, parent=scope) if match['spec'] and match['spec'].strip(): block_candidates = ('TypedefPattern', 'InterfacePattern') statement_candidates = ('ImportPattern', 'VariableDeclarationPattern') spec = self.match_block_statement_candidates( reader.reader_from_sanitized_span(match.span('spec'), include_padding=True), block_candidates, statement_candidates, parser_classes=parser_classes, scope=module ) else: spec = None if match['contains']: contains = [ir.Intrinsic(text='CONTAINS')] span = match.span('contains') span = (span[0] + 8, span[1]) # Skip the "contains" keyword as it has been added candidates = ['SubroutineFunctionPattern'] contains += self.match_block_candidates( reader.reader_from_sanitized_span(span, include_padding=True), candidates, parser_classes=parser_classes, scope=module ) else: contains = None module.__initialize__( # pylint: disable=unnecessary-dunder-call name=module.name, spec=spec, contains=contains, source=module.source, incomplete=True, parser_classes=parser_classes ) if match.span()[0] > 0: pre = reader.reader_from_sanitized_span((0, match.span()[0]), include_padding=True) else: pre = None return pre, module, reader.reader_from_sanitized_span((match.span()[1], None), include_padding=True) class SubroutineFunctionPattern(Pattern): """ Pattern to match :any:`Subroutine` objects """ parser_class = RegexParserClass.ProgramUnitClass def __init__(self): super().__init__( r'^(?P[ \t\w()=]*)?(?Psubroutine|function)[ \t]+(?P\w+)\b.*?$' r'(?P(?:.*?(?:^(?:abstract[ \t]+)?interface\b.*?^end[ \t]+interface)?)+)' r'(?P^contains\n(?:' r'(?:[ \t\w()=]*?subroutine.*?^end[ \t]*subroutine\b(?:[ \t]\w+)?\n)|' r'(?:[ \t\w()=]*?function.*?^end[ \t]*function\b(?:[ \t]\w+)?\n)|' r'(?:^#\w+.*?\n)' r')*?)?' r'^end[ \t]*(?P=keyword)\b(?:[ \t](?P=name))?', re.IGNORECASE | re.DOTALL | re.MULTILINE ) def match(self, reader, parser_classes, scope): """ Match the provided source string against the pattern for a :any:`Subroutine` Parameters ---------- reader : :any:`FortranReader` The reader object containing a sanitized Fortran source parser_classes : RegexParserClass Active parser classes for matching scope : :any:`Scope` The parent scope for the current source fragment """ from loki import Subroutine # pylint: disable=import-outside-toplevel,cyclic-import from loki.function import Function # pylint: disable=import-outside-toplevel,cyclic-import match = self.pattern.search(reader.sanitized_string) if not match: return None, None, reader # Check if the Subroutine node has been created before by looking it up in the scope routine = None name = match['name'] if scope is not None and name in scope.symbol_attrs: proc_type = scope.symbol_attrs[name] # Look-up only in current scope! if proc_type and getattr(proc_type.dtype, 'procedure', BasicType.DEFERRED) != BasicType.DEFERRED: routine = proc_type.dtype.procedure if routine is None: is_function = match['keyword'].lower() == 'function' source = reader.source_from_sanitized_span(match.span()) if is_function: routine = Function(name=name, args=(), source=source, parent=scope) else: routine = Subroutine(name=name, args=(), source=source, parent=scope) if match['spec']: statement_candidates = ('ImportPattern', 'VariableDeclarationPattern', 'CallPattern', 'PragmaPattern') block_candidates = ('InterfacePattern',) spec = self.match_block_statement_candidates( reader.reader_from_sanitized_span(match.span('spec'), include_padding=True), block_candidates, statement_candidates, parser_classes=parser_classes, scope=routine ) spec = combine_multiline_pragmas(spec) else: spec = None if match['contains']: contains = [ir.Intrinsic(text='CONTAINS')] span = match.span('contains') span = (span[0] + 8, span[1]) # Skip the "contains" keyword as it has been added block_children = ['SubroutineFunctionPattern'] contains += self.match_block_candidates( reader.reader_from_sanitized_span(span), block_children, parser_classes=parser_classes, scope=routine ) else: contains = None if match['prefix'].strip(): prefix = match['prefix'].strip() else: prefix=None routine.__initialize__( # pylint: disable=unnecessary-dunder-call name=routine.name, args=routine._dummies, prefix=prefix, spec=spec, contains=contains, source=routine.source, incomplete=True, parser_classes=parser_classes ) if match.span()[0] > 0: pre = reader.reader_from_sanitized_span((0, match.span()[0]), include_padding=True) else: pre = None return pre, routine, reader.reader_from_sanitized_span((match.span()[1], None), include_padding=True) class InterfacePattern(Pattern): """ Pattern to match :any:`Interface` objects """ parser_class = RegexParserClass.InterfaceClass def __init__(self): super().__init__( r'^(?Pabstract[ \t]+)?' r'interface\b[ \t]*(?P\w+\b.*?$)?' r'(?P.*?)' r'^end[ \t]+interface\b[ \t]*(?P=spec)?', re.IGNORECASE | re.DOTALL | re.MULTILINE ) def match(self, reader, parser_classes, scope): """ Match the provided source string against the pattern for a :any:`Interface` Parameters ---------- reader : :any:`FortranReader` The reader object containing a sanitized Fortran source parser_classes : RegexParserClass Active parser classes for matching scope : :any:`Scope` The parent scope for the current source fragment """ from loki import Interface # pylint: disable=import-outside-toplevel,cyclic-import match = self.pattern.search(reader.sanitized_string) if not match: return None, None, reader source = reader.source_from_sanitized_span(match.span()) is_abstract = match['is_abstract'] is not None block_candidates = ['SubroutineFunctionPattern'] statement_candidates = ('ProcedureStatementPattern',) body = self.match_block_statement_candidates( reader.reader_from_sanitized_span(match.span('body'), include_padding=True), block_candidates, statement_candidates, parser_classes=parser_classes, scope=scope ) if match['spec']: spec = match['spec'].replace(' ', '') type_ = SymbolAttributes(ProcedureType(name=spec, is_generic=True)) spec = sym.Variable(name=spec, type=type_, scope=scope) else: spec = None interface = Interface(body=body, abstract=is_abstract, spec=spec, source=source) if match.span()[0] > 0: pre = reader.reader_from_sanitized_span((0, match.span()[0]), include_padding=True) else: pre = None return pre, interface, reader.reader_from_sanitized_span((match.span()[1], None), include_padding=True) class ProcedureStatementPattern(Pattern): """ Pattern to match procedure statements in interfaces """ parser_class = RegexParserClass.InterfaceClass def __init__(self): super().__init__( r'^(?Pmodule[ \t]+)?procedure\b' # Match ``procedure`` keyword r'(?:[ \t]*::)?' # Optional `::` delimiter r'[ \t]*' # Some white space r'(?P' # Beginning of procedures group r'\w+(?:[ \t]*,[ \t]*\w+)*' # Procedure names, separated by ``,`` r')', # End of procedures group re.IGNORECASE ) def match(self, reader, parser_classes, scope): """ Match the provided source string against the pattern for a procedure binding Parameters ---------- reader : :any:`FortranReader` The reader object containing a sanitized Fortran source parser_classes : RegexParserClass Active parser classes for matching scope : :any:`Scope` The parent scope for the current source fragment """ line = reader.current_line match = self.pattern.search(line.line) if not match: return None is_module = match['module'] is not None procedures = match['procedures'].replace(' ', '').split(',') symbols = [ sym.Variable(name=s, type=SymbolAttributes(ProcedureType(name=s)), scope=scope) for s in procedures ] return ir.ProcedureDeclaration( symbols=symbols, module=is_module, source=reader.source_from_current_line() ) class TypedefPattern(Pattern): """ Pattern to match :any:`TypeDef` objects """ parser_class = RegexParserClass.TypeDefClass def __init__(self): super().__init__( r'type(?:[ \t]*,[ \t]*[\w\(\)]+)*?' # type keyword with optional parameters r'(?:[ \t]*::[ \t]*|[ \t]+)' # optional `::` separator or white space r'(?P\w+)\b.*?$' # Type name r'(?P.*?)' # Type spec r'(?P^contains\n.*?)?' # Optional procedure bindings part (after ``contains`` keyword) r'^end[ \t]*type\b(?:[ \t]+(?P=name))?', # End keyword with optionally type name repeated re.IGNORECASE | re.DOTALL | re.MULTILINE ) def match(self, reader, parser_classes, scope): """ Match the provided source string against the pattern for a :any:`TypeDef` Parameters ---------- reader : :any:`FortranReader` The reader object containing a sanitized Fortran source parser_classes : RegexParserClass Active parser classes for matching scope : :any:`Scope` The parent scope for the current source fragment """ match = self.pattern.search(reader.sanitized_string) if not match: return None, None, reader source = reader.source_from_sanitized_span(match.span()) typedef = ir.TypeDef(name=match['name'], body=(), parent=scope, source=source) if match['spec'] and match['spec'].strip(): statement_candidates = ('VariableDeclarationPattern',) spec = self.match_statement_candidates( reader.reader_from_sanitized_span(match.span('spec'), include_padding=True), statement_candidates, parser_classes=parser_classes, scope=typedef ) else: spec = [] if match['contains']: contains = [ir.Intrinsic(text='CONTAINS')] span = match.span('contains') span = (span[0] + 8, span[1]) # Skip the "contains" keyword as it has been added statement_candidates = ('ProcedureBindingPattern', 'GenericBindingPattern') contains += self.match_statement_candidates( reader.reader_from_sanitized_span(span, include_padding=True), statement_candidates, parser_classes=parser_classes, scope=typedef ) else: contains = [] typedef._update(body=as_tuple(spec + contains)) if match.span()[0] > 0: pre = reader.reader_from_sanitized_span((0, match.span()[0]), include_padding=True) else: pre = None return pre, typedef, reader.reader_from_sanitized_span((match.span()[1], None), include_padding=True) class ProcedureBindingPattern(Pattern): """ Pattern to match procedure bindings """ parser_class = RegexParserClass.TypeDefClass def __init__(self): super().__init__( r'^procedure\b' # Match ``procedure`` keyword r'(?P(?:[ \t]*,[ \t]*\w+)*?)' # Optional attributes r'(?:[ \t]*::)?' # Optional `::` delimiter r'[ \t]*' # Some white space r'(?P' # Beginning of bindings group r'\w+(?:[ \t]*=>[ \t]*\w+)?' # Binding name with optional binding name specifier (via ``=>``) r'(?:[ \t]*,[ \t]*' # Optional group for additional bindings, separated by ``,`` r'\w+(?:[ \t]*=>[ \t]*\w+)?' # Additional binding name with optional binding name specifier r')*' # End of optional group for additional bindings r')', # End of bindings group re.IGNORECASE ) def match(self, reader, parser_classes, scope): """ Match the provided source string against the pattern for a procedure binding Parameters ---------- reader : :any:`FortranReader` The reader object containing a sanitized Fortran source parser_classes : RegexParserClass Active parser classes for matching scope : :any:`Scope` The parent scope for the current source fragment """ line = reader.current_line match = self.pattern.search(line.line) if not match: return None bindings = match['bindings'].replace(' ', '').split(',') bindings = [s.split('=>') for s in bindings] symbols = [] for s in bindings: if len(s) == 1: type_ = SymbolAttributes(ProcedureType(name=s[0])) symbols += [sym.Variable(name=s[0], type=type_, scope=scope)] else: type_ = SymbolAttributes(ProcedureType(name=s[1])) bind_name = sym.Variable(name=s[1], type=type_, scope=scope.parent) symbols += [sym.Variable(name=s[0], type=type_.clone(bind_names=(bind_name,)), scope=scope)] return ir.ProcedureDeclaration(symbols=symbols, source=reader.source_from_current_line()) class GenericBindingPattern(Pattern): """ Pattern to match generic bindings """ parser_class = RegexParserClass.TypeDefClass def __init__(self): super().__init__( r'^generic' # Match ``generic`` keyword r'(?P(?:[ \t]*,[ \t]*\w+)*?)' # Optional attributes r'(?:[ \t]*::)?' # Optional `::` delimiter r'[ \t]*' # Some white space r'(?P\w+)' # Binding name r'[ \t]*=>[ \t]*' # Separator ``=>`` r'(?P\w+(?:[ \t]*,[ \t]*\w+)*)*', # Match binding name list re.IGNORECASE ) def match(self, reader, parser_classes, scope): """ Match the provided source string against the pattern for a generic procedure binding Parameters ---------- reader : :any:`FortranReader` The reader object containing a sanitized Fortran source parser_classes : RegexParserClass Active parser classes for matching scope : :any:`Scope` The parent scope for the current source fragment """ line = reader.current_line match = self.pattern.search(line.line) if not match: return None bindings = match['bindings'].replace(' ', '').split(',') name = match['name'] type_ = SymbolAttributes(ProcedureType(name=name, is_generic=True), bind_names=as_tuple(bindings)) symbols = (sym.Variable(name=name, type=type_, scope=scope),) return ir.ProcedureDeclaration(symbols=symbols, generic=True, source=reader.source_from_current_line()) class ImportPattern(Pattern): """ Pattern to match :any:`Import` nodes """ parser_class = RegexParserClass.ImportClass def __init__(self): super().__init__( r'^use +(?P\w+)(?: *, *(?Ponly *:)?' # The use statement including an optional ``only`` r'(?P(?: *\w+\b *(?:=> *\w+|\(.*?\))? *,?)+))?', # The optional list of names (w/ renames, ops) re.IGNORECASE ) def match(self, reader, parser_classes, scope): """ Match the provided source string against the pattern for a :any:`Import` Parameters ---------- reader : :any:`FortranReader` The reader object containing a sanitized Fortran source parser_classes : RegexParserClass Active parser classes for matching scope : :any:`Scope` The parent scope for the current source fragment """ line = reader.current_line match = self.pattern.search(line.line) if not match: return None module = match['module'] type_ = SymbolAttributes(BasicType.DEFERRED, imported=True) if match['imports']: imports = match['imports'].replace(' ', '').split(',') imports = [s.split('=>') for s in imports] imports = [s for s in imports if s and s[0]] if match['only']: rename_list = None symbols = [] for s in imports: if not s[0]: continue if len(s) == 1: symbols += [sym.Variable(name=s[0], type=type_, scope=scope)] else: symbols += [sym.Variable(name=s[0], type=type_.clone(use_name=s[1]), scope=scope)] else: rename_list = [ (s[1], sym.Variable(name=s[0], type=type_.clone(use_name=s[1]), scope=scope)) for s in imports ] symbols = None else: rename_list = None symbols = None return ir.Import( module, symbols=as_tuple(symbols), rename_list=as_tuple(rename_list), source=reader.source_from_current_line() ) class VariableDeclarationPattern(Pattern): """ Pattern to match :any:`VariableDeclaration` nodes. """ parser_class = RegexParserClass.DeclarationClass def __init__(self): super().__init__( r'^(((?:type|class)[ \t]*\([ \t]*(?P\w+)[ \t]*\))|' # TYPE or CLASS keyword with typename r'^([ \t]*(?P(logical|real|integer|complex|character))' r'[ \t]*(?P\([ \t]*(kind|len)[ \t]*=[ \t]*[a-z0-9_-]+[ \t]*\))?[ \t]*))' r'(?:[ \t]*,[ \t]*[a-z]+(?:[ \t]*\((.(\(.*\))?)*?\))?)*' # Optional attributes r'(?:[ \t]*::)?' # Optional `::` delimiter r'[ \t]*' # Some white space r'(?P\w+\b.*?)$', # Variable names re.IGNORECASE ) def match(self, reader, parser_classes, scope): """ Match the provided source string against the pattern for a :any:`VariableDeclaration` Parameters ---------- reader : :any:`FortranReader` The reader object containing a sanitized Fortran source parser_classes : RegexParserClass Active parser classes for matching scope : :any:`Scope` The parent scope for the current source fragment """ line = reader.current_line match = self.pattern.search(line.line) if not match: return None if (_typename := match['typename']): type_ = SymbolAttributes(DerivedType(_typename)) else: type_ = SymbolAttributes(BasicType.from_str(match['basic_type'])) assert type_ if match['param']: param = match['param'].strip().strip('()').split('=') if len(param) == 1 or param[0].lower() == 'kind': type_ = type_.clone(kind=sym.Variable(name=param[-1], scope=scope)) variables = self._remove_quoted_string_nested_parentheses(match['variables']) # Remove dimensions variables = re.sub(r'=(?:>)?[^,]*(?=,|$)', r'', variables) # Remove initialization variables = variables.replace(' ', '').split(',') # Variable names without white space variables = tuple(sym.Variable(name=v, type=type_, scope=scope) for v in variables) return ir.VariableDeclaration(variables, source=reader.source_from_current_line()) class CallPattern(Pattern): """ Pattern to match :any:`CallStatement` nodes """ parser_class = RegexParserClass.CallClass def __init__(self): super().__init__( r'^(?Pif[ \t]*\(.*?\)[ \t]*)?' # Optional inline-conditional preceeding the call r'call', # Call keyword re.IGNORECASE ) def match(self, reader, parser_classes, scope): """ Match the provided source string against the pattern for a :any:`CallStatement` Parameters ---------- reader : :any:`FortranReader` The reader object containing a sanitized Fortran source parser_classes : RegexParserClass Active parser classes for matching scope : :any:`Scope` The parent scope for the current source fragment """ line = reader.current_line match = self.pattern.search(line.line) if not match: return None # Extract the called routine name call = line.line[match.span()[1]:].strip() if not call: return None call = self._remove_quoted_string_nested_parentheses(call) # Remove arguments and dimension expressions call = call.replace(' ', '') # Remove any white space name_parts = call.split('%') name = sym.Variable(name=name_parts[0], scope=scope) for cname in name_parts[1:]: name = sym.Variable(name=name.name + '%' + cname, parent=name, scope=scope) # pylint:disable=no-member scope.symbol_attrs[call] = scope.symbol_attrs.lookup(call).clone( dtype=ProcedureType(name=call, is_function=False) ) source = reader.source_from_current_line() if match['conditional']: span = match.span('conditional') return [ ir.RawSource(text=match['conditional'], source=source.clone_with_span(span)), ir.CallStatement(name=name, arguments=(), source=source.clone_with_span((span[1], None))) ] return ir.CallStatement(name=name, arguments=(), source=source) class PragmaPattern(Pattern): """ Pattern to match :any:`VariableDeclaration` nodes. """ parser_class = RegexParserClass.PragmaClass def __init__(self): super().__init__( r'^!\$[a-z]+ ', re.IGNORECASE ) def match(self, reader, parser_classes, scope): """ Match the provided source string against the pattern for a :any:`Pragma` Parameters ---------- reader : :any:`FortranReader` The reader object containing a sanitized Fortran source parser_classes : RegexParserClass Active parser classes for matching scope : :any:`Scope` The parent scope for the current source fragment """ line = reader.current_line match = self.pattern.search(line.line) if not match: return None keyword = line.line[2:match.span()[1]].strip() content = line.line[match.span()[1]::].strip() source = reader.source_from_current_line() return ir.Pragma(keyword=keyword, content=content, source=source) PATTERN_REGISTRY = { name: globals()[name]() for name in dir() if name.endswith('Pattern') and name != 'Pattern' } """ A global registry of all available patterns This exists to ensure every :any:`Pattern` implementation is only instantiated once to ensure the corresponding regular expressions are not compiled multiple times. """ loki-ecmwf-0.3.6/loki/frontend/util.py0000664000175000017500000003353315167130205020054 0ustar alastairalastair# (C) Copyright 2018- ECMWF. # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. from enum import IntEnum from pathlib import Path import codecs from codetiming import Timer from more_itertools import split_after from loki.expression import ( symbols as sym, SubstituteExpressionsMapper, ExpressionRetriever ) from loki.ir import ( NestedTransformer, FindNodes, PatternFinder, Transformer, Assignment, Comment, CommentBlock, VariableDeclaration, ProcedureDeclaration, Loop, Intrinsic, Pragma ) from loki.frontend.source import join_source_list from loki.logging import detail, warning, error from loki.tools import group_by_class, replace_windowed, as_tuple from loki.types import ProcedureType __all__ = [ 'Frontend', 'OMNI', 'FP', 'REGEX', 'available_frontends', 'read_file', 'InlineCommentTransformer', 'ClusterCommentTransformer', 'CombineMultilinePragmasTransformer', 'sanitize_ir', 'combine_multiline_pragmas' ] class Frontend(IntEnum): """ Enumeration to identify available frontends. """ #: The OMNI compiler frontend OMNI = 1 #: Fparser 2 from STFC FP = 3 #: Reduced functionality parsing using regular expressions REGEX = 4 def __str__(self): return self.name.lower() # pylint: disable=no-member OMNI = Frontend.OMNI FP = Frontend.FP REGEX = Frontend.REGEX def available_frontends(xfail=None, skip=None, include_regex=False): """ Provide list of available frontends to parametrize tests with To run tests for every frontend, an argument :attr:`frontend` can be added to a test with the return value of this function as parameter. For any unavailable frontends where ``HAVE_`` is `False` (e.g. because required dependencies are not installed), :attr:`test` is marked as skipped. Use as ..code-block:: @pytest.mark.parametrize('frontend', available_frontends(xfail=[OMNI, (FP, 'Because...')])) def my_test(frontend): source = Sourcefile.from_file('some.F90', frontend=frontend) # ... Parameters ---------- xfail : list, optional Provide frontends that are expected to fail, optionally as tuple with reason provided as string. By default `None` skip : list, optional Provide frontends that are always skipped, optionally as tuple with reason provided as string. By default `None` include_regex : bool, optional Include the :any:`REGEX` frontend in the list. By default `false`. """ if xfail: xfail = dict((as_tuple(f) + (None,))[:2] for f in xfail) else: xfail = {} if skip: skip = dict((as_tuple(f) + (None,))[:2] for f in skip) else: skip = {} try: import pytest # pylint: disable=import-outside-toplevel except ImportError as e: error('Pytest is not installed.') raise e from loki import frontend # pylint: disable=import-outside-toplevel,cyclic-import # Unavailable frontends unavailable_frontends = { f: f'{f} is not available' for f in Frontend if not getattr(frontend, f'HAVE_{str(f).upper()}') } skip.update(unavailable_frontends) # Build the list of parameters params = [] for f in Frontend: if f in skip: params += [pytest.param(f, marks=pytest.mark.skip(reason=skip[f]))] elif f in xfail: params += [pytest.param(f, marks=pytest.mark.xfail(reason=xfail[f]))] elif f != REGEX or include_regex: params += [f] return params def match_type_pattern(pattern, sequence): """ Match elements in a sequence according to a pattern of their types. Parameters ---------- patter: list of type A list of types of the pattern to match sequence : list The list of items from which to match elements """ idx = [] types = tuple(map(type, sequence)) for i, elem in enumerate(types): if elem == pattern[0]: if tuple(types[i:i+len(pattern)]) == tuple(pattern): idx.append(i) # Return a list of element matches return [sequence[i:i+len(pattern)] for i in idx] class InlineCommentTransformer(Transformer): """ Identify inline comments and merge them onto statements """ def visit_tuple(self, o, **kwargs): pairs = match_type_pattern(pattern=(Assignment, Comment), sequence=o) pairs += match_type_pattern(pattern=(VariableDeclaration, Comment), sequence=o) pairs += match_type_pattern(pattern=(ProcedureDeclaration, Comment), sequence=o) for pair in pairs: # Comment is in-line and can be merged if pair[0].source and pair[1].source: if pair[1].source.lines[0] == pair[0].source.lines[1]: new = pair[0]._rebuild(comment=pair[1]) o = replace_windowed(o, pair, new) # Then recurse over the new nodes visited = tuple(self.visit(i, **kwargs) for i in o) # Strip empty sublists/subtuples or None entries return tuple(i for i in visited if i is not None and as_tuple(i)) visit_list = visit_tuple class ClusterCommentTransformer(Transformer): """ Combines consecutive sets of :any:`Comment` into a :any:`CommentBlock`. """ def visit_tuple(self, o, **kwargs): """ Find groups of :any:`Comment` and inject into the tuple. """ cgroups = group_by_class(o, Comment) for group in cgroups: # Combine the group into a CommentBlock source = join_source_list(tuple(p.source for p in group)) block = CommentBlock(comments=group, label=group[0].label, source=source) o = replace_windowed(o, group, subs=(block,)) # Then recurse over the new nodes visited = tuple(self.visit(i, **kwargs) for i in o) # Strip empty sublists/subtuples or None entries return tuple(i for i in visited if i is not None and as_tuple(i)) visit_list = visit_tuple def inline_labels(ir): """ Find labels and merge them onto the following node. Note: This is currently only required for the OMNI frontend which has labels as nodes next to the corresponding statement without any connection between both. """ pairs = PatternFinder(pattern=(Comment, Assignment)).visit(ir) pairs += PatternFinder(pattern=(Comment, Intrinsic)).visit(ir) pairs += PatternFinder(pattern=(Comment, Loop)).visit(ir) mapper = {} for pair in pairs: if pair[0].source and pair[0].text == '__STATEMENT_LABEL__': if pair[1].source and pair[1].source.lines[0] == pair[0].source.lines[1]: mapper[pair[0]] = None # Mark for deletion mapper[pair[1]] = pair[1]._rebuild(label=pair[0].label.lstrip('0')) # Remove any stale labels for comment in FindNodes(Comment).visit(ir): if comment.text == '__STATEMENT_LABEL__': mapper[comment] = None return NestedTransformer(mapper, invalidate_source=False).visit(ir) def read_file(file_path): """ Reads a file and returns the content as string. This convenience function is provided to catch read errors due to bad character encodings in the file. It skips over these characters and prints a warning for the first occurence of such a character. """ filepath = Path(file_path) try: with filepath.open('r') as f: source = f.read() except UnicodeDecodeError as excinfo: warning('Skipping bad character in input file "%s": %s', str(filepath), str(excinfo)) kwargs = {'mode': 'r', 'encoding': 'utf-8', 'errors': 'ignore'} with codecs.open(filepath, **kwargs) as f: source = f.read() return source def combine_multiline_pragmas(nodes): """ Finds multi-line pragmas and combines them. """ pgroups = group_by_class(nodes, Pragma) for group in pgroups: # Separate sets of consecutive multi-line pragmas pred = lambda p: not p.content.rstrip().endswith('&') # pylint: disable=unnecessary-lambda-assignment for pragmaset in split_after(group, pred=pred): source = join_source_list(tuple(p.source for p in pragmaset)) content = ' '.join(p.content.rstrip(' &') for p in pragmaset) new_pragma = Pragma( keyword=pragmaset[0].keyword, content=content, source=source ) nodes = replace_windowed(nodes, pragmaset, subs=(new_pragma,)) return nodes class CombineMultilinePragmasTransformer(Transformer): """ Combine multiline :any:`Pragma` nodes into single ones. """ def visit_tuple(self, o, **kwargs): """ Finds multi-line pragmas and combines them in-place. """ o = combine_multiline_pragmas(o) visited = tuple(self.visit(i, **kwargs) for i in o) # Strip empty sublists/subtuples or None entries return tuple(i for i in visited if i is not None and as_tuple(i)) class RangeIndexTransformer(Transformer): """ :any:`Transformer` that replaces ``arr(1:n)`` notations with ``arr(n)`` in :any:`VariableDeclaration`. """ retriever = ExpressionRetriever(lambda e: isinstance(e, (sym.Array))) @staticmethod def is_one_index(dim): return isinstance(dim, sym.RangeIndex) and dim.lower == 1 and dim.step is None def visit_VariableDeclaration(self, o, **kwargs): # pylint: disable=unused-argument """ Gets all :any:`Array` symbols and adjusts dimension and shape. """ vmap = {} for v in self.retriever.retrieve(o.symbols): dimensions = tuple(d.upper if self.is_one_index(d) else d for d in v.dimensions) _type = v.type if _type.shape: shape = tuple(d.upper if self.is_one_index(d) else d for d in _type.shape) _type = _type.clone(shape=shape) vmap[v] = v.clone(dimensions=dimensions, type=_type) mapper = SubstituteExpressionsMapper(vmap) return o.clone(symbols=mapper(o.symbols, recurse_to_declaration_attributes=True)) class RemoveDuplicateVariableDeclarationsForExternalProcedures(Transformer): """ :any:`Transformer` that removes procedure symbols from :any:`VariableDeclarations` if they have the ``external`` attribute This is because Fortran's external-stmt allows to declare procedure symbols as external separate to their return type declaration. That makes it virtually impossible to determine that this return type declaration refers to a procedure rather than a local variable until the corresponding ``EXTERNAL`` statement has been encountered. Because Fortran allows to represent this also as an attribute in the same declaration, we choose this to represent external procedures in all cases. This means, we are replacing .. code-block:: REAL :: ext_func EXTERNAL ext_func by the equivalent representation .. code-block:: REAL, EXTERNAL :: ext_func The frontends will readily translate external statements to the procedure declaration with the ``EXTERNAL`` attribute, and therefore this transformer only has to remove the duplicate variable declarations. """ def visit_VariableDeclaration(self, o, **kwargs): # pylint: disable=unused-argument symbols = tuple( s for s in o.symbols if not (s.type.external and isinstance(s.type.dtype, ProcedureType)) ) if not symbols: return None if len(symbols) < len(o.symbols): return o._update(symbols=symbols) return o @Timer(logger=detail, text=lambda s: f'[Loki::Frontend] Executed sanitize_ir in {s:.2f}s') def sanitize_ir(_ir, frontend, pp_registry=None, pp_info=None): """ Utility function to sanitize internal representation after creating it from the parse tree of a frontend It carries out post-processing according to :data:`pp_info` and applies the following operations: * :any:`inline_comments` to attach inline-comments to IR nodes * :any:`ClusterCommentTransformer` to combine multi-line comments into :any:`CommentBlock` * :any:`CombineMultilinePragmasTransformer` to combine multi-line pragmas into a single node Parameters ---------- _ir : :any:`Node` The root node of the internal representation tree to be processed frontend : :any:`Frontend` The frontend from which the IR was created pp_registry: dict, optional Registry of pre-processing items to be applied pp_info : optional Information from internal preprocessing step that was applied to work around parser limitations and that should be re-inserted """ # Apply postprocessing rules to re-insert information lost during preprocessing if pp_info is not None and pp_registry is not None: for r_name, rule in pp_registry.items(): info = pp_info.get(r_name, None) _ir = rule.postprocess(_ir, info) # Perform some minor sanitation tasks _ir = InlineCommentTransformer(inplace=True, invalidate_source=False).visit(_ir) _ir = ClusterCommentTransformer(inplace=True, invalidate_source=False).visit(_ir) if frontend == OMNI: # Revert OMNI's array dimension expansion from `a(n)` => `arr(1:n)` _ir = RangeIndexTransformer(invalidate_source=False).visit(_ir) _ir = inline_labels(_ir) if frontend == FP: _ir = CombineMultilinePragmasTransformer(inplace=True, invalidate_source=False).visit(_ir) _ir = RemoveDuplicateVariableDeclarationsForExternalProcedures(inplace=True, invalidate_source=False).visit(_ir) return _ir loki-ecmwf-0.3.6/loki/frontend/fparser.py0000664000175000017500000043720315167130205020543 0ustar alastairalastair# (C) Copyright 2018- ECMWF. # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. # pylint: disable=too-many-lines import re from itertools import takewhile from codetiming import Timer try: from fparser.two.parser import ParserFactory from fparser.two.utils import get_child, walk from fparser.two import Fortran2003 from fparser.common.readfortran import FortranStringReader HAVE_FP = True """Indicate whether fparser frontend is available.""" except ImportError: HAVE_FP = False from loki.frontend.source import Source from loki.frontend.preprocessing import sanitize_registry from loki.frontend.util import read_file, FP, sanitize_ir from loki import ir from loki.ir import ( GenericVisitor, FindNodes, AttachScopes, attach_pragmas, detach_pragmas, pragmas_attached, process_dimension_pragmas ) import loki.expression.symbols as sym from loki.expression.operations import ( StringConcat, ParenthesisedAdd, ParenthesisedMul, ParenthesisedDiv, ParenthesisedPow ) from loki.expression import AttachScopesMapper from loki.logging import debug, detail, info, warning, error from loki.tools import ( as_tuple, flatten, CaseInsensitiveDict, LazyNodeLookup, dict_override ) from loki.types import BasicType, DerivedType, ProcedureType, SymbolAttributes, Scope from loki.config import config __all__ = ['HAVE_FP', 'FParser2IR', 'parse_fparser_file', 'parse_fparser_source', 'parse_fparser_ast', 'parse_fparser_expression', 'get_fparser_node'] @Timer(logger=debug, text=lambda s: f'[Loki::FP] Executed parse_fparser_file in {s:.2f}s') def parse_fparser_file(filename): """ Generate a parse tree from file via fparser """ info(f'[Loki::FP] Parsing {filename}') fcode = read_file(filename) return parse_fparser_source(source=fcode) @Timer(logger=detail, text=lambda s: f'[Loki::FP] Executed parse_fparser_source in {s:.2f}s') def parse_fparser_source(source): """ Generate a parse tree from string """ if not HAVE_FP: error('Fparser is not available. Try "pip install fparser".') raise RuntimeError # Clear FParser's symbol tables if the FParser version is new enough to have them try: from fparser.two.symbol_table import SYMBOL_TABLES # pylint: disable=import-outside-toplevel SYMBOL_TABLES.clear() except ImportError: pass reader = FortranStringReader(source, ignore_comments=False) f2008_parser = ParserFactory().create(std='f2008') return f2008_parser(reader) @Timer(logger=detail, text=lambda s: f'[Loki::FP] Executed parse_fparser_ast in {s:.2f}s') def parse_fparser_ast(ast, raw_source, pp_info=None, definitions=None, scope=None): """ Generate an internal IR from fparser parse tree Parameters ---------- ast : The fparser parse tree as created by :any:`parse_fparser_source` or :any:`parse_fparser_file` raw_source : str The raw source string from which :attr:`ast` was generated pp_info : optional Information from internal preprocessing step that was applied to work around parser limitations and that should be re-inserted definitions : list of :any:`Module`, optional List of external module definitions to attach upon use scope : :any:`Scope` Scope object for which to parse the AST. Returns ------- :any:`Node` The control flow tree """ # Parse the raw FParser language AST into our internal IR _ir = FParser2IR(raw_source=raw_source, definitions=definitions, pp_info=pp_info, scope=scope).visit(ast) _ir = sanitize_ir(_ir, FP, pp_registry=sanitize_registry[FP], pp_info=pp_info) return _ir def parse_fparser_expression(source, scope): """ Parse an expression string into an expression tree. This exploits Fparser's internal parser structure that relies on recursively matching strings against a list of node types. Usually, this would start by matching against module, subroutine or program. Here, we shortcut this hierarchy by directly matching against a primary expression, thus this should be able to parse any syntactically correct Fortran expression. Parameters ---------- source : str The expression as a string scope : :any:`Scope` The scope to which symbol names inside the expression belong Returns ------- :any:`Expression` The expression tree corresponding to the expression """ if not HAVE_FP: error('Fparser is not installed') raise RuntimeError _ = ParserFactory().create(std='f2008') # Wrap source in brackets to make sure it appears like a valid expression # for fparser, and strip that Parenthesis node from the ast immediately after ast = Fortran2003.Primary('(' + source + ')').children[1] # We parse the standalone expression with a dummy scope, to avoid # overriding existing type info from the given scope, before # attaching it after the fact. _ir = parse_fparser_ast(ast, source, scope=Scope()) _ir = AttachScopes().visit(_ir, scope=scope) return _ir def get_fparser_node(ast, node_type_name, first_only=True, recurse=False): """ Extract child nodes with type given by :attr:`node_type_name` from an fparser parse tree Parameters ---------- ast : The fparser parse tree as created by :any:`parse_fparser_source` or :any:`parse_fparser_file` node_type_name : str or list of str The name of the node type to extract, e.g. `Module`, `Specification_Part` etc. first_only : bool, optional Return only first instance matching :attr:`node_type_name`. Defaults to `True`. recurse : bool, optional Walk the entire parse tree instead of looking only in the children of :attr:`ast`. Defaults to `False`. Returns ------- :class:`fparser.two.util.Base` The node of requested type (or a list of these nodes if :attr:`all` is `True`) """ node_types = tuple(getattr(Fortran2003, name) for name in as_tuple(node_type_name)) if recurse: nodes = walk(ast, node_types) else: nodes = [c for c in ast.children if isinstance(c, node_types)] if first_only: return nodes[0] if nodes else None return nodes def node_sublist(nodelist, starttype, endtype): """ Extract a subset of nodes from a list that sits between marked start and end nodes. """ sublist = [] active = False for node in nodelist: if isinstance(node, endtype): active = False if active: sublist += [node] if isinstance(node, starttype): active = True return sublist def rget_child(node, node_type): """ Searches for the last, immediate child of the supplied node that is of the specified type. Parameters ---------- node : :any:`fparser.two.utils.Base` the node whose children will be searched node_type : class name or tuple of class names the class(es) of child node to search for. Returns ------- :any:`fparser.two.utils.Base` the last child node of type node_type that is encountered or ``None``. """ for child in reversed(node.children): if isinstance(child, node_type): return child return None def _get_comments_from_section(sec, include_pragmas=False, reverse=False): """ Extract leading or trailing :any:`Comment` or `:any:`CommentBlock` nodes from a :any:`Section`. Parameters ---------- sec : :any:`Section` Code section from which to extract comment nodes include_pragmas : bool Flag to enable matching :any:`Pragma` nodes reverse : bool Flag to enable matching trailing comment nodes Returns ------- tuple of :any:`Node` Leading or trailing comment or pragma nodes """ _matches = (ir.Comment, ir.CommentBlock) if include_pragmas: _matches += (ir.Pragma,) def is_comment(n): return isinstance(n, _matches) # Pick out comments from the beginning of the section and update in-place nodes = reversed(sec.body) if reverse else sec.body comments = tuple(takewhile(is_comment, nodes)) sec._update(body=tuple(filter(lambda n: n not in comments, sec.body))) return reversed(comments) if reverse else comments class FParser2IR(GenericVisitor): # pylint: disable=unused-argument # Stop warnings about unused arguments def __init__(self, raw_source, definitions=None, pp_info=None, scope=None): super().__init__() self.raw_source = raw_source.splitlines(keepends=True) self.definitions = CaseInsensitiveDict((d.name, d) for d in as_tuple(definitions)) self.pp_info = pp_info self.default_scope = scope @staticmethod def warn_or_fail(msg): if config['frontend-strict-mode']: error(msg) raise NotImplementedError warning(msg) def get_source(self, node, end_node=None): """ Builds the source object for a given (pair of) AST node(s). """ # Only create Source object if configured and item is given if not config['frontend-store-source']: return None end_node = end_node if end_node else node if not (node.item and end_node.item): return None # Create source object that records lines and raw source string lines = (node.item.span[0], end_node.item.span[1]) string = ''.join(self.raw_source[lines[0] - 1:lines[1]]).strip('\n') return Source(lines=lines, string=string) def get_label(self, o): """ Helper method that returns the label of the node. """ if o is not None and not isinstance(o, str) and o.item is not None: return getattr(o.item, 'label', None) return None def visit(self, o, **kwargs): # pylint: disable=arguments-differ """ Generic dispatch method that tries to generate meta-data from source. """ if o and o.item: kwargs['source'] = self.get_source(o) kwargs['label'] = self.get_label(o) kwargs.setdefault('scope', self.default_scope) return super().visit(o, **kwargs) def visit_List(self, o, **kwargs): """ Universal routine for auto-generated ``*_List`` types in fparser ``*_List`` types have their items children """ return tuple(self.visit(i, **kwargs) for i in o.children) def visit_Intrinsic_Stmt(self, o, **kwargs): """ Universal routine to capture nodes as plain string in the IR """ label = kwargs.get('label') label = str(label) if label else label # Ensure srting labels return ir.Intrinsic(text=o.tostr(), label=label, source=kwargs.get('source')) # # Base blocks # def create_contained_procedures(self, o, **kwargs): """ Helper utility that creates :any:`Subroutine` objects before the full parse to ensure the scope hierarchy is in place. Notes ----- We first make sure the procedure objects for all internal procedures are instantiated before parsing the actual spec and body of the parent routine. This way, all procedure types should exist in the scope and any use of their symbol (e.g. in a :any:`CallStatement` or :any:`InlineCall`) can be matched against a type. """ if not o: return member_asts = tuple( c for c in o.children if isinstance(c, (Fortran2003.Subroutine_Subprogram, Fortran2003.Function_Subprogram)) ) # Instantiate the procedure objects from their initial "stmt # line" to fill the type cache of the scope. This is needed to # get `ProcedureType` objecst and identify `InlineCall` objects. for c in member_asts: self.visit(get_child(c, (Fortran2003.Subroutine_Stmt, Fortran2003.Function_Stmt)), **kwargs) def visit_Specification_Part(self, o, **kwargs): """ The specification part of a program-unit :class:`fparser.two.Fortran2003.Specification_Part` has variable number of children making up the body of the spec. """ children = as_tuple(flatten(self.visit(c, **kwargs) for c in o.children)) return ir.Section(body=children, source=kwargs.get('source')) visit_Implicit_Part = visit_List visit_Program = visit_Specification_Part visit_Execution_Part = visit_Specification_Part visit_Internal_Subprogram_Part = visit_Specification_Part visit_Module_Subprogram_Part = visit_Specification_Part # # Variable, procedure and type names # def visit_Name(self, o, **kwargs): """ A symbol name :class:`fparser.two.Fortran2003.Name` has no children. """ name = o.tostr() scope = kwargs.get('scope', None) parent = kwargs.get('parent') if parent: assert hasattr(parent, 'scope'), f'[Loki::Frontend] Parent "{parent}" for AST node "{o}" has no scope!' scope = parent.scope if scope: scope = scope.get_symbol_scope(name) return sym.Variable(name=name, parent=parent, scope=scope) def visit_Type_Name(self, o, **kwargs): """ A derived type name :class:`fparser.two.Fortran2003.Type_Name` has no children. """ return DerivedType(o.tostr()) def visit_Part_Ref(self, o, **kwargs): """ A part of a data ref (e.g., flat variable or array name, or name of a derived type variable or member) and, optionally, a subscript list :class:`fparser.two.Fortran2003.Part_Ref` has two children: * :class:`fparser.two.Fortran2003.Name`: the part name * :class:`fparser.two.Fortran2003.Section_Subscript_List`: the subscript (or `None`) """ name = self.visit(o.children[0], **kwargs) with dict_override(kwargs, {'parent': None}): # Don't pass any parent on to dimension symbols dimensions = self.visit(o.children[1], **kwargs) if dimensions: name = name.clone(dimensions=dimensions) # Fparser wrongfully interprets function calls as Part_Ref sometimes # This should go away once fparser has a basic symbol table, see # https://github.com/stfc/fparser/issues/201 for some details _type = kwargs['scope'].symbol_attrs.lookup(name.name) if _type is None and (definition := self.definitions.get(name.name)): # We don't have any type information for this, which means it has # not been declared locally. Check the definitions for enriched # type information: if isinstance(dtype := definition.procedure_type, ProcedureType): _type = name.type.clone(dtype=dtype) name = name.clone(type=_type) if _type and isinstance(_type.dtype, ProcedureType): name = name.clone(dimensions=None) call = sym.InlineCall(name, parameters=dimensions, kw_parameters=()) return call return name def visit_Data_Ref(self, o, **kwargs): """ A fully qualified name for accessing a derived type or class member, composed from individual :class:`fparser.two.Fortran2003.Part_Ref` as ``part-ref [% part-ref [% part-ref ...] ]`` :class:`fparser.two.Fortran2003.Data_Ref` has variable number of children, depending on the number of part-ref. """ var = self.visit(o.children[0], **kwargs) for c in o.children[1:]: parent = var kwargs['parent'] = parent var = self.visit(c, **kwargs) if isinstance(var, sym.InlineCall): # This is a function call with a type-bound procedure, so we need to # update the name slightly different function = var.function.clone(name=f'{parent.name}%{var.function.name}', parent=parent) var = var.clone(function=function) else: # Hack: Need to force re-evaluation of the type from parent here via `type=None` # We know there's a parent, but we cannot trust the auto-generation of the type, # since the type lookup via parents can create mismatched DeferredTypeSymbols. var = var.clone( name=f'{parent.name}%{var.name}', parent=parent, scope=parent.scope, type=None ) return var # # Imports of external names # def visit_Use_Stmt(self, o, **kwargs): """ An import of symbol names via ``USE`` :class:`fparser.two.Fortran2003.Use_Stmt` has five children: * module-nature (`str`: 'INTRINSIC' or 'NON_INTRINSIC' or `None` if absent) * '::' (`str`) if a double colon is used, otherwise `None` * module-name :class:`fparser.two.Fortran2003.Module_Name` followed by * ', ONLY:' (`str`) and :class:`fparser.two.Fortran2003.Only_List`, or * ',' (`str`) and :class:`fparser.two.Fortran2003.Rename_List`, or * '' (`str`) and no only-list or rename-list """ if o.children[0] is not None: # Module nature nature = str(o.children[0]) else: nature = None name = o.children[2].tostr() if nature and nature.lower() == 'intrinsic': # Do not use module ref if we refer to an intrinsic module module = None else: module = self.definitions.get(name) scope = kwargs['scope'] if o.children[3] == '' or o.children[3] == ',': # No ONLY list (import all) symbols = () # Rename list if o.children[4]: rename_list = dict(self.visit(o.children[4], **kwargs)) else: rename_list = {} if module is not None: # Import symbol attributes from module, if available for k, v in module.symbol_attrs.items(): # Don't import private module symbols if v.private: continue if module.default_access_spec == "private": if k not in module.public_access_spec and not v.public: continue else: if k in module.private_access_spec: continue if k in rename_list: local_name = rename_list[k].name scope.symbol_attrs[local_name] = v.clone(imported=True, module=module, use_name=k) else: # Need to explicitly reset use_name in case we are importing a symbol # that stems from an import with a rename-list scope.symbol_attrs[k] = v.clone(imported=True, module=module, use_name=None) elif rename_list: # Module not available but some information via rename-list scope.symbol_attrs.update({ v.name: v.type.clone(imported=True, use_name=k) for k, v in rename_list.items() }) rename_list = tuple(rename_list.items()) if rename_list else None elif o.children[3] == ', ONLY:': # ONLY list given (import only selected symbols) symbols = () if o.children[4] is None else self.visit(o.children[4], **kwargs) # No rename-list rename_list = None deferred_type = SymbolAttributes(BasicType.DEFERRED, imported=True) if module is None: # Initialize symbol attributes as DEFERRED for s in symbols: if isinstance(s, tuple): # Renamed symbol scope.symbol_attrs[s[1].name] = deferred_type.clone(use_name=s[0]) else: scope.symbol_attrs[s.name] = deferred_type else: # Import symbol attributes from module for s in symbols: if isinstance(s, tuple): # Renamed symbol _type = module.symbol_attrs.get(s[0], deferred_type) scope.symbol_attrs[s[1].name] = _type.clone( imported=True, module=module, use_name=s[0] ) else: # Need to explicitly reset use_name in case we are importing a symbol # that stems from an import with a rename-list _type = module.symbol_attrs.get(s.name, deferred_type) scope.symbol_attrs[s.name] = _type.clone( imported=True, module=module, use_name=None ) symbols = tuple( s[1].rescope(scope=scope) if isinstance(s, tuple) else s.rescope(scope=scope) for s in symbols ) else: raise ValueError(f'Unexpected only/rename-list value in USE statement: {o.children[3]}') return ir.Import(module=name, symbols=symbols, nature=nature, rename_list=rename_list, source=kwargs.get('source'), label=kwargs.get('label')) visit_Only_List = visit_List visit_Rename_List = visit_List def visit_Rename(self, o, **kwargs): """ A rename of an imported symbol :class:`fparser.two.Fortran2003.Rename` has three children: * 'OPERATOR' (`str`) or `None` * :class:`fparser.two.Fortran2003.Local_Name` or :class:`fparser.two.Fortran2003.Local_Defined_Operator` * :class:`fparser.two.Fortran2003.Use_Name` or :class:`fparser.two.Fortran2003.Use_Defined_Operator` """ if o.children[0] == 'OPERATOR': self.warn_or_fail('OPERATOR in rename-list not yet implemented') return () assert o.children[0] is None return (str(o.children[2]), self.visit(o.children[1], **kwargs)) # # Variable declarations # def visit_Type_Declaration_Stmt(self, o, **kwargs): """ Variable declaration statement :class:`fparser.two.Fortran2003.Type_Declaration_Stmt` has 3 children: * :class:`fparser.two.Fortran2003.Declaration_Type_Spec` (:class:`fparser.two.Fortran2003.Intrinsic_Type_Spec` or :class:`fparser.two.Fortran2003.Derived_Type_Spec`) * :class:`fparser.two.Fortran2003.Attr_Spec_List` * :class:`fparser.two.Fortran2003.Entity_Decl_List` """ # First, obtain data type and attributes _type = self.visit(o.children[0], **kwargs) attrs = self.visit(o.children[1], **kwargs) if o.children[1] else () attrs = dict(attrs) # Then, build the common symbol type for all variables _type = _type.clone(**attrs) # Last, instantiate declared variables variables = as_tuple(self.visit(o.children[2], **kwargs)) # DIMENSION is called shape for us if _type.dimension: _type = _type.clone(shape=_type.dimension, dimension=None) # Attach dimension attribute to variable declaration for uniform # representation of variables in declarations variables = as_tuple(v.clone(dimensions=_type.shape) for v in variables) # Make sure KIND and INITIAL (which can be a name) are in the right scope scope = kwargs['scope'] if _type.kind is not None: kind = AttachScopesMapper()(_type.kind, scope=scope) _type = _type.clone(kind=kind) if _type.initial is not None: initial = AttachScopesMapper()(_type.initial, scope=scope) _type = _type.clone(initial=initial) # EXTERNAL attribute means this is actually a function or subroutine # Since every symbol refers to a different function we have to update the # type definition for every symbol individually if _type.external: for var in variables: type_kwargs = _type.__dict__.copy() return_type = SymbolAttributes(_type.dtype) if _type.dtype is not None else None external_type = scope.symbol_attrs.lookup(var.name) if external_type is None: type_kwargs['dtype'] = ProcedureType( var.name, is_function=return_type is not None, return_type=return_type ) else: type_kwargs['dtype'] = external_type.dtype scope.symbol_attrs[var.name] = var.type.clone(**type_kwargs) variables = tuple(var.rescope(scope=scope) for var in variables) return ir.ProcedureDeclaration( symbols=variables, external=True, source=kwargs.get('source'), label=kwargs.get('label') ) # Update symbol table entries and rescope scope.symbol_attrs.update({var.name: var.type.clone(**_type.__dict__) for var in variables}) variables = tuple(var.rescope(scope=scope) for var in variables) return ir.VariableDeclaration( symbols=variables, dimensions=_type.shape, source=kwargs.get('source'), label=kwargs.get('label') ) def visit_Intrinsic_Type_Spec(self, o, **kwargs): """ An intrinsic type :class:`fparser.two.Fortran2003.Intrinsic_Type_Spec` has 2 children: * type name (str) * kind (:class:`fparser.two.Fortran2003.Kind_Selector`) or length (:class:`fparser.two.Fortran2003.Length_Selector`) """ dtype = BasicType.from_str(o.children[0]) if o.children[1]: if dtype not in ( BasicType.INTEGER, BasicType.REAL, BasicType.COMPLEX, BasicType.LOGICAL, BasicType.CHARACTER ): raise ValueError(f'Unknown kind for intrinsic type: {o.children[0]}') attr = self.visit(o.children[1], **kwargs) if attr: attr = dict(attr) return SymbolAttributes(dtype, **attr) return SymbolAttributes(dtype) def visit_Kind_Selector(self, o, **kwargs): """ A kind selector of an intrinsic type :class:`fparser.two.Fortran2003.Kind_Selector` has 2 or 3 children: * ``'*'`` (str) and :class:`fparser.two.Fortran2003.Char_Length`, or * ``'('`` (str), :class:`fparser.two.Fortran2003.Scalar_Int_Initialization_Expr`, and ``')'`` (str) """ if len(o.children) in (2, 3) and (o.children[0] == '*' or o.children[0] + str(o.children[-1]) == '()'): return (('kind', self.visit(o.children[1], **kwargs)),) self.warn_or_fail('Unknown kind selector') return None def visit_Length_Selector(self, o, **kwargs): """ A length selector for intrinsic character type :class:`fparser.two.Fortran2003.Length_Selector` has 3 children: * '(' (str) * :class:`fparser.two.Fortran2003.Char_Length` or :class:`fparser.two.Fortran2003.Type_Param_Value` * ')' (str) """ assert o.children[0] == '*' or (o.children[0] == '(' and o.children[2] == ')') return (('length', self.visit(o.children[1], **kwargs)),) def visit_Char_Length(self, o, **kwargs): """ Length specifier in the Length_Selector :class:`fparser.two.Fortran2003.Length_Selector` has one child: * length value (str) """ assert o.children[0] == '(' and o.children[2] == ')' return self.visit(o.children[1], **kwargs) def visit_Char_Selector(self, o, **kwargs): """ Length- and kind-selector for intrinsic character type :class:`fparser.two.Fortran2003.Char_Selector` has 2 children: * :class:`fparser.two.Fortran2003.Length_Selector` * some scalar expression for the kind """ length = None kind = None if o.children[0] is not None: length = self.visit(o.children[0], **kwargs) if o.children[1] is not None: kind = self.visit(o.children[1], **kwargs) return (('length', length), ('kind', kind)) def visit_Type_Param_Value(self, o, **kwargs): """ The value of a type parameter in a type spefication (such as length of a CHARACTER) :class:`fparser.two.Fortran2003.Type_Param_Value` has only 1 attribute: * :attr:`string` : the value of the parameter (str) """ if o.string in '*:': return o.string return self.visit(o.string, **kwargs) def visit_Declaration_Type_Spec(self, o, **kwargs): """ A derived type specifier in a declaration :class:`fparser.two.Fortran2003.Declaration_Type_Spec` has 2 children: * keyword 'TYPE' or 'CLASS' (str) * :class:`fparser.two.Fortran2003.Derived_Type_Spec` """ if o.children[0].upper() in ('TYPE', 'CLASS'): dtype = self.visit(o.children[1], **kwargs) # Look for a previous definition of this type _type = kwargs['scope'].symbol_attrs.lookup(dtype.name) if _type is None or _type.dtype is BasicType.DEFERRED: _type = SymbolAttributes(dtype) if o.children[0].upper() == 'CLASS': _type.polymorphic = True # Strip import annotations return _type.clone(imported=None, module=None) return self.visit_Base(o, **kwargs) def visit_Dimension_Attr_Spec(self, o, **kwargs): """ The dimension specification as attribute in a declaration :class:`fparser.two.Fortran2003.Dimensions_Attr_Spec` has 2 children: * attribute name (str) * :class:`fparser.two.Fortran2003.Array_Spec` """ return (o.children[0].lower(), self.visit(o.children[1], **kwargs)) def visit_Intent_Attr_Spec(self, o, **kwargs): """ The intent specification in a declaration :class:`fparser.two.Fortran2003.Intent_Attr_Spec` has 2 children: * 'INTENT' keyword * :class:`fparser.two.Fortran2003.Intent_Spec` """ return (o.children[0].lower(), o.children[1].tostr().lower()) visit_Attr_Spec_List = visit_List def visit_Attr_Spec(self, o, **kwargs): """ A declaration attribute :class:`fparser.two.Fortran2003.Attr_Spec` has no children. """ return (str(o).lower(), True) def visit_Access_Spec(self, o, **kwargs): """ A declaration attribute for access specification (PRIVATE, PUBLIC) :class:`fparser.two.Fortran2003.Access_Spec` has no children. """ return (o.string.lower(), True) visit_Entity_Decl_List = visit_List def visit_Entity_Decl(self, o, **kwargs): """ A variable entity in a declaration :class:`fparser.two.Fortran2003.Entity_Decl` has 4 children: * object name (:class:`fparser.two.Fortran2003.Name`) * array spec (:class:`fparser.two.Fortran2003.Array_Spec`) * char length (:class:`fparser.two.Fortran2003.Char_Length`) * init (:class:`fparser.two.Fortran2003.Initialization`) """ # Do not pass scope down, as it might alias with previously # created symbols. Instead, let the rescope in the Declaration # assign the right scope, always! with dict_override(kwargs, {'scope': None}): var = self.visit(o.children[0], **kwargs) if o.children[1]: dimensions = as_tuple(self.visit(o.children[1], **kwargs)) var = var.clone(dimensions=dimensions, type=var.type.clone(shape=dimensions)) if o.children[2]: char_length = self.visit(o.children[2], **kwargs) var = var.clone(type=var.type.clone(length=char_length)) if o.children[3]: init = self.visit(o.children[3], **kwargs) var = var.clone(type=var.type.clone(initial=init)) return var def visit_Explicit_Shape_Spec(self, o, **kwargs): """ Explicit shape specification for arrays :class:`fparser.two.Fortran2003.Explicit_Shape_Spec` has 2 children: * lower bound (if explicitly given) * upper bound """ lower_bound, upper_bound = None, None if o.children[1] is not None: upper_bound = self.visit(o.children[1], **kwargs) if o.children[0] is not None: lower_bound = self.visit(o.children[0], **kwargs) if upper_bound is not None and lower_bound is None: return upper_bound source = kwargs.get('source') if source: source = source.clone_with_string(o.string) return sym.RangeIndex((lower_bound, upper_bound)) def visit_Assumed_Size_Spec(self, o, **kwargs): """ Assumed size specification for arrays :class:`fparser.two.Fortran2003.Assumed_Size_Spec` has 2 children: * An explicit shape specification preceding the assumed size specifier * lower bound (if explicitly given) """ dims = [] lower_bound = None if isinstance(o.children[0], Fortran2003.Explicit_Shape_Spec_List): # pylint: disable=no-member dims += list(self.visit(child, **kwargs) for child in o.children[0].children) if o.children[1] is not None: # to workaround a 0 lbound lower_bound = self.visit(o.children[1], **kwargs) if lower_bound is not None: # to workaround a 0 lbound dims += [sym.RangeIndex((lower_bound, sym.IntrinsicLiteral('*'))),] else: dims += [sym.IntrinsicLiteral('*'),] return as_tuple(dims) visit_Explicit_Shape_Spec_List = visit_List visit_Assumed_Shape_Spec = visit_Explicit_Shape_Spec visit_Assumed_Shape_Spec_List = visit_List visit_Deferred_Shape_Spec = visit_Explicit_Shape_Spec visit_Deferred_Shape_Spec_List = visit_List def visit_Initialization(self, o, **kwargs): """ Variable initialization in declaration :class:`fparser.two.Fortran2003.Initialization` has 2 children: * '=' or '=>' (str) * init expr """ if o.children[0] == '=': return self.visit(o.items[1], **kwargs) if o.children[0] == '=>': return self.visit(o.items[1], **kwargs) raise ValueError(f'Invalid assignment operator {o.children[0]}') visit_Component_Initialization = visit_Initialization def visit_External_Stmt(self, o, **kwargs): """ An ``EXTERNAL`` statement to specify the external attribute for a list of names :class:`fparser.two.Fortran2003.External_Stmt` has 2 children: * keyword 'EXTERNAL (`str`) * the list of names :class:`fparser.two.Fortran2003.External_Name_List` """ assert o.children[0].upper() == 'EXTERNAL' # Compile the list of names... symbols = self.visit(o.children[1], **kwargs) # ...and update their symbol table entry... scope = kwargs['scope'] for var in symbols: _type = scope.symbol_attrs.lookup(var.name) or SymbolAttributes(dtype=BasicType.DEFERRED) if _type.dtype == BasicType.DEFERRED: dtype = ProcedureType(var.name, is_function=False) else: dtype = ProcedureType(var.name, is_function=True, return_type=_type) scope.symbol_attrs[var.name] = _type.clone(dtype=dtype, external=True) symbols = tuple(v.rescope(scope=scope) for v in symbols) declaration = ir.ProcedureDeclaration(symbols=symbols, external=True, source=kwargs.get('source'), label=kwargs.get('label')) return declaration visit_External_Name_List = visit_List def visit_Access_Stmt(self, o, **kwargs): """ An access-spec statement that specifies accessibility of symbols in a module :class:`faprser.two.Fortran2003.Access_Stmt` has 2 children: * keyword ``PRIVATE`` or ``PUBLIC`` (`str`) * optional list of names (:class:`fparser.two.Fortran2003.Access_Id_List`) or `None` """ from loki.module import Module # pylint: disable=import-outside-toplevel,cyclic-import assert isinstance(kwargs['scope'], Module) assert o.children[0] in ('PUBLIC', 'PRIVATE') if o.children[1] is None: assert kwargs['scope'].default_access_spec is None kwargs['scope'].default_access_spec = o.children[0].lower() else: access_id_list = [str(name).lower() for name in o.children[1].children] if o.children[0] == 'PUBLIC': kwargs['scope'].public_access_spec += as_tuple(access_id_list) else: kwargs['scope'].private_access_spec += as_tuple(access_id_list) # # Procedure declarations # def visit_Procedure_Declaration_Stmt(self, o, **kwargs): """ Procedure declaration statement :class:`fparser.two.Fortran2003.Procedure_Declaration_Stmt` has 3 children: * :class:`fparser.two.Fortran2003.Name`: the name of the procedure interface * :class:`fparser.two.Fortran2003.Proc_Attr_Spec_List` or `None`: the declared attributes (if any) * :class:`fparser.two.Fortran2003.Proc_Decl_List`: the local procedure names """ scope = kwargs['scope'] # Instantiate declared symbols symbols = as_tuple(self.visit(o.children[2], **kwargs)) # Any additional declared attributes attrs = self.visit(o.children[1], **kwargs) if o.children[1] else () attrs = dict(attrs) # Find out which procedure we are declaring (i.e., PROCEDURE()) assert o.children[0] is not None try: # This could be an implicit interface or dummy routine... return_type = SymbolAttributes(BasicType.from_str(o.children[0].tostr())) except ValueError: return_type = None if return_type is None: interface = self.visit(o.children[0], **kwargs) interface = AttachScopesMapper()(interface, scope=scope) if interface.type.dtype is BasicType.DEFERRED: # This is (presumably!) an external function with explicit interface that we # don't know because the type information is not available, e.g., because it's been # imported from another module or sits in an intfb.h header file. # So, we create a ProcedureType object with the interface name and use that dtype = ProcedureType(interface.name) interface = interface.clone(type=interface.type.clone(dtype=dtype)) _type = interface.type.clone(**attrs) else: interface = return_type.dtype _type = SymbolAttributes(BasicType.DEFERRED, **attrs) # Make sure any "bind_names" symbol (i.e. the procedure we're binding to) is in the right scope if _type.bind_names is not None: bind_names = AttachScopesMapper()(_type.bind_names, scope=scope) _type = _type.clone(bind_names=bind_names) # Update symbol table entries if return_type is None: scope.symbol_attrs.update({var.name: var.type.clone(**_type.__dict__) for var in symbols}) else: for var in symbols: dtype = ProcedureType(var.name, is_function=True, return_type=return_type) scope.symbol_attrs[var.name] = _type.clone(dtype=dtype) symbols = tuple(var.rescope(scope=scope) for var in symbols) return ir.ProcedureDeclaration( symbols=symbols, interface=interface, source=kwargs.get('source'), label=kwargs.get('label') ) visit_Proc_Attr_Spec_List = visit_List def visit_Proc_Attr_Spec(self, o, **kwargs): """ Procedure declaration attribute :class:`fparser.two.Fortran2003.Proc_Attr_Spec` has 2 children: * attribute name (`str`) * attribute value (such as ``IN``, ``OUT``, ``INOUT``) or `None` """ return (o.children[0].lower(), str(o.children[1]).lower() if o.children[1] is not None else True) visit_Proc_Decl_List = visit_List def visit_Proc_Decl(self, o, **kwargs): """ A symbol entity in a procedure declaration with initialization :class:`fparser.two.Fortran2003.Proc_Decl` has 3 children: * object name (:class:`fparser.two.Fortran2003.Name`) * operator ``=>`` (`str`) * initializer (:class:`fparser.two.Fortran2003.Function_Reference`) """ var = self.visit(o.children[0], **kwargs) assert o.children[1] == '=>' init = self.visit(o.children[2], **kwargs) return var.clone(type=var.type.clone(initial=init)) # # Array constructor # def visit_Array_Constructor(self, o, **kwargs): """ An array constructor expression :class:`fparser.two.Fortran2003.Array_Constructor` has three children: * left bracket (`str`): ``(/`` or ``[`` * the spec: :class:`fparser.two.Fortran2003.Ac_Spec` * right bracket (`str`): ``/)`` or ``]`` """ source = kwargs.get('source') if source: source = source.clone_with_string(o.string) if isinstance(o.children[1], Fortran2003.Ac_Spec): values, dtype = self.visit(o.children[1], **kwargs) else: values, dtype = self.visit(o.children[1], **kwargs), None return sym.LiteralList(values=values, dtype=dtype) def visit_Ac_Spec(self, o, **kwargs): """ The spec in an array constructor :class:`fparser.two.Fortran2003.Ac_Spec` has two children: * :class:`fparser.two.Fortran2003.Type_Spec` or None * :class:`fparser.two.Fortran2003.Ac_Value_List` """ if o.children[0] is not None: return self.visit(o.children[1], **kwargs), self.visit(o.children[0], **kwargs) return self.visit(o.children[1], **kwargs), None def visit_Ac_Value_List(self, o, **kwargs): """ The list of values in an array constructor """ return as_tuple(self.visit(c, **kwargs) for c in o.children) def visit_Ac_Implied_Do(self, o, **kwargs): """ An implied-do for array constructors :class:`fparser.two.Fortran2003.Ac_Implied_Do` has two children: * the expression as :class:`fparser.two.Fortran2003.Ac_Value_List` * the loop control as :class:`fparser.two.Fortran2003.Ac_Implied_Do_Control` """ values = self.visit(o.children[0], **kwargs) variable, bounds = self.visit(o.children[1], **kwargs) source = kwargs.get('source') if source: source = source.clone_with_string(o.string) return sym.InlineDo(values, variable, bounds) def visit_Ac_Implied_Do_Control(self, o, **kwargs): """ The "loop control" for an implied-do :class:`fparser.two.Fortran2003.Ac_Implied_Do_Control` has two children: * the variable name * the loop bounds """ variable = self.visit(o.children[0], **kwargs) bounds = tuple(self.visit(i, **kwargs) for i in o.children[1]) return (variable, sym.LoopRange(bounds)) # # DATA statements # def visit_Data_Stmt(self, o, **kwargs): """ A ``DATA`` statement :class:`fparser.two.Fortran2003.Data_Stmt` has variable number of children :class:`fparser.two.Fortran2003.Data_Stmt_Set`. """ data_statements = tuple(self.visit(data_set, **kwargs) for data_set in o.children) return data_statements def visit_Data_Stmt_Set(self, o, **kwargs): """ A data-stmt-set in a data-stmt :class:`fparser.two.Fortran2003.Data_Stmt_Set` has two children: * the object to initialize :class:`fparser.two.Fortran2003.Data_Stmt_Object` * the value list :class:`fparser.two.Fortran2003.Data_Stmt_Value_List` """ variable = self.visit(o.children[0], **kwargs) values = self.visit(o.children[1], **kwargs) return ir.DataDeclaration(variable=variable, values=values, label=kwargs.get('label'), source=kwargs.get('source')) def visit_Data_Implied_Do(self, o, **kwargs): """ An implied-do for data-stmt """ # TODO: Implement implied-do return self.visit_Base(o, **kwargs) visit_Data_Stmt_Object_List = visit_List visit_Data_Stmt_Value_List = visit_List def visit_Data_Stmt_Value(self, o, **kwargs): """ A value in a data-stmt-set :class:`fparser.two.Fortran2003.Data_Stmt_Value` has two children: * the repeat value :class:`fparser.two.Fortran2003.Data_Stmt_Repeat` * the constant :class:`fparser.two.Fortran2003.Data_Stmt_Constant` """ constant = self.visit(o.children[1], **kwargs) if o.children[0] is None: return constant repeat = self.visit(o.children[0], **kwargs) return self.create_operation('*', (repeat, constant)) # # Subscripts # visit_Section_Subscript_List = visit_List def visit_Subscript_Triplet(self, o, **kwargs): """ A subscript expression with ``[start] : [stop] [: stride]`` :class:`fparser.two.Fortran2003.Subscript_Triplet` has three children: * start :class:`fparser.two.Fortran2003.Subscript` or `None` * stop :class:`fparser.two.Fortran2003.Subscript` or `None` * stride :class:`fparser.two.Fortran2003.Stride` or `None` """ start = self.visit(o.children[0], **kwargs) if o.children[0] is not None else None stop = self.visit(o.children[1], **kwargs) if o.children[1] is not None else None stride = self.visit(o.children[2], **kwargs) if o.children[2] is not None else None source = kwargs.get('source') if source: source = source.clone_with_string(o.string) return sym.RangeIndex((start, stop, stride)) def visit_Array_Section(self, o, **kwargs): """ A subscript operation on a data-ref This includes dereferences such as ``a%b%c`` or extracting a substring. In practice, the first are typically flattened in the Fparser AST and directly returned as `Part_Ref`, so we should see only the substring operation here. :class:`fparser.two.Fortran2003.Array_Subscript` has two children: * the subscript data-ref :class:`fparser.two.Fortran2003.Data_Ref` * an optional substring range :class:`fparser.two.Fortran2003.Substring_Range` """ name = self.visit(o.children[0], **kwargs) if o.children[1] is None: return name substring = self.visit(o.children[1], **kwargs) return sym.StringSubscript(name, substring) def visit_Substring_Range(self, o, **kwargs): """ The range of a substring operation :class:`fparser.two.Fortran2003.Substring_Range` has two children: * start :class:`fparser.two.Fortran2003.Scalar_Int_Expr` or None * stop :class:`fparser.two.Fortran2003.Scalar_Int_Expr` or None """ start = self.visit(o.children[0], **kwargs) if o.children[0] is not None else None stop = self.visit(o.children[1], **kwargs) if o.children[1] is not None else None return sym.RangeIndex((start, stop)) def visit_Stride(self, o, **kwargs): # TODO: Implement Stride return self.visit_Base(o, **kwargs) # # Derived Type definition # def visit_Derived_Type_Def(self, o, **kwargs): """ A derived type definition :class:`fparser.two.Fortran2003.Derived_Type_Def` has variable number of children: * header stmt (:class:`fparser.two.Fortran2003.Derived_Type_Stmt`) * all of body (list of :class:`fparser.two.Fortran2003.Type_Param_Def_Stmt`, :class:`fparser.two.Fortran2003.Private_Or_Sequence`, :class:`fparser.two.Fortran2003.Component_Part`, :class:`fparser.two.Fortran2003.Type_Bound_Procedure_Part`) * end stmt (:class:`fparser.two.Fortran2003.End_Type_Stmt`) """ # Find start and end of construct derived_type_stmt = get_child(o, Fortran2003.Derived_Type_Stmt) derived_type_stmt_index = o.children.index(derived_type_stmt) end_type_stmt = get_child(o, Fortran2003.End_Type_Stmt) end_type_stmt_index = o.children.index(end_type_stmt) # Everything before the construct pre = as_tuple(self.visit(c, **kwargs) for c in o.children[:derived_type_stmt_index]) # Instantiate the TypeDef without its body # Note: This creates the symbol table for the declarations and # the typedef object registers itself in the parent scope typedef = self.visit(derived_type_stmt, **kwargs) # Pass down the typedef scope when building the body kwargs['scope'] = typedef body = [self.visit(c, **kwargs) for c in o.children[derived_type_stmt_index+1:end_type_stmt_index]] body = as_tuple(flatten(body)) # Infer any additional shape information from `!$loki dimension` pragmas body = attach_pragmas(body, ir.VariableDeclaration) body = process_dimension_pragmas(body) body = detach_pragmas(body, ir.VariableDeclaration) # Finally: update the typedef with its body and make sure all symbols # are in the right scope source = self.get_source(derived_type_stmt, end_node=end_type_stmt) typedef._update(body=body, source=source) typedef.rescope_symbols() return (*pre, typedef) def visit_Derived_Type_Stmt(self, o, **kwargs): """ The block header for the derived type definition :class:`fparser.two.Fortran2003.Derived_Type_Stmt` has 3 children: * attribute spec list (:class:`fparser.two.Fortran2003.Type_Attr_Spec_List`) * type name (:class:`fparser.two.Fortran2003.Type_Name`) * parameter name list (:class:`fparser.two.Fortran2003.Type_Param_Name_List`) """ if o.children[0] is not None: attrs = dict(self.visit(o.children[0], **kwargs)) abstract = attrs.get('abstract', False) extends = attrs.get('extends') bind_c = attrs.get('bind') == 'c' private = attrs.get('private', False) public = attrs.get('public', False) else: abstract = False extends = None bind_c = False private = False public = False name = o.children[1].tostr() if o.children[2] is not None: self.warn_or_fail('parameter-name-list not implemented for derived types') return ir.TypeDef( name=name, body=(), abstract=abstract, extends=extends, bind_c=bind_c, private=private, public=public, label=kwargs['label'], parent=kwargs['scope'] ) visit_Type_Attr_Spec_List = visit_List def visit_Type_Attr_Spec(self, o, **kwargs): """ A component declaration attribute :class:`fparser.two.Fortran2003.Type_Attr_Spec` has 2 children: * keyword (`str`) * value (`str`) or `None` """ if o.children[1] is not None: return (str(o.children[0]).lower(), str(o.children[1]).lower()) return (str(o.children[0]).lower(), True) def visit_Type_Param_Def_Stmt(self,o , **kwargs): self.warn_or_fail('Parameterized types not implemented') visit_Binding_Attr_List = visit_List def visit_Binding_Attr(self, o, **kwargs): """ A binding attribute :class:`fparser.two.Fortran2003.Binding_Attr_Spec` has no children """ keyword = str(o).lower() if keyword == 'pass': return ('pass_attr', True) if keyword == 'nopass': return ('pass_attr', False) if keyword in ('non_overridable', 'deferred'): return (keyword, True) self.warn_or_fail(f'Unsupported binding attribute: {str(o)}') return None def visit_Binding_PASS_Arg_Name(self, o, **kwargs): """ Named PASS attribute :class:`fparser.two.Fortran2003.Binding_PASS_Arg_Name` has two children: * `str`: 'PASS' * `Name`: the argument name """ return ('pass_attr', str(o.children[1])) def visit_Component_Part(self, o, **kwargs): """ Derived type definition components :class:`fparser.two.Fortran2003.Component_Part` has a list of :class:`fparser.two.Fortran2003.Data_Component_Def_Stmt` or :class:`fparser.two.Fortran2003.Proc_Component_Def_Stmt` as children """ return tuple(self.visit(c, **kwargs) for c in o.children) # The definition stmts (= components of a derived type) look identical to regular # variable and procedure declarations in the parse tree and are represented by # the same IR nodes in Loki visit_Data_Component_Def_Stmt = visit_Type_Declaration_Stmt visit_Component_Attr_Spec_List = visit_List visit_Component_Attr_Spec = visit_Attr_Spec visit_Dimension_Component_Attr_Spec = visit_Dimension_Attr_Spec visit_Component_Decl_List = visit_List visit_Component_Decl = visit_Entity_Decl visit_Proc_Component_Def_Stmt = visit_Procedure_Declaration_Stmt visit_Proc_Component_Attr_Spec_List = visit_List visit_Proc_Component_Attr_Spec = visit_Attr_Spec def visit_Type_Bound_Procedure_Part(self, o, **kwargs): """ Procedure definitions part in a derived type definition :class:`fparser.two.Fortran2003.Type_Bound_Procedure_Part` starts with the contains-stmt (:class:`fparser.two.Fortran2003.Contains_Stmt`) followed by (optionally) :class:`fparser.two.Fortran2003.Binding_Private_Stmt` and a sequence of :class:`fparser.two.Fortran2003.Proc_Binding_Stmt` """ return tuple(self.visit(c, **kwargs) for c in o.children) def visit_Specific_Binding(self, o, **kwargs): """ A specific binding for a type-bound procedure in a derived type :class:`fparser.two.Fortran2003.Specific_Binding` has five children: * interface name :class:`fparser.two.Fortran2003.Interface_Name` * binding attr list :class:`fparser.two.Fortran2003.Binding_Attr_List` * '::' (`str`) or `None` * name :class:`fparser.two.Fortran2003.Binding_Name` * procedure name :class:`fparser.two.Fortran2003.Procedure_Name` """ scope = kwargs['scope'] # Instantiate declared symbols symbols = as_tuple(self.visit(o.children[3], **kwargs)) # Procedure we bind to this type interface = None if o.children[0]: # Procedure interface provided # (we pass the parent scope down for this) kwargs['scope'] = scope.parent interface = self.visit(o.children[0], **kwargs) bind_names = as_tuple(interface) func_names = [interface.name] * len(symbols) assert o.children[4] is None kwargs['scope'] = scope elif o.children[4]: # we pass the parent scope down for this kwargs['scope'] = scope.parent bind_names = as_tuple(self.visit(o.children[4], **kwargs)) assert len(bind_names) == len(symbols) func_names = [i.name for i in bind_names] kwargs['scope'] = scope else: bind_names = None func_names = [s.name for s in symbols] # Look up the type of the procedure types = [scope.symbol_attrs.lookup(name) for name in func_names] types = [ SymbolAttributes(dtype=ProcedureType(name)) if not t or t.dtype == BasicType.DEFERRED else t for t, name in zip(types, func_names) ] # Any declared attributes attrs = self.visit(o.children[1], **kwargs) if o.children[1] else () attrs = dict(attrs) types = [t.clone(**attrs) for t in types] # Store the bind_names if bind_names: types = [t.clone(bind_names=as_tuple(i)) for t, i in zip(types, bind_names)] # Update symbol table entries scope.symbol_attrs.update({s.name: s.type.clone(**t.__dict__) for s, t in zip(symbols, types)}) symbols = tuple(var.rescope(scope=scope) for var in symbols) return ir.ProcedureDeclaration(symbols=symbols, interface=interface, source=kwargs.get('source'), label=kwargs.get('label')) def visit_Generic_Binding(self, o, **kwargs): """ A generic binding for a type-bound procedure in a derived type :class:`fparser.two.Fortran2003.Generic_Binding` has three children: * :class:`fparser.two.Fortran2003.Access_Spec` or None (access specifier) * :class:`fparser.two.Fortran2003.Generic_Spec` (the local name of the binding) * :class:`fparser.two.Fortran2003.Binding_Name_List` (the names it binds to) """ scope = kwargs['scope'] name = self.visit(o.children[1], **kwargs) bind_names = self.visit(o.children[2], **kwargs) bind_names = AttachScopesMapper()(bind_names, scope=scope) _type = SymbolAttributes(ProcedureType(name=name.name, is_generic=True), bind_names=as_tuple(bind_names)) if o.children[0] is not None: access_spec = self.visit(o.children[0], **kwargs) attrs = {access_spec[0]: access_spec[1]} _type = _type.clone(**attrs) scope.symbol_attrs[name.name] = _type name = name.rescope(scope=scope) return ir.ProcedureDeclaration( symbols=(name,), generic=True, source=kwargs.get('source'), label=kwargs.get('label') ) def visit_Final_Binding(self, o, **kwargs): """ A final binding for type-bound procedures in a derived type :class:`fparser.two.Fortran2003.Final_Binding` has two children: * keyword ``'FINAL'`` (`str`) * :class:`fparser.two.Fortran2003.Final_Subroutine_Name_List` (the list of routines) """ scope = kwargs['scope'] symbols = self.visit(o.children[1], **kwargs) symbols = tuple(var.rescope(scope=scope) for var in symbols) return ir.ProcedureDeclaration( symbols=symbols, final=True, source=kwargs.get('source'), label=kwargs.get('label') ) visit_Binding_Name_List = visit_List visit_Final_Subroutine_Name_List = visit_List visit_Contains_Stmt = visit_Intrinsic_Stmt visit_Binding_Private_Stmt = visit_Intrinsic_Stmt visit_Private_Components_Stmt = visit_Intrinsic_Stmt visit_Sequence_Stmt = visit_Intrinsic_Stmt # # ASSOCIATE blocks # def visit_Associate_Construct(self, o, **kwargs): """ The entire ASSOCIATE construct :class:`fparser.two.Fortran2003.Associate_Construct` has a variable number of children: * Any preceeding comments :class:`fparser.two.Fortran2003.Comment` * :class:`fparser.two.Fortran2003.Associate_Stmt` (the actual statement with the definition of associates) * the body of the ASSOCIATE construct * :class:`fparser.two.Fortran2003.End_Associate_Stmt` """ # Find start and end of associate construct assoc_stmt = get_child(o, Fortran2003.Associate_Stmt) assoc_stmt_index = o.children.index(assoc_stmt) end_assoc_stmt = get_child(o, Fortran2003.End_Associate_Stmt) end_assoc_stmt_index = o.children.index(end_assoc_stmt) # Everything before the associate statement pre = as_tuple(self.visit(c, **kwargs) for c in o.children[:assoc_stmt_index]) # Extract source object for construct source = self.get_source(assoc_stmt, end_node=end_assoc_stmt) # Handle the associates associations = self.visit(assoc_stmt, **kwargs) # Create a scope for the associate parent_scope = kwargs['scope'] associate = ir.Associate(associations=associations, body=(), parent=parent_scope, label=kwargs.get('label'), source=source) kwargs['scope'] = associate # Put associate expressions into the right scope and determine type of new symbols associate._derive_local_symbol_types(parent_scope=parent_scope) # The body body = as_tuple(flatten(self.visit(c, **kwargs) for c in o.children[assoc_stmt_index+1:end_assoc_stmt_index])) associate._update(body=body) # Everything past the END ASSOCIATE (should be empty) assert not o.children[end_assoc_stmt_index+1:] return (*pre, associate) def visit_Associate_Stmt(self, o, **kwargs): """ The ASSOCIATE statement with the association list :class:`fparser.two.Fortran2003.Associate_Stmt` has two children: * The command `ASSOCIATE` (`str`) * The :class:`fparser.two.Fortran2003.Association_List` defining the associations """ assert o.children[0].upper() == 'ASSOCIATE' return self.visit(o.children[1], **kwargs) visit_Association_List = visit_List def visit_Association(self, o, **kwargs): """ A single association in an associate-stmt :class:`fparser.two.Fortran2003.Associate` has two children: * :class:`fparser.two.Fortran2003.Name` (the new assigned name) * the operator ``=>`` (`str`) * :class:`fparser.two.Fortran2003.Name` (the associated expression) """ assert o.children[1] == '=>' associate_name = self.visit(o.children[0], **kwargs) selector = self.visit(o.children[2], **kwargs) return (selector, associate_name) # (associate_name, selector) # # Interface block # def visit_Interface_Block(self, o, **kwargs): """ An ``INTERFACE`` block :class:`fparser.two.Fortran2003.Interface_Block` has variable number of children: * Any preceeding comments :class:`fparser.two.Fortran2003.Comment` * :class:`fparser.two.Fortran2003.Interface_Stmt` (the actual statement that begins the construct) * the body, made up of :class:`fparser.two.Fortran2003.Subroutine_Body`, :class:`fparser.two.Fortran2003.Function_Body`, :class:`fparser.two.Fortran2003.Procedure_Stmt` and, potentially, any interleaving comments :class:`fparser.two.Fortran2003.Comment` * the closing :class:`fparser.two.Fortran2003.End_Interface_Stmt` """ # Find start and end of construct interface_stmt = get_child(o, Fortran2003.Interface_Stmt) interface_stmt_index = o.children.index(interface_stmt) end_interface_stmt = get_child(o, Fortran2003.End_Interface_Stmt) end_interface_stmt_index = o.children.index(end_interface_stmt) # Everything before the construct pre = as_tuple(self.visit(c, **kwargs) for c in o.children[:interface_stmt_index]) # Extract source object for construct source = self.get_source(interface_stmt, end_node=end_interface_stmt) # The interface spec abstract = False spec = self.visit(interface_stmt, **kwargs) if spec == 'ABSTRACT': # This is an abstract interface abstract = True spec = None elif spec is not None: # This has a generic specification (and we might need to update symbol table) scope = kwargs['scope'] spec_type = scope.symbol_attrs.lookup(spec.name) if not spec_type or spec_type.dtype == BasicType.DEFERRED: scope.symbol_attrs[spec.name] = SymbolAttributes( ProcedureType(name=spec.name, is_generic=True) ) spec = spec.rescope(scope=scope) # Traverse the body and build the object body = as_tuple(flatten( self.visit(c, **kwargs) for c in o.children[interface_stmt_index+1:end_interface_stmt_index] )) interface = ir.Interface( body=body, abstract=abstract, spec=spec, label=kwargs.get('label'), source=source ) # Everything past the END INTERFACE (should be empty) assert not o.children[end_interface_stmt_index+1:] return (*pre, interface) def visit_Interface_Stmt(self, o, **kwargs): """ The specification of the interface :class:`fparser.two.Fortran2003.Interface_Stmt` has one child, which is either: * `None`, if no further specification exists * ``'ABSTRACT'`` (`str`) for an abstract interface * :class:`fparser.two.Fortran2003.Generic_Spec` for other specifications """ if o.children[0] == 'ABSTRACT': return 'ABSTRACT' if o.children[0] is not None: return self.visit(o.children[0], **kwargs) return None def visit_Generic_Spec(self, o, **kwargs): """ The generic-spec of an interface :class:`fparser.two.Fortran2003.Generic_Spec` has two children, which is either: * ``'OPERATOR'`` (`str`) followed by * :class:`fparser.two.Fortran2003.Defined_Operator` -or- * ``'ASSIGNMENT'`` (`str`) followed by * ``'='`` (`str`) """ return sym.Variable(name=str(o)) def visit_Procedure_Stmt(self, o, **kwargs): """ Procedure statement :class:`fparser.two.Fortran2003.Procedure_Stmt` has 1 child: * :class:`fparser.two.Fortran2003.Procedure_Name_List`: the names of the procedures """ module_proc = o.string.upper().startswith('MODULE') symbols = self.visit(o.children[0], **kwargs) symbols = AttachScopesMapper()(symbols, scope=kwargs['scope']) return ir.ProcedureDeclaration( symbols=symbols, module=module_proc, source=kwargs.get('source'), label=kwargs.get('label') ) visit_Procedure_Name_List = visit_List visit_Procedure_Name = visit_Name def visit_Import_Stmt(self, o, **kwargs): """ An import statement for named entities in an interface body :class:`fparser.two.Fortran2003.Import_Stmt` has two children: * The string ``'IMPORT'`` * :class:`fparser.two.Fortran2003.Import_Name_List` with the names of imported entities """ assert o.children[0] == 'IMPORT' symbols = self.visit(o.children[1], **kwargs) symbols = AttachScopesMapper()(symbols, scope=kwargs['scope']) return ir.Import( module=None, symbols=symbols, f_import=True, source=kwargs.get('source'), label=kwargs.get('label') ) visit_Import_Name_List = visit_List visit_Import_Name = visit_Name # # Subroutine and Function definitions # def visit_Main_Program(self, o, **kwargs): """ The entire block that comprises a ``PROGRAM`` definition Loki does currently not have support for ``PROGRAM`` blocks, and this will raise a :any:`NotImplementedError` """ self.warn_or_fail('No support for PROGRAM') def visit_Subroutine_Subprogram(self, o, **kwargs): """ The entire block that comprises a ``SUBROUTINE`` definition, i.e. everything from the subroutine-stmt to the end-stmt :class:`fparser.two.Fortran2003.Subroutine_Subprogram` has variable number of children, where the internal nodes may be optional: * :class:`fparser.two.Fortran2003.Subroutine_Stmt` (the opening statement) * :class:`fparser.two.Fortran2003.Specification_Part` (variable declarations, module imports etc.); due to an fparser bug, this can appear multiple times interleaved with the execution-part * :class:`fparser.two.Fortran2003.Execution_Part` (the body of the routine) * :class:`fparser.two.Fortran2003.Internal_Subprogram_Part` (any member procedures declared inside the procedure) * :class:`fparser.two.Fortran2003.End_Subroutine_Stmt` (the final statement) """ # Find start and end of construct subroutine_stmt = get_child(o, Fortran2003.Subroutine_Stmt) subroutine_stmt_index = o.children.index(subroutine_stmt) end_subroutine_stmt = get_child(o, Fortran2003.End_Subroutine_Stmt) end_subroutine_stmt_index = o.children.index(end_subroutine_stmt) # Everything before the construct pre = as_tuple(self.visit(c, **kwargs) for c in o.children[:subroutine_stmt_index]) # ...and there shouldn't be anything after the construct assert end_subroutine_stmt_index + 1 == len(o.children) # Instantiate the object routine, _ = self.visit(subroutine_stmt, **kwargs) kwargs['scope'] = routine # Extract source object for construct source = self.get_source(subroutine_stmt, end_node=end_subroutine_stmt) # Pre-populate internal procedure scopes in the type hierarchy self.create_contained_procedures(get_child(o, Fortran2003.Internal_Subprogram_Part), **kwargs) # Hack: Collect all spec and body parts and use all but the # last body as spec. Reason is that Fparser misinterprets statement # functions as array assignments and thus breaks off spec early part_asts = [ c for c in o.children if isinstance(c, (Fortran2003.Specification_Part, Fortran2003.Execution_Part)) ] if not part_asts: spec_asts = [] body_ast = None elif isinstance(part_asts[-1], Fortran2003.Execution_Part): *spec_asts, body_ast = part_asts else: spec_asts = part_asts body_ast = None # Build the spec by parsing all relevant parts of the AST and appending them # to the same section object spec_parts = [self.visit(spec_ast, **kwargs) for spec_ast in spec_asts] spec_parts = flatten([part.body for part in spec_parts if part is not None]) spec = ir.Section(body=as_tuple(spec_parts)) spec = sanitize_ir(spec, FP, pp_registry=sanitize_registry[FP], pp_info=self.pp_info) # As variables may be defined out of sequence, we need to re-generate # symbols in the spec part to make them coherent with the symbol table spec = AttachScopes().visit(spec, scope=routine, recurse_to_declaration_attributes=True) # Now all declarations are well-defined and we can parse the member routines contains = self.visit(get_child(o, Fortran2003.Internal_Subprogram_Part), **kwargs) # Finally, take care of the body if body_ast is None: body = ir.Section(body=()) else: body = self.visit(body_ast, **kwargs) body = sanitize_ir(body, FP, pp_registry=sanitize_registry[FP], pp_info=self.pp_info) # Workaround for lost StatementFunctions: # Since FParser has no means to identify StmtFuncs, the last set of them # can get lumped in with the body, and we simply need to shift them over. stmt_funcs = tuple(n for n in body.body if isinstance(n, ir.StatementFunction)) if stmt_funcs: idx = body.body.index(stmt_funcs[-1]) + 1 spec._update(body=spec.body + body.body[:idx]) body._update(body=body.body[idx:]) # Extract the leading comments of the specification as "docstring" section docs = _get_comments_from_section(spec) if spec else () # Move trailing comments from spec to the body as those can be pragmas. body.prepend(_get_comments_from_section(spec, include_pragmas=True, reverse=True)) # To complete spec and body, build source objects once we're done moving things around if config['frontend-store-source']: if spec.body: spec_lines = (spec.body[0].source.lines[0], spec.body[-1].source.lines[1]) spec_string = ''.join(self.raw_source[spec_lines[0]-1:spec_lines[1]]).strip('\n') spec._update(source=Source(lines=spec_lines, string=spec_string)) else: # Empty spec source object line = source.lines[0] + 1 spec._update(source=Source(lines=(line, line), string='')) if body.body: body_lines = (body.body[0].source.lines[0], body.body[-1].source.lines[1]) body_string = ''.join(self.raw_source[body_lines[0]-1:body_lines[1]]).rstrip('\n') body._update(source=Source(lines=body_lines, string=body_string)) else: # Empty body source object line = spec.source.lines[1] + 1 body._update(source=Source(lines=(line, line), string='')) # Finally, call the subroutine constructor on the object again to register all # bits and pieces in place and rescope all symbols # pylint: disable=unnecessary-dunder-call routine.__initialize__( name=routine.name, args=routine._dummies, docstring=docs, spec=spec, body=body, contains=contains, ast=o, prefix=routine.prefix, bind=routine.bind, rescope_symbols=False, source=source, incomplete=False ) # Once statement functions are in place, we need to update the original declaration so that it # contains ProcedureSymbols rather than Scalars for decl in FindNodes(ir.VariableDeclaration).visit(spec): if any(routine.symbol_attrs[s.name].is_stmt_func for s in decl.symbols): decl._update(symbols=tuple(s.clone() if routine.symbol_attrs[s.name].is_stmt_func else s for s in decl.symbols)) # Update array shapes with Loki dimension pragmas with pragmas_attached(routine, ir.VariableDeclaration): routine.spec = process_dimension_pragmas(routine.spec, scope=routine) return (*pre, routine) def visit_Function_Subprogram(self, o, **kwargs): """ The entire block that comprises a ``FUNCTION`` definition, i.e. everything from the function-stmt to the end-stmt :class:`fparser.two.Fortran2003.Function_Subprogram` has variable number of children, where the internal nodes may be optional: * :class:`fparser.two.Fortran2003.Function_Stmt` (the opening statement) * :class:`fparser.two.Fortran2003.Specification_Part` (variable declarations, module imports etc.); due to an fparser bug, this can appear multiple times interleaved with the execution-part * :class:`fparser.two.Fortran2003.Execution_Part` (the body of the routine) * :class:`fparser.two.Fortran2003.Internal_Subprogram_Part` (any member procedures declared inside the procedure) * :class:`fparser.two.Fortran2003.End_Function_Stmt` (the final statement) """ # Find start and end of construct function_stmt = get_child(o, Fortran2003.Function_Stmt) function_stmt_index = o.children.index(function_stmt) end_function_stmt = get_child(o, Fortran2003.End_Function_Stmt) end_function_stmt_index = o.children.index(end_function_stmt) # Everything before the construct pre = as_tuple(self.visit(c, **kwargs) for c in o.children[:function_stmt_index]) # ...and there shouldn't be anything after the construct assert end_function_stmt_index + 1 == len(o.children) # Instantiate the object (routine, return_type) = self.visit(function_stmt, **kwargs) kwargs['scope'] = routine # Extract source object for construct source = self.get_source(function_stmt, end_node=end_function_stmt) # Pre-populate internal procedure scopes in the type hierarchy self.create_contained_procedures(get_child(o, Fortran2003.Internal_Subprogram_Part), **kwargs) # Hack: Collect all spec and body parts and use all but the # last body as spec. Reason is that Fparser misinterprets statement # functions as array assignments and thus breaks off spec early part_asts = [ c for c in o.children if isinstance(c, (Fortran2003.Specification_Part, Fortran2003.Execution_Part)) ] if not part_asts: spec_asts = [] body_ast = None elif isinstance(part_asts[-1], Fortran2003.Execution_Part): *spec_asts, body_ast = part_asts else: spec_asts = part_asts body_ast = None # Build the spec by parsing all relevant parts of the AST and appending them # to the same section object spec_parts = [self.visit(spec_ast, **kwargs) for spec_ast in spec_asts] spec_parts = flatten([part.body for part in spec_parts if part is not None]) spec = ir.Section(body=as_tuple(spec_parts)) spec = sanitize_ir(spec, FP, pp_registry=sanitize_registry[FP], pp_info=self.pp_info) # As variables may be defined out of sequence, we need to re-generate # symbols in the spec part to make them coherent with the symbol table spec = AttachScopes().visit(spec, scope=routine, recurse_to_declaration_attributes=True) # If the return type is given, inject it into the symbol table if return_type: routine.symbol_attrs[routine.result_name] = return_type # Now all declarations are well-defined and we can parse the member routines contains = self.visit(get_child(o, Fortran2003.Internal_Subprogram_Part), **kwargs) # Finally, take care of the body if body_ast is None: body = ir.Section(body=()) else: body = self.visit(body_ast, **kwargs) body = sanitize_ir(body, FP, pp_registry=sanitize_registry[FP], pp_info=self.pp_info) # Workaround for lost StatementFunctions: # Since FParser has no means to identify StmtFuncs, the last set of them # can get lumped in with the body, and we simply need to shift them over. stmt_funcs = tuple(n for n in body.body if isinstance(n, ir.StatementFunction)) if stmt_funcs: idx = body.body.index(stmt_funcs[-1]) + 1 spec._update(body=spec.body + body.body[:idx]) body._update(body=body.body[idx:]) # Extract the leading comments of the specification as "docstring" section docs = _get_comments_from_section(spec) if spec else () # Move trailing comments from spec to the body as those can be pragmas. body.prepend(_get_comments_from_section(spec, include_pragmas=True, reverse=True)) # Finally, call the subroutine constructor on the object again to register all # bits and pieces in place and rescope all symbols # pylint: disable=unnecessary-dunder-call routine.__initialize__( name=routine.name, args=routine._dummies, docstring=docs, spec=spec, body=body, contains=contains, ast=o, prefix=routine.prefix, bind=routine.bind, result_name=routine.result_name, rescope_symbols=False, source=source, incomplete=False ) # Once statement functions are in place, we need to update the original declaration so that it # contains ProcedureSymbols rather than Scalars for decl in FindNodes(ir.VariableDeclaration).visit(spec): if any(routine.symbol_attrs[s.name].is_stmt_func for s in decl.symbols): decl._update(symbols=tuple(s.clone() if routine.symbol_attrs[s.name].is_stmt_func else s for s in decl.symbols)) # Update array shapes with Loki dimension pragmas with pragmas_attached(routine, ir.VariableDeclaration): routine.spec = process_dimension_pragmas(routine.spec, scope=routine) return (*pre, routine) visit_Subroutine_Body = visit_Subroutine_Subprogram visit_Function_Body = visit_Function_Subprogram @staticmethod def _get_procedure_from_scope(name, scope=None): """ """ if not scope: return None, None if proc_type := scope.symbol_attrs.get(name): # Look-up only in current scope! if proc_type and proc_type.dtype != BasicType.DEFERRED and \ proc_type.dtype.procedure != BasicType.DEFERRED: return proc_type.dtype.procedure, proc_type return None, None def visit_Function_Stmt(self, o, **kwargs): """ The ``FUNCTION`` statement :class:`fparser.two.Fortran2003.Function_Stmt` has four children: * prefix :class:`fparser.two.Fortran2003.Prefix` * name :class:`fparser.two.Fortran2003.Subroutine_Name` * dummy argument list :class:`fparser.two.Fortran2003.Dummy_Arg_List` * suffix :class:`fparser.two.Fortran2003.Suffix` or language binding spec :class:`fparser.two.Fortran2003.Proc_Language_Binding_Spec` """ from loki.function import Function # pylint: disable=import-outside-toplevel,cyclic-import # Parse the prefix prefix = () return_type = None if o.children[0] is not None: prefix = self.visit(o.children[0], **kwargs) return_type = [i for i in prefix if not isinstance(i, str)] prefix = [i for i in prefix if isinstance(i, str)] assert len(return_type) in (0, 1) return_type = return_type[0] if return_type else None name = self.visit(o.children[1], **kwargs) name = name.name # Check if the Subroutine node has been created before by looking it up in the scope function, proc_type = self._get_procedure_from_scope(name, scope=kwargs.get('scope')) if function and not function._incomplete: # We return the existing object right away, unless it exists from a # previous incomplete parse for which we have to make sure we get a # full parse first return (function, proc_type.dtype.return_type) # Build the dummy argument list if o.children[2] is None: args = () else: dummy_arg_list = self.visit(o.children[2], **kwargs) args = tuple(str(arg) for arg in dummy_arg_list) # Parse suffix, such as result name or language binding specs if isinstance(o.children[3], Fortran2003.Suffix): result, bind = self.visit(o.children[3], **kwargs) else: # Fparser inlines the language-binding spec directly if there is not other suffix result = None bind = None if o.children[3] is None else self.visit(o.children[3], **kwargs) # Instantiate the object if function is None: function = Function( name=name, args=args, prefix=prefix, bind=bind, result_name=result, parent=kwargs['scope'] ) else: function.__initialize__( name=name, args=args, docstring=function.docstring, spec=function.spec, prefix=prefix, bind=bind, result_name=result, incomplete=function._incomplete ) return (function, return_type) def visit_Subroutine_Stmt(self, o, **kwargs): """ The ``SUBROUTINE`` statement :class:`fparser.two.Fortran2003.Subroutine_Stmt` has four children: * prefix :class:`fparser.two.Fortran2003.Prefix` * name :class:`fparser.two.Fortran2003.Subroutine_Name` * dummy argument list :class:`fparser.two.Fortran2003.Dummy_Arg_List` * suffix :class:`fparser.two.Fortran2003.Suffix` or language binding spec :class:`fparser.two.Fortran2003.Proc_Language_Binding_Spec` """ from loki.subroutine import Subroutine # pylint: disable=import-outside-toplevel,cyclic-import # Parse the prefix prefix = () if o.children[0] is not None: prefix = self.visit(o.children[0], **kwargs) prefix = [i for i in prefix if isinstance(i, str)] name = self.visit(o.children[1], **kwargs) name = name.name # Check if the Subroutine node has been created before by looking it up in the scope routine, _ = self._get_procedure_from_scope(name, scope=kwargs.get('scope')) if routine and not routine._incomplete: # We return the existing object right away, unless it exists from a # previous incomplete parse for which we have to make sure we get a # full parse first return routine, None # Build the dummy argument list if o.children[2] is None: args = () else: dummy_arg_list = self.visit(o.children[2], **kwargs) args = tuple(str(arg) for arg in dummy_arg_list) # Parse suffix, such as result name or language binding specs if isinstance(o.children[3], Fortran2003.Suffix): _, bind = self.visit(o.children[3], **kwargs) else: # Fparser inlines the language-binding spec directly if there is not other suffix bind = None if o.children[3] is None else self.visit(o.children[3], **kwargs) # Instantiate the object if routine is None: routine = Subroutine( name=name, args=args, prefix=prefix, bind=bind, parent=kwargs['scope'] ) else: routine.__initialize__( name=name, args=args, docstring=routine.docstring, spec=routine.spec, body=routine.body, contains=routine.contains, prefix=prefix, bind=bind, ast=routine._ast, source=routine._source, incomplete=routine._incomplete ) return (routine, None) visit_Subroutine_Name = visit_Name visit_Function_Name = visit_Name visit_Dummy_Arg_List = visit_List def visit_Prefix(self, o, **kwargs): """ The prefix of a subprogram definition :class:`fparser.two.Fortran2003.Prefix` has variable number of children that have the type * :class:`fparser.two.Fortran2003.Prefix_Spec` to declare attributes * :class:`fparser.two.Fortran2003.Declaration_Type_Spec` (or any of its variations) to declare the return type of a function """ attrs = [self.visit(c, **kwargs) for c in o.children] return as_tuple(attrs) def visit_Prefix_Spec(self, o, **kwargs): """ A prefix keyword in a subprogram definition :class:`fparser.two.Fortran2003.Prefix_Spec` has no children """ return o.string def visit_Suffix(self, o, **kwargs): """ The suffix of a subprogram statement :class:`fparser.two.Fortran2003.Suffix` has two children: * A :class:`fparser.two.Fortran2003.Result_Name` if specified, or None * a :class:`fparser.two.Fortran2003.Language_Binding_Spec` if specified, or None """ result = o.children[0].tostr() if o.children[0] is not None else None bind = self.visit(o.children[1], **kwargs) if o.children[1] is not None else None return result, bind def visit_Language_Binding_Spec(self, o, **kwargs): """ A language binding spec suffix :class:`fparser.two.Fortran2003.Language_Binding_Spec` has a single child: * :class:`fparser.two.Fortran2003.Char_Literal_Constant` with the name of the C routine it binds to """ return self.visit(o.children[0], **kwargs) # # Module definition # def visit_Module(self, o, **kwargs): """ The definition of a Fortran module :class:`fparser.two.Fortran2003.Module` has up to four children: * The opening :class:`fparser.two.Fortran2003.Module_Stmt` * The specification part :class:`fparser.two.Fortran2003.Specification_Part` * The module subprogram part :class:`fparser.two.Fortran2003.Module_Subprogram_Part` * the closing :class:`fparser.two.Fortran2003.End_Module_Stmt` """ # Find start and end of construct module_stmt = get_child(o, Fortran2003.Module_Stmt) module_stmt_index = o.children.index(module_stmt) end_module_stmt = get_child(o, Fortran2003.End_Module_Stmt) end_module_stmt_index = o.children.index(end_module_stmt) # Everything before the construct pre = as_tuple(self.visit(c, **kwargs) for c in o.children[:module_stmt_index]) # ...and there shouldn't be anything after the construct assert end_module_stmt_index + 1 == len(o.children) # Extract source object for construct source = self.get_source(module_stmt, end_node=end_module_stmt) # Instantiate the object module = self.visit(module_stmt, **kwargs) kwargs['scope'] = module # Pre-populate internal procedure scopes in the type hierarchy self.create_contained_procedures(get_child(o, Fortran2003.Module_Subprogram_Part), **kwargs) # Build the spec spec = self.visit(get_child(o, Fortran2003.Specification_Part), **kwargs) spec = sanitize_ir(spec, FP, pp_registry=sanitize_registry[FP], pp_info=self.pp_info) # Infer any additional shape information from `!$loki dimension` pragmas spec = attach_pragmas(spec, ir.VariableDeclaration) spec = process_dimension_pragmas(spec) spec = detach_pragmas(spec, ir.VariableDeclaration) # Extract the leading comments of the specification as "docstring" section docs = _get_comments_from_section(spec) if spec else () # As variables may be defined out of sequence, we need to re-generate # symbols in the spec part to make them coherent with the symbol table spec = AttachScopes().visit(spec, scope=module, recurse_to_declaration_attributes=True) # Now that all declarations are well-defined we can parse the member routines contains = self.visit(get_child(o, Fortran2003.Module_Subprogram_Part), **kwargs) # To complete spec and contains, build source objects once we have everything if config['frontend-store-source']: if spec: if spec.body: spec_lines = (spec.body[0].source.lines[0], spec.body[-1].source.lines[1]) spec_string = ''.join(self.raw_source[spec_lines[0]-1:spec_lines[1]]).strip('\n') spec._update(source=Source(lines=spec_lines, string=spec_string)) else: # Empty spec source object line = source.lines[0] + 1 spec._update(source=Source(lines=(line, line), string='')) if contains: if contains.body: contains_lines = (contains.body[0].source.lines[0], contains.body[-1].source.lines[1]) contains_string = ''.join(self.raw_source[contains_lines[0]-1:contains_lines[1]]).strip('\n') contains._update(source=Source(lines=contains_lines, string=contains_string)) else: # Empty body source object line = spec.source.lines[1] + 1 contains._update(source=Source(lines=(line, line), string='')) # Finally, call the module constructor on the object again to register all # bits and pieces in place and rescope all symbols # pylint: disable=unnecessary-dunder-call module.__initialize__( name=module.name, docstring=docs, spec=spec, contains=contains, default_access_spec=module.default_access_spec, public_access_spec=module.public_access_spec, private_access_spec=module.private_access_spec, ast=o, rescope_symbols=False, source=source, incomplete=False ) return (*pre, module) def visit_Module_Stmt(self, o, **kwargs): """ The ``MODULE`` statement :class:`fparser.two.Fortran2003.Module_Stmt` has 2 children: * keyword `MODULE` (str) * name :class:`fparser.two.Fortran2003.Module_Name` """ from loki.module import Module # pylint: disable=import-outside-toplevel,cyclic-import name = self.visit(o.children[1], **kwargs) name = name.name # Check if the Module node has been created before by looking it up in the scope if kwargs['scope'] is not None and name in kwargs['scope'].symbol_attrs: module_type = kwargs['scope'].symbol_attrs[name] # Look-up only in current scope! if module_type and module_type.dtype.module != BasicType.DEFERRED: return module_type.dtype.module module = Module(name=name, parent=kwargs['scope']) self.definitions[name] = module return module visit_Module_Name = visit_Name # # Conditional # def visit_If_Construct(self, o, **kwargs): """ The entire ``IF`` construct :class:`fparser.two.Fortran2003.If_Construct` has variable number of children: * Any preceeding comments :class:`fparser.two.Fortran2003.Comment` * :class:`fparser.two.Fortran2003.If_Then_Stmt` (the actual statement that begins the construct with the first condition) * the body of the conditional branch * Optionally, one or more :class:`fparser.two.Fortran2003.Else_If_Stmt` followed by their corresponding bodies * Optionally, a :class:`fparser.two.Fortran2003.Else_Stmt` followed by its body * :class:`fparser.two.Fortran2003.End_If_Stmt` """ # Find start and end of construct if_then_stmt = get_child(o, Fortran2003.If_Then_Stmt) if_then_stmt_index = o.children.index(if_then_stmt) end_if_stmt = get_child(o, Fortran2003.End_If_Stmt) end_if_stmt_index = o.children.index(end_if_stmt) # Everything before the IF statement pre = as_tuple(self.visit(c, **kwargs) for c in o.children[:if_then_stmt_index]) # Find all branches else_if_stmts = tuple((i, c) for i, c in enumerate(o.children) if isinstance(c, Fortran2003.Else_If_Stmt)) if else_if_stmts: else_if_stmt_index, else_if_stmts = zip(*else_if_stmts) else: else_if_stmt_index = () # Note: we need to use here the same method as for else-if because finding Else_Stmt # directly and checking its position via o.children.index may give the wrong result. # This is because Else_Stmt may erronously compare equal to other node types. # See https://github.com/stfc/fparser/issues/400 else_stmt = tuple((i, c) for i, c in enumerate(o.children) if isinstance(c, Fortran2003.Else_Stmt)) if else_stmt: assert len(else_stmt) == 1 else_stmt_index, else_stmt = else_stmt[0] else: else_stmt_index = end_if_stmt_index conditions = as_tuple(self.visit(c, **kwargs) for c in (if_then_stmt,) + else_if_stmts) bodies = tuple( tuple(flatten(as_tuple(self.visit(c, **kwargs) for c in o.children[start+1:stop]))) for start, stop in zip( (if_then_stmt_index,) + else_if_stmt_index, else_if_stmt_index + (else_stmt_index,) ) ) else_body = flatten([self.visit(c, **kwargs) for c in o.children[else_stmt_index+1:end_if_stmt_index]]) # Extract source objects for branches sources, labels = [], [] for conditional in (if_then_stmt,) + else_if_stmts: sources += [self.get_source(conditional, end_node=end_if_stmt)] labels += [self.get_label(conditional)] # Build IR nodes backwards using else-if branch as else body body = bodies[-1] node = ir.Conditional(condition=conditions[-1], body=body, else_body=as_tuple(else_body), inline=False, has_elseif=False, label=labels[-1], source=sources[-1]) for idx in reversed(range(len(conditions)-1)): node = ir.Conditional(condition=conditions[idx], body=bodies[idx], else_body=as_tuple(node), inline=False, has_elseif=True, label=labels[idx], source=sources[idx]) # Update with construct name name = if_then_stmt.get_start_name() node._update(name=name) # Everything past the END IF (should be empty) assert not o.children[end_if_stmt_index+1:] return (*pre, node) def visit_If_Then_Stmt(self, o, **kwargs): """ The first conditional in a ``IF`` construct :class:`fparser.two.Fortran2003.If_Then_Stmt` has one child: the condition expression """ return self.visit(o.children[0], **kwargs) visit_Else_If_Stmt = visit_If_Then_Stmt def visit_If_Stmt(self, o, **kwargs): """ An inline ``IF`` statement with a single statement as body :class:`fparser.two.Fortran2003.If_Stmt` has two children: * the condition expression * the body """ cond = self.visit(o.items[0], **kwargs) body = as_tuple(self.visit(o.items[1], **kwargs)) return ir.Conditional(condition=cond, body=body, else_body=(), inline=True, label=kwargs.get('label'), source=kwargs.get('source')) # # SELECT CASE constructs # def visit_Case_Construct(self, o, **kwargs): """ The entire ``SELECT CASE`` construct :class:`fparser.two.Fortran2003.Case_Construct` has variable number of children: * Any preceeding comments :class:`fparser.two.Fortran2003.Comment` * :class:`fparser.two.Fortran2003.Select_Case_Stmt` (the actual statement with the selection expression) * the body of the case-construct, containing one or multiple :class:`fparser.two.Fortran2003.Case_Stmt` followed by their corresponding bodies * :class:`fparser.two.Fortran2003.End_Select_Stmt` """ # Find start and end of case construct select_case_stmt = get_child(o, Fortran2003.Select_Case_Stmt) select_case_stmt_index = o.children.index(select_case_stmt) end_select_stmt = get_child(o, Fortran2003.End_Select_Stmt) end_select_stmt_index = o.children.index(end_select_stmt) # Everything before the SELECT CASE statement pre = as_tuple(self.visit(c, **kwargs) for c in o.children[:select_case_stmt_index]) # Extract source object for construct source = self.get_source(select_case_stmt, end_node=end_select_stmt) # Handle the SELECT CASE statement expr = self.visit(select_case_stmt, **kwargs) name = select_case_stmt.get_start_name() label = self.get_label(select_case_stmt) # Find all CASE statements and corresponding bodies case_stmts, case_stmt_index = zip(*[(c, i) for i, c in enumerate(o.children) if isinstance(c, Fortran2003.Case_Stmt)]) # Retain any comments between `SELECT CASE` and the first `CASE` statement if case_stmt_index[0] > select_case_stmt_index + 1: # Our IR doesn't provide a means to store them in the right place, so # we'll just put them before the `SELECT CASE` pre += as_tuple(self.visit(c, **kwargs) for c in o.children[select_case_stmt_index+1:case_stmt_index[0]]) values = as_tuple(self.visit(c, **kwargs) for c in case_stmts) bodies = tuple( as_tuple(flatten(as_tuple(self.visit(c, **kwargs)) for c in o.children[start+1:stop])) for start, stop in zip(case_stmt_index, case_stmt_index[1:] + (end_select_stmt_index,)) ) if 'DEFAULT' in values: default_index = values.index('DEFAULT') else_body = bodies[default_index] values = values[:default_index] + values[default_index+1:] bodies = bodies[:default_index] + bodies[default_index+1:] else: else_body = () # Everything past the END ASSOCIATE (should be empty) assert not o.children[end_select_stmt_index+1:] case_construct = ir.MultiConditional(expr=expr, values=values, bodies=bodies, else_body=else_body, label=label, name=name, source=source) return (*pre, case_construct) def visit_Select_Case_Stmt(self, o, **kwargs): """ A ``SELECT CASE`` statement for a case-construct :class:`fparser.two.Fortran2003.Select_Case_Stmt` has only one child: the selection expression. """ return self.visit(o.children[0], **kwargs) def visit_Case_Stmt(self, o, **kwargs): """ A ``CASE`` statement in a case-construct :class:`fparser.two.Fortran2003.Case_Stmt` has two children: * the selection expression :class:`fparser.two.Fortran2003.Case_Selector`. * the construct name :class:`fparser.two.Fortran2003.Case_Construct_Name` or `None` """ return self.visit(o.children[0], **kwargs) def visit_Case_Selector(self, o, **kwargs): """ The selector in a ``CASE`` statement :class:`fparser.two.Fortran2003.Case_Selector` has one child: the value-range-list :class:`fparser.two.Fortran2003.Case_Value_Range_List` or `None` for the ``DEFAULT`` case. """ if o.children[0] is None: return 'DEFAULT' return self.visit(o.children[0], **kwargs) def visit_Case_Value_Range(self, o, **kwargs): """ The range of values in a ``CASE`` statement :class:`fparser.two.Fortran2003.Case_Value_Range` has two children: * start :class:`fparser.two.Fortran2003.Case_Value` or `None` * stop :class:`fparser.two.Fortran2003.Case_Value` or `None` """ start = self.visit(o.children[0], **kwargs) if o.children[0] is not None else None stop = self.visit(o.children[1], **kwargs) if o.children[1] is not None else None source = kwargs.get('source') if source: source = source.clone_with_string(o.string) return sym.RangeIndex((start, stop)) visit_Case_Value_Range_List = visit_List # # SELECT TYPE constructs # def visit_Select_Type_Construct(self, o, **kwargs): """ The entire ``SELECT TYPE`` construct :class:`fparser.two.Fortran2003.Select_Type_Construct` has variable number of children: * Any preceeding comments :class:`fparser.two.Fortran2003.Comment` * :class:`fparser.two.Fortran2003.Select_Type_Stmt` (the actual statement with the selection expression) * the body of the case-construct, containing one or multiple :class:`fparser.two.Fortran2003.Type_Guard_Stmt` followed by their corresponding bodies * :class:`fparser.two.Fortran2003.End_Select_Type_Stmt` """ # Find start and end of construct select_type_stmt = get_child(o, Fortran2003.Select_Type_Stmt) select_type_stmt_index = o.children.index(select_type_stmt) end_select_stmt = get_child(o, Fortran2003.End_Select_Type_Stmt) end_select_stmt_index = o.children.index(end_select_stmt) # Everything before the SELECT TYPE statement pre = as_tuple(self.visit(c, **kwargs) for c in o.children[:select_type_stmt_index]) # Extract source object for construct source = self.get_source(select_type_stmt, end_node=end_select_stmt) # Handle the SELECT TYPE statement expr = self.visit(select_type_stmt, **kwargs) name = select_type_stmt.get_start_name() label = self.get_label(select_type_stmt) # Find all CLASS IS/TYPE IS statements and corresponding bodies case_stmts, case_stmt_index = zip(*[(c, i) for i, c in enumerate(o.children) if isinstance(c, Fortran2003.Type_Guard_Stmt)]) # Retain any comments between `SELECT TYPE` and the first `CLASS IS`/`TYPE IS` statement if case_stmt_index[0] > select_type_stmt_index + 1: # Our IR doesn't provide a means to store them in the right place, so # we'll just put them before the `SELECT TYPE` pre += as_tuple(self.visit(c, **kwargs) for c in o.children[select_type_stmt_index+1:case_stmt_index[0]]) # Extract all cases values = as_tuple(self.visit(c, **kwargs) for c in case_stmts) bodies = tuple( as_tuple(flatten(as_tuple(self.visit(c, **kwargs)) for c in o.children[start+1:stop])) for start, stop in zip(case_stmt_index, case_stmt_index[1:] + (end_select_stmt_index,)) ) # Type_Name in the Type_Guard_Stmts will be converted to DerivedType objects, # thus we need to convert them to DerivedTypeSymbol values = tuple( ( sym.DerivedTypeSymbol(name=t.name, scope=kwargs['scope'], type=SymbolAttributes(dtype=t)) if isinstance(t, DerivedType) else t, i ) for (t, i) in values ) if (None, None) in values: # CLASS DEFAULT default_index = values.index((None, None)) else_body = bodies[default_index] values = values[:default_index] + values[default_index+1:] bodies = bodies[:default_index] + bodies[default_index+1:] else: else_body = () # Everything past the END ASSOCIATE (should be empty) assert not o.children[end_select_stmt_index+1:] type_construct = ir.TypeConditional(expr=expr, values=values, bodies=bodies, else_body=else_body, label=label, name=name, source=source) return (*pre, type_construct) def visit_Select_Type_Stmt(self, o, **kwargs): """ A ``SELECT TYPE`` statement for a select-type-construct :class:`fparser.two.Fortran2003.Select_Type_Stmt` has two children: * the associate name or None * the selection expression """ if o.children[0] is not None: raise NotImplementedError('Associate name in Select_Type_Stmt not yet implemented') return self.visit(o.children[1], **kwargs) def visit_Type_Guard_Stmt(self, o, **kwargs): """ A ``CLASS`` or ``TYPE`` statement in a select-type-construct :class:`fparser.two.Fortran2003.Type_Guard_Stmt` has 3 children: * the selection keyword ``CLASS IS`` or ``TYPE IS`` or ``CLASS DEFAULT`` * the selection expression, a :class:`fparser.two.Fortran2003.Type_Name` * the construct name :class:`fparser.two.Fortran2003.Select_Construct_Name` or None """ if o.children[0] == 'CLASS IS': is_polymorphic = True elif o.children[0] == 'TYPE IS': is_polymorphic = False elif o.children[0] == 'CLASS DEFAULT': is_polymorphic = None else: raise ValueError(f'Unsupported first child of Type_Guard_Stmt: {o.children[0]}') return self.visit(o.children[1], **kwargs), is_polymorphic # # Allocation statements # def visit_Allocate_Stmt(self, o, **kwargs): """ A call to ``ALLOCATE`` :class:`fparser.two.Fortran2003.Allocate_Stmt` has three children: * :class:`fparser.two.Fortran2003.Type_Spec` or `None` * :class:`fparser.two.Fortran2003.Allocation_List` * :class:`fparser.two.Fortran2003.Alloc_Opt_List` or `None` """ if o.children[0] is not None: # We can't handle type spec at the moment self.warn_or_fail('type-spec in allocate-stmt not implemented') # Any allocation options. We can only deal with "source" at the moment alloc_opts = {} if o.children[2] is not None: alloc_opts = self.visit(o.children[2], **kwargs) # We need to filter out any options we can't handle currently (and which returned None) alloc_opts = [opt for opt in alloc_opts if opt is not None] alloc_opts = dict(alloc_opts) variables = self.visit(o.children[1], **kwargs) return ir.Allocation( variables=variables, data_source=alloc_opts.get('source'), status_var=alloc_opts.get('stat'), source=kwargs.get('source'), label=kwargs.get('label') ) visit_Allocation_List = visit_List def visit_Allocation(self, o, **kwargs): """ An allocation specification in an allocate-stmt :class:`fparser.two.Fortran2003.Allocation` has two children: * the name of the data object to be allocated: :class:`fparser.two.Fortran2003.Allocate_Object` * the shape of the object: :class:`fparser.two.Fortran2003.Allocate_Shape_Spec_List` """ name = self.visit(o.children[0], **kwargs) shape = self.visit(o.children[1], **kwargs) return name.clone(dimensions=shape) visit_Allocate_Shape_Spec = visit_Explicit_Shape_Spec visit_Allocate_Shape_Spec_List = visit_List visit_Alloc_Opt_List = visit_List visit_Dealloc_Opt_List = visit_List visit_Allocate_Object_List = visit_List def visit_Alloc_Opt(self, o, **kwargs): """ An allocation option in an allocate-stmt :class:`fparser.two.Fortran2003.Alloc_Opt` has two children: * the keyword (`str`) * the option value """ keyword = o.children[0].lower() if keyword in ('source', 'stat'): return keyword, self.visit(o.children[1], **kwargs) # TODO: implement other alloc options self.warn_or_fail(f'Unsupported allocation option: {o.children[0]}') return None def visit_Deallocate_Stmt(self, o, **kwargs): """ A call to ``DEALLOCATE`` :class:`fparser.two.Fortran2003.Deallocate_Stmt` has two children: * the list of objects :class:`fparser.two.Fortran2003.Allocate_Object_List` * list of options :class:`fparser.two.Fortran2003.Dealloc_Opt_list` """ variables = self.visit(o.children[0], **kwargs) dealloc_opts = {} if o.children[1] is not None: dealloc_opts = self.visit(o.children[1], **kwargs) # We need to filter out any options we can't handle currently (and which returned None) dealloc_opts = [opt for opt in dealloc_opts if opt is not None] dealloc_opts = dict(dealloc_opts) return ir.Deallocation( variables=variables, status_var=dealloc_opts.get('stat'), source=kwargs.get('source'), label=kwargs.get('label') ) def visit_Dealloc_Opt(self, o, **kwargs): """ A deallocation option in a deallocate-stmt :class:`fparser.two.Fortran2003.Dealloc_Opt` has two children: * the keyword (`str`) * the option value """ keyword = o.children[0].lower() if keyword == 'stat': return keyword, self.visit(o.children[1], **kwargs) # TODO: implement other alloc options self.warn_or_fail(f'Unsupported deallocation option: {o.children[0]}') return None # # Subroutine and function calls # def visit_Call_Stmt(self, o, **kwargs): """ A ``CALL`` statement :class:`fparser.two.Fortran2003.Call_Stmt` has two children: * the subroutine name :class:`fparser.two.Fortran2003.Procedure_Designator` * the argument list :class:`fparser.two.Fortran2003.Actual_Arg_Spec_List` """ name = self.visit(o.children[0], **kwargs) if o.children[1] is not None: arguments = self.visit(o.children[1], **kwargs) kwarguments = tuple(arg for arg in arguments if isinstance(arg, tuple)) arguments = tuple(arg for arg in arguments if not isinstance(arg, tuple)) else: arguments, kwarguments = (), () return ir.CallStatement(name=name, arguments=arguments, kwarguments=kwarguments, label=kwargs.get('label'), source=kwargs.get('source')) def visit_Procedure_Designator(self, o, **kwargs): """ The function or subroutine designator This appears only when a type-bound procedure is called (as otherwise Fparser hands through the relevant names directly). :class:`fparser.two.Fortran2003.Procedure_Designator` has three children: * Parent name :class:`fparser.two.Fortran2003.Data_Ref` * '%' (`str`) * procedure name :class:`fparser.two.Fortran2003.Binding_Name` """ assert o.children[1] == '%' scope = kwargs.get('scope', None) parent = self.visit(o.children[0], **kwargs) if parent: scope = parent.scope name = self.visit(o.children[2], **kwargs) # Update the name with the type-bound parent symbol return name.clone(name=f'{parent.name}%{name.name}', parent=parent, scope=scope) visit_Actual_Arg_Spec_List = visit_List def visit_Actual_Arg_Spec(self, o, **kwargs): """ A single argument in a subroutine call :class:`fparser.two.Fortran2003.Actual_Arg_Spec` has two children: * keyword :class:`fparser.two.Fortran2003.Keyword` * argument :class:`fparser.two.Fortran2003.Actual_Arg` """ keyword = o.children[0].tostr() if o.children[0] is not None else None arg = self.visit(o.children[1], **kwargs) return (keyword, arg) def visit_Function_Reference(self, o, **kwargs): """ An inline function call :class:`fparser.two.Fortran2003.Actual_Arg_Spec` has two children: * the function name :class:fparser.two.Fortran2003.ProcedureDesignator` * the argument list :class:`fparser.two.Fortran2003.Actual_Arg_Spec_List` """ name = self.visit(o.children[0], **kwargs) if o.children[1] is not None: arguments = self.visit(o.children[1], **kwargs) kwarguments = tuple(arg for arg in arguments if isinstance(arg, tuple)) arguments = tuple(arg for arg in arguments if not isinstance(arg, tuple)) else: arguments, kwarguments = (), () return sym.InlineCall(name, parameters=arguments, kw_parameters=kwarguments) def visit_Intrinsic_Function_Reference(self, o, **kwargs): # Register the ProcedureType in the scope before the name lookup pname = o.children[0].string scope = kwargs['scope'] if not scope.get_symbol_scope(pname): # No known alternative definition; register a true intrinsic procedure type proc_type = ProcedureType( name=pname, is_function=True, is_intrinsic=True, procedure=None ) kwargs['scope'].symbol_attrs[pname] = SymbolAttributes(dtype=proc_type, is_intrinsic=True) # Look up the function symbol name = self.visit(o.children[0], **kwargs) if o.children[1] is not None: arguments = self.visit(o.children[1], **kwargs) kwarguments = tuple(arg for arg in arguments if isinstance(arg, tuple)) arguments = tuple(arg for arg in arguments if not isinstance(arg, tuple)) else: arguments, kwarguments = (), () if str(name).upper() in ('REAL', 'INT'): assert arguments expr = arguments[0] if kwarguments: assert len(arguments) == 1 assert len(kwarguments) == 1 and kwarguments[0][0].lower() == 'kind' kind = kwarguments[0][1] else: kind = arguments[1] if len(arguments) > 1 else None return sym.Cast(name, expr, kind=kind) return sym.InlineCall(name, parameters=arguments, kw_parameters=kwarguments) visit_Intrinsic_Name = visit_Name def visit_Structure_Constructor(self, o, **kwargs): """ Call to the constructor of a derived type :class:`fparser.two.Fortran2003.Structure_Constructor` has two children: * the structure name :class:`fparser.two.Fortran2003.Derived_Type_Spec` * the argument list :class:`fparser.two.Fortran2003.Component_Spec_List` """ # Note: Fparser wrongfully interprets function calls as Structure_Constructor # sometimes. However, we represent constructor calls in the same way, so it # doesn't really matter for us. # This should go away once fparser has a basic symbol table, see # https://github.com/stfc/fparser/issues/201 for some details name = self.visit(o.children[0], **kwargs) assert isinstance(name, DerivedType) scope = kwargs.get('scope', None) # `name` is a DerivedType but we represent a constructor call as InlineCall for # which we need ProcedureSymbol name = sym.Variable(name=name.name, scope=scope) if o.children[1] is not None: arguments = self.visit(o.children[1], **kwargs) kwarguments = tuple(arg for arg in arguments if isinstance(arg, tuple)) arguments = tuple(arg for arg in arguments if not isinstance(arg, tuple)) else: arguments, kwarguments = (), () return sym.InlineCall(name, parameters=arguments, kw_parameters=kwarguments) visit_Component_Spec = visit_Actual_Arg_Spec visit_Component_Spec_List = visit_List # # ENUM declaration # def visit_Enum_Def(self, o, **kwargs): """ The definition of an ``ENUM`` :class:`fparser.two.Fortran2003.Enum_Def` has variable number of children: * Any preceeding comments :class:`fparser.two.Fortran2003.Comment` * :class:`fparser.two.Fortran2003.Enum_Def_Stmt` (the statement indicating the beginning of the enum) * the body of the enum, containing one or multiple :class:`fparser.two.Fortran2003.Enumerator_Def_Stmt` * :class:`fparser.two.Fortran2003.End_Enum_Stmt` """ # Find start end end of construct enum_def_stmt = get_child(o, Fortran2003.Enum_Def_Stmt) enum_def_stmt_index = o.children.index(enum_def_stmt) end_enum_stmt = get_child(o, Fortran2003.End_Enum_Stmt) end_enum_stmt_index = o.children.index(end_enum_stmt) # Everything before the construct pre = as_tuple(self.visit(c, **kwargs) for c in o.children[:enum_def_stmt_index]) # Take out any comments (and other stuff which shouldn't be there) # from inside the enum and put them behind it post = as_tuple( self.visit(c, **kwargs) for c in o.children[enum_def_stmt_index+1:end_enum_stmt_index] if not isinstance(c, Fortran2003.Enumerator_Def_Stmt) ) # Find the constant definitions inside the enum symbols = flatten( self.visit(c, **kwargs) for c in o.children[enum_def_stmt_index+1:end_enum_stmt_index] if isinstance(c, Fortran2003.Enumerator_Def_Stmt) ) # Update type information for symbols with deferred type # (applies to all constant that are defined without explicit value) symbols = tuple( s.clone(type=SymbolAttributes(BasicType.INTEGER)) if s.type.dtype is BasicType.DEFERRED else s for s in symbols ) # Put symbols in the right scope (that should register their type in that scope's symbol table) symbols = tuple(s.rescope(scope=kwargs['scope']) for s in symbols) # Create the enum and make sure there's nothing else left to do source = self.get_source(enum_def_stmt, end_node=end_enum_stmt) enum = ir.Enumeration(symbols=symbols, source=source, label=kwargs['label']) assert end_enum_stmt_index + 1 == len(o.children) return (*pre, enum, *post) def visit_Enumerator_Def_Stmt(self, o, **kwargs): """ A definition inside an ``ENUM`` :class:`fparser.two.Fortran2003.Enumerator_Def_Stmt` has 2 children: * ``'ENUMERATOR'`` (str) * :class:`fparser.two.Fortran2003.Enumerator_List` (the constants) """ return self.visit(o.children[1], **kwargs) visit_Enumerator_List = visit_List def visit_Enumerator(self, o, **kwargs): """ A constant definition within an ``ENUM``'s definition stmt :class:`fparser.two.Fortran2003.Enumerator` has 3 children: * :class:`fparser.two.Fortran2003.Name` (the constant's name) * ``'='`` (str) * the constant's value given as some constant expression that must evaluate to an integer """ assert o.children[1] == '=' symbol = self.visit(o.children[0], **kwargs) initial = self.visit(o.children[2], **kwargs) _type = SymbolAttributes(BasicType.INTEGER, initial=initial) return symbol.clone(type=_type) # # FORALL construct # def visit_Forall_Stmt(self, o, **kwargs): """ Visit and process a single-line FORALL statement: FORALL ( = [, = ] ... [, ]) assign-stmt """ named_bounds, mask = self.visit(o.children[0], **kwargs) # At this point, the body should contain one child. This will be validated during the construction of ir.Forall body = as_tuple(self.visit(child, **kwargs) for child in o.children[1:]) return ir.Forall(named_bounds=named_bounds, mask=mask, body=body, inline=True, source=kwargs.get("source")) def visit_Forall_Construct(self, o, **kwargs): """ Visit and process a multi-line FORALL construct: [name:] FORALL ( = [, = ] ... [, ]) ...body... END FORALL [name] Notes: * Optional `name` of the construct is stored by fparser only in the End_Forall_Stmt at the end, and not in the beginning of the whole statement. * The body can consist of not only assignment statements, but also comments and nested FORALLs """ start = get_child(o, Fortran2003.Forall_Construct_Stmt) start_idx = o.children.index(start) # Anything before the construct (comments and/or pragmas) prelude = as_tuple(self.visit(c, **kwargs) for c in o.children[:start_idx]) # Analyse body of the construct body = node_sublist(o.children, Fortran2003.Forall_Construct_Stmt, Fortran2003.End_Forall_Stmt) # The construct name is the second child of the End_Forall_Stmt (it is not stored in the header by fparser!) end = get_child(o, Fortran2003.End_Forall_Stmt) if name := end.children[1]: name = name.string # In the visit() below, skip the Forall_Constrct_Stmt and go directly to the Forall_Header named_bounds, mask = self.visit(start.children[1], **kwargs) body = as_tuple(self.visit(c, **kwargs) for c in body) source = self.get_source(start, end_node=end) return *prelude, ir.Forall(name=name, named_bounds=named_bounds, mask=mask, body=body, inline=False, source=source) def visit_Forall_Header(self, o, **kwargs): """ Visit FORALL header consisting of variables with their bounds and an optional mask """ # Skip the Forall_Triplet_Spec_List, and go directly into each Forall_Triplet_Spec (named bounds) named_bounds = as_tuple(self.visit(c, **kwargs) for c in o.children[0].children) mask = self.visit(o.children[1], **kwargs) return named_bounds, mask def visit_Forall_Triplet_Spec(self, o, **kwargs): """ Visit a triplet specification consisting of named variable, `=`, and a range (hence, the triplet!) """ # The optional [type::] (integer data type) is not handled by fparser2, # so, the first child is always the variable name variable = self.visit(o.children[0], **kwargs) bounds = as_tuple((self.visit(a, **kwargs) for a in (o.children[1:]))) return variable, sym.Range(bounds) # # WHERE construct # def visit_Where_Construct(self, o, **kwargs): """ Fortran's masked array assignment construct :class:`fparser.two.Fortran2003.Where_Construct` has variable number of children: * Any preceeding comments :class:`fparser.two.Fortran2003.Comment` * :class:`fparser.two.Fortran2003.Where_Construct_Stmt` (the statement that marks the beginning of the construct) * body of the where-construct, usually an assignment * (optional) :class:`fparser.two.Fortran2003.Masked_Elsewhere_Stmt` (essentially an "else-if"), followed by its body; this can appear more than once * (optional) :class:`fparser.two.Fortran2003.Elsewhere_Stmt` (essentially an "else"), followed by its body * :class:`fparser.two.Fortran2003.End_Where_Stmt` """ # Find start and end of construct where_stmt = get_child(o, Fortran2003.Where_Construct_Stmt) where_stmt_index = o.children.index(where_stmt) end_where_stmt = get_child(o, Fortran2003.End_Where_Stmt) end_where_stmt_index = o.children.index(end_where_stmt) # The banter before the construct... pre = as_tuple(self.visit(c, **kwargs) for c in o.children[:where_stmt_index]) # Extract source object for construct source = self.get_source(where_stmt, end_node=end_where_stmt) # Find all ELSEWHERE statements where_stmts, where_stmts_index = zip(*( [(where_stmt, where_stmt_index)] + [ (c, i) for i, c in enumerate(o.children) if isinstance(c, (Fortran2003.Masked_Elsewhere_Stmt, Fortran2003.Elsewhere_Stmt)) ] )) where_stmts_index = where_stmts_index + (end_where_stmt_index,) # Handle all cases conditions = tuple(self.visit(c, **kwargs) for c in where_stmts) bodies = tuple( flatten(as_tuple(self.visit(c, **kwargs) for c in o.children[start+1:stop])) for start, stop in zip(where_stmts_index[:-1], where_stmts_index[1:]) ) # Extract the default case if any if conditions[-1] == 'DEFAULT': conditions = conditions[:-1] *bodies, default = bodies else: default = () # Make sure there's nothing left to do assert not o.children[end_where_stmt_index+1:] masked_statement = ir.MaskedStatement( conditions=conditions, bodies=as_tuple(bodies), default=default, label=kwargs.get('label'), source=source ) return (*pre, masked_statement) def visit_Where_Construct_Stmt(self, o, **kwargs): """ The ``WHERE`` statement that marks the beginning of a where-construct :class:`fparser.two.Fortran2003.Where_Construct_Stmt` has 1 child: * the expression that marks the condition """ return self.visit(o.children[0], **kwargs) def visit_Masked_Elsewhere_Stmt(self, o, **kwargs): """ An ``ELSEWHERE`` statement with a condition in a where-construct :class:`fparser.two.Fortran2003.Masked_Elsewhere_Stmt` has 2 children: * the expression that marks the condition * the construct name or `None` """ if o.children[1] is not None: self.warn_or_fail('where-construct-names not yet implemented') return self.visit(o.children[0], **kwargs) def visit_Elsewhere_Stmt(self, o, **kwargs): """ An unconditional ``ELSEWHERE`` statement :class:`fparser.two.Fortran2003.Elsewhere_Stmt` has 2 children: * ``'ELSEWHERE'`` (str) * the construct name or `None` """ if o.children[1] is not None: self.warn_or_fail('where-construct-names not yet implemented') assert o.children[0] == 'ELSEWHERE' return 'DEFAULT' def visit_Where_Stmt(self, o, **kwargs): """ An inline ``WHERE`` assignment :class:`fparser.two.Fortran2003.Where_Stmt` has 2 children: * the expression that marks the condition * the assignment """ condition = self.visit(o.children[0], **kwargs) body = as_tuple(self.visit(o.children[1], **kwargs)) return ir.MaskedStatement( conditions=(condition, ), bodies=(body, ), default=(), inline=True, label=kwargs.get('label'), source=kwargs.get('source') ) ### Below functions have not yet been revisited ### def visit_Base(self, o, **kwargs): """ Universal default for ``Base`` FParser-AST nodes """ self.warn_or_fail(f'No specific handler for node type {o.__class__}') children = tuple(self.visit(c, **kwargs) for c in o.items if c is not None) if len(children) == 1: return children[0] # Flatten hierarchy if possible return children if len(children) > 0 else None def visit_BlockBase(self, o, **kwargs): """ Universal default for ``BlockBase`` FParser-AST nodes """ self.warn_or_fail(f'No specific handler for node type {o.__class__}') children = tuple(self.visit(c, **kwargs) for c in o.content) children = tuple(c for c in children if c is not None) if len(children) == 1: return children[0] # Flatten hierarchy if possible return children if len(children) > 0 else None def visit_literal(self, o, _type, kind=None, **kwargs): source = kwargs.get('source') if source: source = source.clone_with_string(str(o.items[0])) val = source.string else: val = o.items[0] if kind is not None: if kind.isdigit(): kind = sym.Literal(value=int(kind)) else: kind = AttachScopesMapper()(sym.Variable(name=kind), scope=kwargs['scope']) return sym.Literal(value=val, type=_type, kind=kind) return sym.Literal(value=val, type=_type) def visit_Char_Literal_Constant(self, o, **kwargs): return self.visit_literal(o, BasicType.CHARACTER, **kwargs) def visit_Int_Literal_Constant(self, o, **kwargs): kind = o.items[1] if o.items[1] is not None else None return self.visit_literal(o, BasicType.INTEGER, kind=kind, **kwargs) visit_Signed_Int_Literal_Constant = visit_Int_Literal_Constant def visit_Real_Literal_Constant(self, o, **kwargs): kind = o.items[1] if o.items[1] is not None else None return self.visit_literal(o, BasicType.REAL, kind=kind, **kwargs) visit_Signed_Real_Literal_Constant = visit_Real_Literal_Constant def visit_Logical_Literal_Constant(self, o, **kwargs): return self.visit_literal(o, BasicType.LOGICAL, **kwargs) def visit_Complex_Literal_Constant(self, o, **kwargs): source = kwargs.get('source') if source: source = source.clone_with_string(o.string) val = source.string else: val = o.string return sym.IntrinsicLiteral(value=val) visit_Binary_Constant = visit_Complex_Literal_Constant visit_Octal_Constant = visit_Complex_Literal_Constant visit_Hex_Constant = visit_Complex_Literal_Constant def visit_Include_Stmt(self, o, **kwargs): fname = o.items[0].tostr() return ir.Import(module=fname, f_include=True, source=kwargs.get('source'), label=kwargs.get('label')) def visit_Implicit_Stmt(self, o, **kwargs): return ir.Intrinsic(text=f'IMPLICIT {o.items[0]}', source=kwargs.get('source'), label=kwargs.get('label')) def visit_Print_Stmt(self, o, **kwargs): # NOTE: fparser returns None for an empty print (`PRINT *`) instead of # the usual `Output_Item_List` entity. return ir.Intrinsic(text=f'PRINT {", ".join(str(i) for i in o.items if i is not None)}', source=kwargs.get('source'), label=kwargs.get('label')) # TODO: Deal with line-continuation pragmas! _re_pragma = re.compile(r'^\s*\!\$(?P\w+)\s*(?P.*)', re.IGNORECASE) def visit_Comment(self, o, **kwargs): source = kwargs.get('source', None) match_pragma = self._re_pragma.search(o.tostr()) if match_pragma: # Found pragma, generate this instead gd = match_pragma.groupdict() return ir.Pragma(keyword=gd['keyword'], content=gd['content'], source=source) return ir.Comment(text=o.tostr(), source=source) def visit_Data_Pointer_Object(self, o, **kwargs): v = self.visit(o.items[0], source=kwargs.get('source'), scope=kwargs['scope']) for i in o.items[1:-1]: if i == '%': continue # Careful not to propagate type or dims here v = self.visit(i, parent=v, source=kwargs.get('source'), scope=kwargs['scope']) # Attach types and dims to final leaf variable return self.visit(o.items[-1], parent=v, **kwargs) def visit_Proc_Component_Ref(self, o, **kwargs): '''This is the compound object for accessing procedure components of a variable.''' pname = o.items[0].tostr().lower() v = AttachScopesMapper()(sym.Variable(name=pname), scope=kwargs['scope']) for i in o.items[1:-1]: if i != '%': v = self.visit(i, parent=v, source=kwargs.get('source'), scope=kwargs['scope']) return self.visit(o.items[-1], parent=v, **kwargs) def visit_Block_Nonlabel_Do_Construct(self, o, **kwargs): do_stmt_types = (Fortran2003.Nonlabel_Do_Stmt, Fortran2003.Label_Do_Stmt) # In the banter before the loop, Pragmas are hidden... banter = [] for ch in o.content: if isinstance(ch, do_stmt_types): do_stmt = ch break banter += [self.visit(ch, **kwargs)] else: do_stmt = get_child(o, do_stmt_types) # Extract source by looking at everything between DO and END DO statements end_do_stmt = rget_child(o, Fortran2003.End_Do_Stmt) has_end_do = True if end_do_stmt is None: # We may have a labeled loop with an explicit CONTINUE statement has_end_do = False end_do_stmt = rget_child(o, Fortran2003.Continue_Stmt) assert str(end_do_stmt.item.label) == do_stmt.label.string source = self.get_source(do_stmt, end_node=end_do_stmt) label = self.get_label(do_stmt) construct_name = do_stmt.item.name # Extract loop header and get stepping info variable, bounds = self.visit(do_stmt, **kwargs) # Extract and process the loop body body_nodes = node_sublist(o.content, do_stmt.__class__, Fortran2003.End_Do_Stmt) body = as_tuple(flatten(self.visit(node, **kwargs) for node in body_nodes)) # Loop label for labeled do constructs loop_label = str(do_stmt.items[1]) if isinstance(do_stmt, Fortran2003.Label_Do_Stmt) else None # Select loop type if bounds: obj = ir.Loop(variable=variable, body=body, bounds=bounds, loop_label=loop_label, label=label, name=construct_name, has_end_do=has_end_do, source=source) else: obj = ir.WhileLoop(condition=variable, body=body, loop_label=loop_label, label=label, name=construct_name, has_end_do=has_end_do, source=source) return (*banter, obj, ) visit_Block_Label_Do_Construct = visit_Block_Nonlabel_Do_Construct def visit_Nonlabel_Do_Stmt(self, o, **kwargs): variable, bounds = None, None loop_control = get_child(o, Fortran2003.Loop_Control) if loop_control: variable, bounds = self.visit(loop_control, **kwargs) return variable, bounds visit_Label_Do_Stmt = visit_Nonlabel_Do_Stmt def visit_Loop_Control(self, o, **kwargs): if o.items[0]: # Scalar logical expression return self.visit(o.items[0], **kwargs), None variable = self.visit(o.items[1][0], **kwargs) bounds = as_tuple(flatten(self.visit(a, **kwargs) for a in as_tuple(o.items[1][1]))) return variable, sym.LoopRange(bounds) def visit_Assignment_Stmt(self, o, **kwargs): ptr = isinstance(o, Fortran2003.Pointer_Assignment_Stmt) lhs = self.visit(o.items[0], **kwargs) rhs = self.visit(o.items[2], **kwargs) # Special-case: Identify statement functions using our internal symbol table symbol_attrs = kwargs['scope'].symbol_attrs if isinstance(lhs, sym.Array) and symbol_attrs.lookup(lhs.name) is not None: # If this looks like an array but we have an explicit scalar declaration then # this might in fact be a statement function. # To avoid the costly lookup for declarations on each array assignment, we run through # some sanity checks instead that allow us to bail out early in most cases lhs_type = lhs.type could_be_a_statement_func = not ( lhs_type.shape or lhs_type.length # Declaration with length or dimensions or lhs.parent # Derived type member (we might lack information from enrichment) or lhs_type.intent or lhs_type.imported # Dummy argument or imported from module or isinstance(lhs.scope, ir.Associate) # Symbol stems from an associate ) if could_be_a_statement_func: def _create_stmt_func_type(stmt_func): name = str(stmt_func.variable) procedure = LazyNodeLookup( anchor=kwargs['scope'], query=lambda x: [ f for f in FindNodes(ir.StatementFunction).visit(x.spec) if f.variable == name ][0] ) proc_type = ProcedureType(is_function=True, procedure=procedure, name=name) return SymbolAttributes(dtype=proc_type, is_stmt_func=True) f_symbol = sym.ProcedureSymbol(name=lhs.name, scope=kwargs['scope']) stmt_func = ir.StatementFunction( variable=f_symbol, arguments=lhs.dimensions, rhs=rhs, return_type=symbol_attrs[lhs.name], label=kwargs.get('label'), source=kwargs.get('source') ) # Update the type in the local scope and return stmt func node symbol_attrs[str(stmt_func.variable)] = _create_stmt_func_type(stmt_func) return stmt_func # Return Assignment node if we don't have to deal with the stupid side of Fortran! return ir.Assignment( lhs=lhs, rhs=rhs, ptr=ptr, label=kwargs.get('label'), source=kwargs.get('source') ) visit_Pointer_Assignment_Stmt = visit_Assignment_Stmt def create_operation(self, op, exprs): """ Construct expressions from individual operations. """ exprs = as_tuple(exprs) if op == '*': return sym.Product(exprs) if op == '/': return sym.Quotient(numerator=exprs[0], denominator=exprs[1]) if op == '+': return sym.Sum(exprs) if op == '-': if len(exprs) > 1: # Binary minus return sym.Sum((exprs[0], sym.Product((-1, exprs[1])))) # Unary minus return sym.Product((-1, exprs[0])) if op == '**': return sym.Power(base=exprs[0], exponent=exprs[1]) if op.lower() == '.and.': return sym.LogicalAnd(exprs) if op.lower() == '.or.': return sym.LogicalOr(exprs) if op.lower() in ('==', '.eq.'): return sym.Comparison(exprs[0], '==', exprs[1]) if op.lower() in ('/=', '.ne.'): return sym.Comparison(exprs[0], '!=', exprs[1]) if op.lower() in ('>', '.gt.'): return sym.Comparison(exprs[0], '>', exprs[1]) if op.lower() in ('<', '.lt.'): return sym.Comparison(exprs[0], '<', exprs[1]) if op.lower() in ('>=', '.ge.'): return sym.Comparison(exprs[0], '>=', exprs[1]) if op.lower() in ('<=', '.le.'): return sym.Comparison(exprs[0], '<=', exprs[1]) if op.lower() == '.not.': return sym.LogicalNot(exprs[0]) if op.lower() == '.eqv.': return sym.LogicalOr((sym.LogicalAnd(exprs), sym.LogicalNot(sym.LogicalOr(exprs)))) if op.lower() == '.neqv.': return sym.LogicalAnd((sym.LogicalNot(sym.LogicalAnd(exprs)), sym.LogicalOr(exprs))) if op == '//': return StringConcat(exprs) raise RuntimeError('FParser: Error parsing generic expression') def visit_Add_Operand(self, o, **kwargs): source = kwargs.get('source') if source: source = source.clone_with_string(o.string) if len(o.items) > 2: # Binary operand exprs = [self.visit(o.items[0], **kwargs)] exprs += [self.visit(o.items[2], **kwargs)] return self.create_operation(op=o.items[1], exprs=exprs) # Unary operand exprs = [self.visit(o.items[1], **kwargs)] return self.create_operation(op=o.items[0], exprs=exprs) visit_Mult_Operand = visit_Add_Operand visit_And_Operand = visit_Add_Operand visit_Or_Operand = visit_Add_Operand visit_Equiv_Operand = visit_Add_Operand def visit_Level_2_Expr(self, o, **kwargs): source = kwargs.get('source') if source: source = source.clone_with_string(o.string) e1 = self.visit(o.items[0], **kwargs) e2 = self.visit(o.items[2], **kwargs) return self.create_operation(op=o.items[1], exprs=(e1, e2)) def visit_Level_2_Unary_Expr(self, o, **kwargs): source = kwargs.get('source') if source: source = source.clone_with_string(o.string) exprs = as_tuple(self.visit(o.items[1], **kwargs)) return self.create_operation(op=o.items[0], exprs=exprs) visit_Level_3_Expr = visit_Level_2_Expr visit_Level_4_Expr = visit_Level_2_Expr visit_Level_5_Expr = visit_Level_2_Expr def visit_Parenthesis(self, o, **kwargs): source = kwargs.get('source') expression = self.visit(o.items[1], **kwargs) if source: source = source.clone_with_string(o.string) if isinstance(expression, sym.Sum): expression = ParenthesisedAdd(expression.children) if isinstance(expression, sym.Product): expression = ParenthesisedMul(expression.children) if isinstance(expression, sym.Quotient): expression = ParenthesisedDiv(expression.numerator, expression.denominator) if isinstance(expression, sym.Power): expression = ParenthesisedPow(expression.base, expression.exponent) return expression visit_Format_Stmt = visit_Intrinsic_Stmt visit_Write_Stmt = visit_Intrinsic_Stmt visit_Goto_Stmt = visit_Intrinsic_Stmt visit_Return_Stmt = visit_Intrinsic_Stmt visit_Continue_Stmt = visit_Intrinsic_Stmt visit_Cycle_Stmt = visit_Intrinsic_Stmt visit_Exit_Stmt = visit_Intrinsic_Stmt visit_Save_Stmt = visit_Intrinsic_Stmt visit_Read_Stmt = visit_Intrinsic_Stmt visit_Open_Stmt = visit_Intrinsic_Stmt visit_Close_Stmt = visit_Intrinsic_Stmt visit_Inquire_Stmt = visit_Intrinsic_Stmt visit_Namelist_Stmt = visit_Intrinsic_Stmt visit_Parameter_Stmt = visit_Intrinsic_Stmt visit_Dimension_Stmt = visit_Intrinsic_Stmt visit_Equivalence_Stmt = visit_Intrinsic_Stmt visit_Common_Stmt = visit_Intrinsic_Stmt visit_Stop_Stmt = visit_Intrinsic_Stmt visit_Error_Stop_Stmt = visit_Intrinsic_Stmt visit_Backspace_Stmt = visit_Intrinsic_Stmt visit_Rewind_Stmt = visit_Intrinsic_Stmt visit_Entry_Stmt = visit_Intrinsic_Stmt visit_Cray_Pointer_Stmt = visit_Intrinsic_Stmt def visit_Cpp_If_Stmt(self, o, **kwargs): return ir.PreprocessorDirective(text=o.tostr(), source=kwargs.get('source')) visit_Cpp_Elif_Stmt = visit_Cpp_If_Stmt visit_Cpp_Else_Stmt = visit_Cpp_If_Stmt visit_Cpp_Endif_Stmt = visit_Cpp_If_Stmt visit_Cpp_Macro_Stmt = visit_Cpp_If_Stmt visit_Cpp_Undef_Stmt = visit_Cpp_If_Stmt visit_Cpp_Line_Stmt = visit_Cpp_If_Stmt visit_Cpp_Warning_Stmt = visit_Cpp_If_Stmt visit_Cpp_Error_Stmt = visit_Cpp_If_Stmt visit_Cpp_Null_Stmt = visit_Cpp_If_Stmt def visit_Cpp_Include_Stmt(self, o, **kwargs): fname = o.items[0].tostr() return ir.Import(module=fname, c_import=True, source=kwargs.get('source')) def visit_Nullify_Stmt(self, o, **kwargs): if not o.items[1]: return () variables = as_tuple(flatten(self.visit(v, **kwargs) for v in o.items[1].items)) return ir.Nullify(variables=variables, label=kwargs.get('label'), source=kwargs.get('source')) loki-ecmwf-0.3.6/loki/frontend/source.py0000664000175000017500000005044715167130205020402 0ustar alastairalastair# (C) Copyright 2018- ECMWF. # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. """ Implementation of :any:`Source` and adjacent utilities """ from bisect import bisect_left from enum import Enum, auto from itertools import accumulate, takewhile import re from codetiming import Timer try: from fparser.common.readfortran import Comment as ReadComment, FortranStringReader except ImportError: ReadComment = None FortranStringReader = None from loki.logging import debug, warning __all__ = [ 'Source', 'SourceStatus', 'FortranReader', 'source_to_lines', 'join_source_list' ] class SourceStatus(Enum): """ The node and its children are unchanged - :any:`Source` is valid! """ VALID = auto() """ Interior node properties (expressions) have changed - :any:`Source` is invalid! """ INVALID_NODE = auto() """ Interior node properties have not changed, but child nodes have been altered - :any:`Source` is invalid! """ INVALID_CHILDREN = auto() class Source: """ Store information about the original source for an IR node. Parameters ---------- line : tuple Start and (optional) end line number in original source file string : str (optional) The original raw source string file : str (optional) The file name status : :any:`SourceStatus` Flag indicating if the associated node has been altered """ def __init__(self, lines, string=None, file=None, status=None): assert lines and len(lines) == 2 and (lines[1] is None or lines[1] >= lines[0]) self.lines = lines self.string = string self.file = file self.status = status if status else SourceStatus.VALID def clone(self, **kwargs): """ Replicate the object with the provided overrides. """ if 'lines' not in kwargs: kwargs['lines'] = self.lines if self.string is not None and 'string' not in kwargs: kwargs['string'] = self.string if self.file is not None and 'file' not in kwargs: kwargs['file'] = self.file if self.status is not None and 'status' not in kwargs: kwargs['status'] = self.status return type(self)(**kwargs) def __repr__(self): line_end = f'-{self.lines[1]}' if self.lines[1] else '' return f'Source' def __eq__(self, o): if isinstance(o, Source): return self.__dict__ == o.__dict__ return super().__eq__(o) def __hash__(self): return hash((self.lines, self.string, self.file)) def find(self, string, ignore_case=True, ignore_space=True): """ Find the given string in the source and return start and end index or None if not found. """ if not self.string: return None, None if ignore_case: string = string.lower() self_string = self.string.lower() else: self_string = self.string if string in self_string: # string is contained as is idx = self_string.find(string) return idx, idx + len(string) if ignore_space: # Split the key and try to find individual parts strings = string.strip().split() if strings[0] in self_string: if all(substr in self_string for substr in strings): return (self_string.find(strings[0]), self_string.find(strings[-1]) + len(strings[-1])) return None, None def clone_with_string(self, string, ignore_case=True, ignore_space=True): """ Clone the source object and extract the given string from the original source string or use the provided string. """ cstart, cend = self.find(string, ignore_case=ignore_case, ignore_space=ignore_space) if None not in (cstart, cend): string = self.string[cstart:cend] lstart = self.lines[0] + self.string[:cstart].count('\n') lend = lstart + string.count('\n') lines = (lstart, lend) else: lines = self.lines return Source(lines=lines, string=string, file=self.file) def clone_with_span(self, span): """ Clone the source object and extract the given line span from the original source string (relative to the string length). """ string = self.string[span[0]:span[1]] lstart = self.lines[0] + self.string[:span[0]].count('\n') lend = lstart + string.count('\n') return Source(lines=(lstart, lend), string=string, file=self.file) def clone_lines(self, span=None): """ Create source object clones for each line. """ if span is not None: return self.clone_with_span(span).clone_lines() return [ Source(lines=(self.lines[0]+idx,)*2, string=line, file=self.file) for idx, line in enumerate(self.string.splitlines()) ] def invalidate(self, children=False): """ Set the status of this source to ``SourceStatus.INVALID_NODE`` or ``SourceStatus.INVALID_CHILDREN``. Calling ``source.invalidate()`` marks the entire source object as invalid, while ``source.invalidate(children=True)`` denotes that interior properties the associated node have not been changed, but its children have been invalidated. """ self.status = SourceStatus.INVALID_CHILDREN if children else SourceStatus.INVALID_NODE return self def is_valid(self): """ Returns ``True`` if the :any:`Source` object is still valid. """ return self.status == SourceStatus.VALID class FortranReader: """ Reader for Fortran source strings that provides a sanitized version of the source code It performs the following sanitizer steps: - Remove all comments and preprocessor directives - Remove empty lines - Remove all whitespace at the beginning and end of lines - Resolve all line continuations This enables easier pattern matching in the source code. The original source code can be recovered (with some restrictions) for each position in the sanitized source string. Parameters ---------- raw_source : str The Fortran source code Attributes ---------- source_lines : list The lines of the original source code sanitized_lines : list of :class:`fparser.common.Line` Lines in the sanitized source code sanitized_string : str The sanitized source code sanitized_spans : list of int Start index of each line in the sanitized string """ def __init__(self, raw_source): self.line_offset = 0 raw_source = raw_source.strip() self.source_lines = raw_source.splitlines() self._sanitize_raw_source(raw_source) @Timer(logger=debug, text=lambda s: f'[Loki::Frontend] Executed _sanitize_raw_source in {s:.2f}s') def _sanitize_raw_source(self, raw_source): """ Helper routine to create a sanitized Fortran source string with comments removed and whitespace stripped from line beginning and end """ if FortranStringReader is None: raise RuntimeError('FortranReader needs fparser2') # do not ignore comments during reading as this would also ignore pragmas ... reader = FortranStringReader(raw_source, ignore_comments=False) lines = tuple(l for l in reader) # ...but remove all comments that do not look like pragmas and are not inline comments def is_not_comment(line, prev=None): prev_end = prev.span[1] if prev else 0 return not isinstance(line, ReadComment) or ( line.span[0] > prev_end and line.line[:2] == '!$' ) self.sanitized_lines = tuple(takewhile(is_not_comment, lines[:1])) + tuple( l for l, prev in zip(lines[1:], lines) if is_not_comment(l, prev) ) self.sanitized_spans = (0,) + tuple(accumulate(len(item.line)+1 for item in self.sanitized_lines)) self.sanitized_string = '\n'.join(item.line for item in self.sanitized_lines) def get_line_index(self, line_number): """ Yield the index in :attr:`source_lines` for the given :data:`line_number` """ return line_number - self.line_offset - 1 def get_line_indices_from_span(self, span, include_padding=False): """ Yield the line indices in :attr:`source_lines` and :attr:`sanitized_lines` for the given :data:`span` in the :attr:`sanitized_string` Parameters ---------- span : tuple Start and end in the :attr:`sanitized_string`. The end can optionally be `None`, which includes everything up to the end include_padding : bool (optional) Includes lines from the original source that are missing in the sanitized string (i.e. comments etc.) and that are located immediately before/after the specified span. Returns ------- sanitized_start, sanitized_end, source_start, source_end Start and end indices corresponding to :attr:`sanitized_lines` and :attr:`source_lines`, respectively. Indices for `start` are inclusive and for `end` exclusive (i.e. ``[start, end)``). """ # First, find the corresponding line indices in the sanitized string sanitized_start = bisect_left(self.sanitized_spans, span[0]) if span[1] is None: sanitized_end = len(self.sanitized_lines) else: sanitized_end = bisect_left(self.sanitized_spans, span[1], lo=sanitized_start) sanitized_end = min(len(self.sanitized_lines), sanitized_end) # Next, find the corresponding line indices in the original string if include_padding: if sanitized_start == 0: # Span starts at the beginning of the sanitized string: include everything # before as well source_start = 0 elif sanitized_start >= len(self.sanitized_lines): # Span starts after the sanitized string: include only lines after it source_start = self.get_line_index(self.sanitized_lines[-1].span[1] + 1) elif self.sanitized_lines[sanitized_start].span[0] - self.sanitized_lines[sanitized_start-1].span[1] > 1: # There are lines in the original string that are missing in the sanitized string # between the previous and the start line source_start = self.get_line_index(self.sanitized_lines[sanitized_start-1].span[1] + 1) else: source_start = self.get_line_index(self.sanitized_lines[sanitized_start].span[0]) if sanitized_end == len(self.sanitized_lines): # Span reaches until the end of the sanitized_string: include everything # after it as well source_end = len(self.source_lines) else: # Include everything until (but not including) the line corresponding to the # first line after the span in the sanitized string source_end = self.get_line_index(self.sanitized_lines[sanitized_end].span[0]) elif sanitized_start >= len(self.sanitized_lines): # Span starts after the sanitized string: Point to the first line after it source_start = self.get_line_index(self.sanitized_lines[-1].span[1] + 1) source_end = source_start else: source_start = self.get_line_index(self.sanitized_lines[sanitized_start].span[0]) source_end = self.get_line_index(self.sanitized_lines[sanitized_end-1].span[1] + 1) return sanitized_start, sanitized_end, source_start, source_end def to_source(self, include_padding=False): """ Create a :any:`Source` object with the content of the reader """ if not self.source_lines: string = '' lines = (self.line_offset + 1, self.line_offset + 1) elif include_padding: string = '\n'.join(self.source_lines) lines = (self.line_offset + 1, self.line_offset + len(self.source_lines)) else: lines = (self.sanitized_lines[0].span[0], self.sanitized_lines[-1].span[1]) index = (lines[0] - self.line_offset - 1, lines[1] - self.line_offset) string = '\n'.join(self.source_lines[index[0]:index[1]]) return Source(lines=lines, string=string) def source_from_head(self): """ Create a :any:`Source` object that contains raw source lines present in the original source string before the sanitized source string This means typically comments or preprocessor directives. Returns `None` if there is nothing. """ if not self.source_lines: return None if not self.sanitized_lines: string = '\n'.join(self.source_lines) lines = (self.line_offset + 1, self.line_offset + len(self.source_lines)) return Source(lines=lines, string=string) line_diff = self.sanitized_lines[0].span[0] - self.line_offset if line_diff == 1: return None assert line_diff > 0 string = '\n'.join(self.source_lines[:line_diff - 1]) lines = (self.line_offset + 1, self.sanitized_lines[0].span[0] - 1) return Source(lines=lines, string=string) def source_from_tail(self): """ Create a :any:`Source` object that contains raw source lines present in the original source string after the sanitized source string This means typically comments or preprocessor directives. Returns `None` if there is nothing. """ if not self.sanitized_lines: return None line_diff = len(self.source_lines) + self.line_offset - self.sanitized_lines[-1].span[1] if line_diff == 0: return None assert line_diff > 0 start = self.sanitized_lines[-1].span[1] + 1 string = '\n'.join(self.source_lines[self.get_line_index(start):]) lines = (start, start + line_diff - 1) return Source(lines=lines, string=string) def source_from_sanitized_span(self, span, include_padding=False): """ Create a :any:`Source` object containing the original source string corresponding to the given span in the sanitized string """ *_, source_start, source_end = self.get_line_indices_from_span(span, include_padding) string = '\n'.join(self.source_lines[source_start:source_end]) if not string: return None lines = (self.line_offset + source_start + 1, self.line_offset + source_end) return Source(lines=lines, string=string) def reader_from_sanitized_span(self, span, include_padding=False): """ Create a new :any:`FortranReader` object covering only the source code section corresponding to the given span in the sanitized string """ sanit_start, sanit_end, source_start, source_end = self.get_line_indices_from_span(span, include_padding) if sanit_start >= len(self.sanitized_lines): return None new_reader = FortranReader.__new__(FortranReader) new_reader.line_offset = self.line_offset + source_start new_reader.source_lines = self.source_lines[source_start:source_end] new_reader.sanitized_lines = self.sanitized_lines[sanit_start:sanit_end] span_offset = self.sanitized_spans[sanit_start] new_reader.sanitized_spans = tuple(span - span_offset for span in self.sanitized_spans[sanit_start:sanit_end+1]) if sanit_end + 1 < len(self.sanitized_spans): sanitized_span = [self.sanitized_spans[sanit_start], self.sanitized_spans[sanit_end + 1]] else: sanitized_span = [self.sanitized_spans[sanit_start], None] new_reader.sanitized_string = self.sanitized_string[sanitized_span[0]:sanitized_span[1]] return new_reader def __iter__(self): """Initialize iteration over lines in the sanitized string""" self._current_index = 0 return self def __next__(self): self._current_index += 1 if self._current_index > len(self.sanitized_lines): raise StopIteration return self.current_line @property def current_line(self): """ Return the current line of the iterator or `None` if outside of iteration range """ _current_index = getattr(self, '_current_index', 0) if _current_index <= 0 or _current_index > len(self.sanitized_lines): return None return self.sanitized_lines[_current_index - 1] def source_from_current_line(self): """ Return a :class:`Source` object for the current line """ line = self.current_line start = self.get_line_index(line.span[0]) end = self.get_line_index(line.span[1]) return Source(lines=line.span, string='\n'.join(self.source_lines[start:end+1])) def _merge_source_match_source(pre, match, post): """ Merge a triple of :class:`Source`, :class:`re.Match`, :class:`Source` objects into a single :class:`Source` object spanning multiple lines Helper routine for :any:`source_to_lines`. """ assert isinstance(pre, Source) assert isinstance(match, re.Match) assert isinstance(post, Source) lines = (pre.lines[0], post.lines[1]) return Source(lines, pre.string + post.string, pre.file) def _create_lines_and_merge(source_lines, source, span, lineno=None): """ Create line-wise :class:`Source` objects for the substring in :data:`source` given by :data:`span` If the existing list of source lines ends with (:class:`Source`, :class:`re.Match`), they are joined with the first line in the new substring. Helper routine for :any:`source_to_lines`. """ if lineno is None: new_lines = source.clone_lines(span) else: new_lines = Source((lineno, None), source.string[span[0]:span[1]], source.file).clone_lines() if len(source_lines) >= 2 and isinstance(source_lines[-1], re.Match): source_lines = ( source_lines[:-2] + [_merge_source_match_source(source_lines[-2], source_lines[-1], new_lines[0])] + new_lines[1:] ) else: source_lines += new_lines return source_lines _re_line_cont = re.compile(r'&([ \t]*)\n([ \t]*)(?:&|(?!\!)(?=\S))') """Pattern to match Fortran line continuation.""" def source_to_lines(source): """ Create line-wise :class:`Source` objects, resolving Fortran line-continuation. """ source_lines = [] ptr = 0 lineno = source.lines[0] for match in _re_line_cont.finditer(source.string): source_lines = _create_lines_and_merge(source_lines, source, (ptr, match.span()[0]), lineno=lineno) lineno = source_lines[-1].lines[1] + 1 source_lines += [match] ptr = match.span()[1] if ptr < len(source.string): source_lines = _create_lines_and_merge(source_lines, source, (ptr, len(source.string)), lineno=lineno) return source_lines def join_source_list(source_list): """ Combine a list of :class:`Source` objects into a single object containing the joined source string. This will annotate the joined source object with the maximum range of line numbers provided in :data:`source_list` objects and insert empty lines for any missing line numbers inbetween the provided source objects. """ if not source_list: return None string = source_list[0].string lines = [source_list[0].lines[0], source_list[0].lines[1] or source_list[0].lines[0]] for source in source_list[1:]: newlines = source.lines[0] - lines[1] if newlines < 0: warning('join_source_list: overlapping line range') newlines = 0 string += '\n' * newlines + source.string lines[1] = source.lines[1] if source.lines[1] else lines[1] + newlines + source.string.count('\n') return Source(tuple(lines), string, source_list[0].file) loki-ecmwf-0.3.6/loki/frontend/omni.py0000664000175000017500000017245515167130205020050 0ustar alastairalastair# (C) Copyright 2018- ECMWF. # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. from pathlib import Path from shutil import which import xml.etree.ElementTree as ET from codetiming import Timer from loki.frontend.source import Source from loki.frontend.util import OMNI, sanitize_ir from loki import ir from loki.ir import ( GenericVisitor, FindNodes, Transformer, process_dimension_pragmas, pragmas_attached ) from loki.expression import ( symbols as sym, operations as op, ExpressionDimensionsMapper, StringConcat, AttachScopesMapper ) from loki.logging import debug, info, warning, error from loki.config import config from loki.tools import ( as_tuple, execute, gettempdir, filehash, CaseInsensitiveDict ) from loki.types import BasicType, DerivedType, ProcedureType, SymbolAttributes __all__ = ['HAVE_OMNI', 'parse_omni_source', 'parse_omni_file', 'parse_omni_ast'] HAVE_OMNI = which('F_Front') is not None """Indicate whether OMNI frontend is available.""" @Timer(logger=debug, text=lambda s: f'[Loki::OMNI] Executed parse_omni_file in {s:.2f}s') def parse_omni_file(filename, xmods=None): """ Deploy the OMNI compiler's frontend (F_Front) to generate the OMNI AST. Note that the intermediate XML files can be dumped to file via by setting the environment variable ``LOKI_OMNI_DUMP_XML``. """ if not HAVE_OMNI: error('OMNI is not available. Is "F_Front" in the search path?') dump_xml_files = config['omni-dump-xml'] filepath = Path(filename) info(f'[Loki::OMNI] Parsing {filepath}') xml_path = filepath.with_suffix('.xml') xmods = xmods or [] cmd = ['F_Front', '-fleave-comment'] for m in xmods: cmd += ['-M', f'{Path(m)}'] cmd += [f'{filepath}'] if dump_xml_files: # Parse AST from xml file dumped to disk cmd += ['-o', f'{xml_path}'] execute(cmd) return ET.parse(str(xml_path)).getroot() result = execute(cmd, silent=False, capture_output=True, text=True) return ET.fromstring(result.stdout) @Timer(logger=debug, text=lambda s: f'[Loki::OMNI] Executed parse_omni_source in {s:.2f}s') def parse_omni_source(source, filepath=None, xmods=None): """ Deploy the OMNI compiler's frontend (F_Front) to AST for a source string. """ # Use basename of filepath if given if filepath is None: filepath = Path(filehash(source, prefix='omni-', suffix='.f90')) else: filepath = filepath.with_suffix(f'.omni{filepath.suffix}') # Always store intermediate flies in tmp dir filepath = gettempdir()/filepath.name debug(f'[Loki::OMNI] Writing temporary source {filepath}') with filepath.open('w') as f: f.write(source) return parse_omni_file(filename=filepath, xmods=xmods) @Timer(logger=debug, text=lambda s: f'[Loki::OMNI] Executed parse_omni_ast in {s:.2f}s') def parse_omni_ast(ast, definitions=None, type_map=None, symbol_map=None, raw_source=None, scope=None): """ Generate an internal IR from the raw OMNI parser AST. """ # Parse the raw OMNI language AST _ir = OMNI2IR(type_map=type_map, definitions=definitions, symbol_map=symbol_map, raw_source=raw_source, scope=scope).visit(ast) # Perform some minor sanitation tasks _ir = sanitize_ir(_ir, OMNI) return _ir class OMNI2IR(GenericVisitor): # pylint: disable=unused-argument # Stop warnings about unused arguments _omni_types = { 'Fint': 'INTEGER', 'Freal': 'REAL', 'Flogical': 'LOGICAL', 'Fcharacter': 'CHARACTER', 'Fcomplex': 'COMPLEX', 'int': 'INTEGER', 'real': 'REAL', } def __init__(self, definitions=None, type_map=None, symbol_map=None, raw_source=None, scope=None): super().__init__() self.definitions = CaseInsensitiveDict((d.name, d) for d in as_tuple(definitions)) self.type_map = type_map or {} self.symbol_map = symbol_map or {} self.raw_source = raw_source.splitlines(keepends=True) self.default_scope = scope self.lineno = None # use to save lineno of last element with attribute lineno @staticmethod def warn_or_fail(msg): if config['frontend-strict-mode']: error(msg) raise NotImplementedError warning(msg) def type_from_type_attrib(self, type_attrib, **kwargs): """ Helper routine to derive :any:`SymbolAttributes` for a given type name/hash/id """ if type_attrib in self._omni_types: typename = self._omni_types[type_attrib] _type = SymbolAttributes(BasicType.from_fortran_type(typename)) elif type_attrib in self.type_map: _type = self.visit(self.type_map[type_attrib], **kwargs) dims = self.type_map[type_attrib].findall('indexRange') if dims: dimensions = as_tuple(self.visit(d, **kwargs) for d in dims) _type = _type.clone(shape=dimensions) else: _type = SymbolAttributes(BasicType.from_fortran_type(type_attrib)) return _type def lookup_method(self, instance): """ Alternative lookup method for XML element types, identified by ``element.tag`` """ tag = instance.tag.replace('-', '_') if tag in self._handlers: return self._handlers[tag] return super().lookup_method(instance) def get_source(self, o): """Helper method that builds the source object for a node""" file = o.attrib.get('file', None) lineno = o.attrib.get('lineno', self.lineno) if lineno: self.lineno = int(lineno) lines = (self.lineno, self.lineno) string = self.raw_source[self.lineno-1] else: lines = (None, None) string = None return Source(lines=lines, string=string, file=file) def visit(self, o, **kwargs): # pylint: disable=arguments-differ """ Generic dispatch method that tries to generate meta-data from source. """ kwargs['source'] = self.get_source(o) kwargs.setdefault('scope', self.default_scope) kwargs.setdefault('symbol_map', self.symbol_map) return super().visit(o, **kwargs) def visit_Element(self, o, **kwargs): """ Universal default for XML element types """ warning('No specific handler for node type %s', o.__class__.name) children = tuple(self.visit(c, **kwargs) for c in o) children = tuple(c for c in children if c is not None) if len(children) == 1: return children[0] # Flatten hierarchy if possible return children if len(children) > 0 else None def visit_XcodeProgram(self, o, **kwargs): body = [self.visit(c, **kwargs) for c in o.find('globalDeclarations')] return ir.Section(body=as_tuple(body)) def visit_FuseDecl(self, o, **kwargs): # No ONLY list nature = 'intrinsic' if o.attrib.get('intrinsic') == 'true' else None name = o.attrib['name'] scope = kwargs['scope'] # Rename list rename_list = dict(self.visit(s, **kwargs) for s in o.findall('rename')) module = self.definitions.get(name, None) if module is not None: # Import symbol attributes from module, if available for k, v in module.symbol_attrs.items(): # Don't import private module symbols if v.private or (module.default_access_spec == "private" and not v.public): continue if k in rename_list: local_name = rename_list[k].name scope.symbol_attrs[local_name] = v.clone(imported=True, module=module, use_name=k) else: # Need to explicitly reset use_name in case we are importing a symbol # that stems from an import with a rename-list scope.symbol_attrs[k] = v.clone(imported=True, module=module, use_name=None) elif rename_list: # Module not available but some information via rename-list scope.symbol_attrs.update({v.name: v.type.clone(imported=True, use_name=k) for k, v in rename_list.items()}) rename_list = tuple(rename_list.items()) if rename_list else None return ir.Import(module=name, nature=nature, rename_list=rename_list, c_import=False, source=kwargs['source']) def visit_FuseOnlyDecl(self, o, **kwargs): # ONLY list given (import only selected symbols) nature = 'intrinsic' if o.attrib.get('intrinsic') == 'true' else None name = o.attrib['name'] scope = kwargs['scope'] symbols = tuple(self.visit(c, **kwargs) for c in o.findall('renamable')) if nature == 'intrinsic': module = None else: module = self.definitions.get(name, None) deferred_type = SymbolAttributes(BasicType.DEFERRED, imported=True) if module is None: # Initialize symbol attributes as DEFERRED for s in symbols: if isinstance(s, tuple): # Renamed symbol scope.symbol_attrs[s[1].name] = deferred_type.clone(use_name=s[0]) else: scope.symbol_attrs[s.name] = deferred_type else: # Import symbol attributes from module for s in symbols: if isinstance(s, tuple): # Renamed symbol _type = module.symbol_attrs.get(s[0], deferred_type) scope.symbol_attrs[s[1].name] = _type.clone( imported=True, module=module, use_name=s[0] ) else: # Need to explicitly reset use_name in case we are importing a symbol # that stems from an import with a rename-list _type = module.symbol_attrs.get(s.name, deferred_type) scope.symbol_attrs[s.name] = _type.clone( imported=True, module=module, use_name=None ) symbols = tuple( s[1].rescope(scope=scope) if isinstance(s, tuple) else s.rescope(scope=scope) for s in symbols ) return ir.Import(module=name, symbols=symbols, nature=nature, c_import=False, source=kwargs['source']) def visit_renamable(self, o, **kwargs): name = o.attrib['use_name'] if o.attrib.get('is_operator') == 'true': if name == '=': name = 'ASSIGNMENT(=)' else: name = f'OPERATOR({name})' if o.attrib.get('local_name'): return (name, sym.Variable(name=o.attrib['local_name'])) return sym.Variable(name=name) visit_rename = visit_renamable def visit_FinterfaceDecl(self, o, **kwargs): abstract = o.get('is_abstract') == 'true' if o.get('is_assignment') == 'true': name = 'ASSIGNMENT(=)' elif o.get('is_operator') == 'true': name = f'OPERATOR({o.get("name")})' else: name = o.get('name') if name is not None: scope = kwargs['scope'] if name not in scope.symbol_attrs: scope.symbol_attrs[name] = SymbolAttributes(ProcedureType(name, is_generic=True)) spec = sym.Variable(name=name, scope=kwargs['scope']) else: spec = None body = tuple(self.visit(c, **kwargs) for c in o) return ir.Interface(body=body, abstract=abstract, spec=spec, source=kwargs['source']) def _create_Procedure_object(self, o, scope, symbol_map): """Helper method to instantiate a Subroutine object""" from loki.function import Function # pylint: disable=import-outside-toplevel,cyclic-import from loki.subroutine import Subroutine # pylint: disable=import-outside-toplevel,cyclic-import assert o.tag in ('FfunctionDefinition', 'FfunctionDecl') name = o.find('name').text # Check if the Subroutine node has been created before by looking it up in the scope procedure = None if scope is not None and name in scope.symbol_attrs: proc_type = scope.symbol_attrs[name] # Look-up only in current scope! if proc_type and proc_type.dtype.procedure != BasicType.DEFERRED: procedure = proc_type.dtype.procedure if not procedure._incomplete: # We return the existing object right away, unless it exists from a # previous incomplete parse for which we have to make sure we get a # full parse first return procedure # Return type and dummy args ftype = self.type_map[o.find('name').attrib['type']] if ftype.attrib.get('is_program') == 'true': self.warn_or_fail('No support for PROGRAM') return None proc_type = self.visit(ftype, scope=scope, symbol_map=symbol_map) is_function = ftype.attrib['return_type'] != 'Fvoid' args = tuple(a.text for a in ftype.findall('params/name')) # Function/Subroutine prefix prefix = proc_type.prefix or () if prefix: # We store the prefix on the Subroutine object, so let's remove it from the symbol attrs proc_type = proc_type.clone(prefix=None) # Function suffix (result name and language binding, but no support for the latter in OMNI) result = ftype.attrib.get('result_name') # Instantiate the object if is_function: if procedure is None: procedure = Function( name=name, args=args, prefix=prefix, bind=None, result_name=result, parent=scope, ast=o, source=self.get_source(o) ) else: procedure.__initialize__( name=name, args=args, docstring=procedure.docstring, spec=procedure.spec, body=procedure.body, contains=procedure.contains, prefix=prefix, bind=None, result_name=result, ast=o, source=self.get_source(o), incomplete=procedure._incomplete ) else: if procedure is None: procedure = Subroutine( name=name, args=args, prefix=prefix, bind=None, parent=scope, ast=o, source=self.get_source(o) ) else: procedure.__initialize__( name=name, args=args, docstring=procedure.docstring, spec=procedure.spec, body=procedure.body, contains=procedure.contains, prefix=prefix, bind=None, ast=o, source=self.get_source(o), incomplete=procedure._incomplete ) return procedure def visit_FfunctionDefinition(self, o, **kwargs): # Update the symbol map with local entries kwargs['symbol_map'] = kwargs['symbol_map'].copy() kwargs['symbol_map'].update({s.attrib['type']: s for s in o.find('symbols')}) # Instantiate the object routine = self._create_Procedure_object(o, kwargs['scope'], kwargs['symbol_map']) if routine is None: return None kwargs['scope'] = routine # Parse the spec spec = self.visit(o.find('declarations'), **kwargs) spec = sanitize_ir(spec, OMNI) # Filter out the declaration for the subroutine name but keep it for functions (since # this declares the return type) spec_map = {} if not routine.is_function: spec_map.update({ d: None for d in FindNodes((ir.ProcedureDeclaration, ir.VariableDeclaration)).visit(spec) if routine.name in d.symbols }) # Hack: We remove comments from the beginning of the spec to get the docstring docstring = [] for node in spec.body: if node in spec_map: continue if not isinstance(node, (ir.Comment, ir.CommentBlock)): break docstring.append(node) spec_map[node] = None docstring = as_tuple(docstring) spec = Transformer(spec_map, invalidate_source=False).visit(spec) # Insert the `implicit none` statement OMNI omits (slightly hacky!) f_imports = [im for im in FindNodes(ir.Import).visit(spec) if not im.c_import] if not f_imports: spec.prepend(ir.Intrinsic(text='IMPLICIT NONE')) else: spec.insert(spec.body.index(f_imports[-1])+1, ir.Intrinsic(text='IMPLICIT NONE')) # Parse member functions body_ast = o.find('body') contains_ast = None if body_ast is None else body_ast.find('FcontainsStatement') if contains_ast is not None: contains = self.visit(contains_ast, **kwargs) # Strip contains part from the XML before we proceed body_ast.remove(contains_ast) else: contains = None # Finally, take care of the body if body_ast is None: body = ir.Section(body=()) else: body = ir.Section(body=self.visit(body_ast, **kwargs)) body = sanitize_ir(body, OMNI) # Finally, call the subroutine constructor on the object again to register all # bits and pieces in place and rescope all symbols # pylint: disable=unnecessary-dunder-call if routine.is_function: routine.__initialize__( name=routine.name, args=routine._dummies, docstring=docstring, spec=spec, body=body, contains=contains, ast=o, prefix=routine.prefix, bind=routine.bind, result_name=routine.result_name, rescope_symbols=True, source=routine.source, incomplete=False ) else: routine.__initialize__( name=routine.name, args=routine._dummies, docstring=docstring, spec=spec, body=body, contains=contains, ast=o, prefix=routine.prefix, bind=routine.bind, rescope_symbols=True, source=routine.source, incomplete=False ) # Update array shapes with Loki dimension pragmas with pragmas_attached(routine, ir.VariableDeclaration): routine.spec = process_dimension_pragmas(routine.spec, scope=routine) return routine visit_FfunctionDecl = visit_FfunctionDefinition def visit_FcontainsStatement(self, o, **kwargs): body = [self.visit(c, **kwargs) for c in o] body = [c for c in body if c is not None] body = [ir.Intrinsic('CONTAINS', source=kwargs['source'])] + body return ir.Section(body=as_tuple(body)) def visit_FmoduleProcedureDecl(self, o, **kwargs): symbols = as_tuple(self.visit(o.find('name'), **kwargs)) symbols = AttachScopesMapper()(symbols, scope=kwargs['scope']) return ir.ProcedureDeclaration(symbols=symbols, module=True, source=kwargs.get('source')) def _create_Module_object(self, o, scope): """Helper method to instantiate a Module object""" from loki.module import Module # pylint: disable=import-outside-toplevel,cyclic-import name = o.attrib['name'] # Check if the Module node has been created before by looking it up in the scope if scope is not None and name in scope.symbol_attrs: module_type = scope.symbol_attrs[name] # Look-up only in current scope if module_type and module_type.dtype.module != BasicType.DEFERRED: return module_type.dtype.module module = Module(name=name, parent=scope) self.definitions[name] = module return module def visit_FmoduleDefinition(self, o, **kwargs): # Update the symbol map with local entries kwargs['symbol_map'] = kwargs['symbol_map'].copy() kwargs['symbol_map'].update({s.attrib['type']: s for s in o.find('symbols')}) # Instantiate the object module = self._create_Module_object(o, kwargs['scope']) kwargs['scope'] = module # Pre-populate symbol table with procedure types declared in this module # to correctly classify inline function calls and type-bound procedures contains_ast = o.find('FcontainsStatement') if contains_ast is not None: # Note that we overwrite this variable subsequently with the fully parsed subroutines # where the visit-method for the subroutine/function statement will pick out the existing # subroutine objects using the weakref pointers stored in the symbol table. # I know, it's not pretty but alternatively we could hand down this array as part of # kwargs but that feels like carrying around a lot of bulk, too. contains = [ self._create_Procedure_object(member_ast, kwargs['scope'], kwargs['symbol_map']) for member_ast in contains_ast.findall('FfunctionDefinition') ] # Parse the spec spec = self.visit(o.find('declarations'), **kwargs) spec = sanitize_ir(spec, OMNI) # Hack: We remove comments from the beginning of the spec to get the docstring docstring = [] spec_map = {} for node in spec.body: if node in spec_map: continue if not isinstance(node, (ir.Comment, ir.CommentBlock)): break docstring.append(node) spec_map[node] = None docstring = as_tuple(docstring) spec = Transformer(spec_map, invalidate_source=False).visit(spec) # Parse member functions if contains_ast is not None: contains = self.visit(contains_ast, **kwargs) else: contains = None # Finally, call the module constructor on the object again to register all # bits and pieces in place and rescope all symbols # pylint: disable=unnecessary-dunder-call module.__initialize__( name=module.name, docstring=docstring, spec=spec, contains=contains, ast=o, rescope_symbols=True, source=kwargs['source'], incomplete=False ) return module def visit_declarations(self, o, **kwargs): body = tuple(self.visit(c, **kwargs) for c in o) body = tuple(c for c in body if c is not None) return ir.Section(body=body, source=kwargs['source']) def visit_body(self, o, **kwargs): body = tuple(self.visit(c, **kwargs) for c in o) body = tuple(c for c in body if c is not None) return body def visit_FimportDecl(self, o, **kwargs): symbols = tuple(self.visit(i, **kwargs) for i in o) symbols = AttachScopesMapper()(symbols, scope=kwargs['scope']) return ir.Import( module=None, symbols=symbols, f_import=True, source=kwargs['source'] ) def visit_varDecl(self, o, **kwargs): # OMNI has only one variable per declaration, find and create that name = o.find('name') variable = self.visit(name, **kwargs) interface = None scope = kwargs['scope'] # Create the declared type if name.attrib['type'] in self._omni_types: # Intrinsic scalar type t = self._omni_types[name.attrib['type']] _type = SymbolAttributes(BasicType.from_fortran_type(t)) dimensions = None elif name.attrib['type'] in self.type_map: # Type with attributes or derived type tast = self.type_map[name.attrib['type']] _type = self.visit(tast, **kwargs) dimensions = as_tuple(self.visit(d, **kwargs) for d in tast.findall('indexRange')) if dimensions: _type = _type.clone(shape=dimensions) variable = variable.clone(dimensions=dimensions) if isinstance(_type.dtype, ProcedureType): if _type.dtype.name == 'UNKNOWN': # _Probably_ a declaration with implicit interface dtype = ProcedureType( variable.name, is_function=_type.dtype.is_function, return_type=_type.dtype.return_type ) _type = _type.clone(dtype=dtype) interface = dtype.return_type.dtype if variable != scope.name: # Instantiate the symbol representing the procedure in the current scope to create # relevant symbol table entries, and then extract the dtype try: symbol_scope = scope.get_symbol_scope(_type.dtype.name) interface = symbol_scope.Variable(name=_type.dtype.name) _type = _type.clone(dtype=interface.type.dtype) except AttributeError: # Interface symbol could not be found pass elif _type.dtype.return_type is not None: # This is the declaration of the return type inside a function, which is # why we restore the return_type _type = _type.dtype.return_type # If the return type has a shape, we need to apply this as a dimension to the # variable, otherwise it will be missing from the declaration if _type.shape: variable = variable.clone(dimensions=_type.shape) if tast.attrib.get('is_external') == 'true': _type.external = True else: raise ValueError if o.find('value') is not None: _type = _type.clone(initial=AttachScopesMapper()(self.visit(o.find('value'), **kwargs), scope=scope)) if _type.kind is not None: _type = _type.clone(kind=AttachScopesMapper()(_type.kind, scope=scope)) scope.symbol_attrs[variable.name] = _type variable = variable.rescope(scope=scope) if isinstance(_type.dtype, ProcedureType): # This is actually a function or subroutine (EXTERNAL or PROCEDURE declaration) return ir.ProcedureDeclaration( symbols=(variable,), interface=interface, external=_type.external or False, source=kwargs['source'] ) return ir.VariableDeclaration(symbols=(variable,), source=kwargs['source']) def visit_FstructDecl(self, o, **kwargs): name = o.find('name') struct_type = self.type_map[name.attrib['type']] # Type attributes abstract = struct_type.get('is_abstract') == 'true' if 'extends' in struct_type.attrib: base_type = kwargs['symbol_map'][struct_type.attrib['extends']] extends = base_type.find('name').text else: extends = None bind_c = struct_type.get('bind', '').lower() == 'c' private = struct_type.get('is_private', '').lower() == 'true' public = struct_type.get('is_public', '').lower() == 'true' # Type Parameters if struct_type.find('typeParams') is not None: self.warn_or_fail('Parameterized types not implemented') # Instantiate the TypeDef without its body # Note: This creates the symbol table for the declarations and # the typedef object registers itself in the parent scope typedef = ir.TypeDef( name=name.text, body=(), abstract=abstract, extends=extends, bind_c=bind_c, private=private, public=public, parent=kwargs['scope'], source=kwargs['source'] ) kwargs['scope'] = typedef body = [] # Check if the type is marked as sequence if struct_type.get('is_sequence') == 'true': body += [ir.Intrinsic('SEQUENCE')] # Build the list of derived type members and individual body for each if struct_type.find('symbols') is not None: variables = self.visit(struct_type.find('symbols'), **kwargs) for v in variables: if isinstance(v.type.dtype, ProcedureType): if v.type.dtype.name == v and v.type.dtype.is_function: interface = v.type.dtype.return_type else: iface_name = v.type.dtype.name interface = sym.Variable(name=iface_name, scope=kwargs['scope'].get_symbol_scope(iface_name)) body += [ir.ProcedureDeclaration(symbols=(v,), interface=interface)] else: body += [ir.VariableDeclaration(symbols=(v,))] if struct_type.find('typeBoundProcedures') is not None: # See if components are marked private body += [ir.Intrinsic('CONTAINS')] if struct_type.attrib.get('is_internal_private') == 'true': body += [ir.Intrinsic('PRIVATE')] body += self.visit(struct_type.find('typeBoundProcedures'), **kwargs) # Finally: update the typedef with its body typedef._update(body=as_tuple(body)) typedef.rescope_symbols() return typedef def visit_symbols(self, o, **kwargs): """ Build the list of variables for a `FstructType` node """ variables = [] for s in o: var = self.visit(s.find('name'), **kwargs) _type = self.type_from_type_attrib(s.attrib['type'], **kwargs) kwargs['scope'].symbol_attrs[var.name] = _type if _type.shape: var = var.clone(dimensions=_type.shape) variables += [var.rescope(scope=kwargs['scope'])] return variables def visit_typeBoundProcedures(self, o, **kwargs): procedures = [] for i in o: proc = self.visit(i, **kwargs) if i.get('is_deferred') == 'true': assert proc.type.deferred is True assert proc.type.bind_names and len(proc.type.bind_names) == 1 intf = proc.type.bind_names[0] procedures += [ir.ProcedureDeclaration(interface=intf, symbols=(proc,))] elif i.tag == 'typeBoundGenericProcedure': procedures += [ir.ProcedureDeclaration(symbols=(proc,), generic=True)] elif i.tag == 'finalProcedure': procedures += [ir.ProcedureDeclaration(symbols=(proc,), final=True)] else: procedures += [ir.ProcedureDeclaration(symbols=(proc,))] return procedures def visit_typeBoundProcedure(self, o, **kwargs): scope = kwargs['scope'] var = self.visit(o.find('name'), **kwargs) _type = self.type_from_type_attrib(o.attrib['type'], **kwargs) if o.get('pass') == 'pass': _type = _type.clone(pass_attr=o.get('pass_arg_name', True)) elif o.get('pass') == 'nopass': _type = _type.clone(pass_attr=False) if o.get('is_deferred') == 'true': _type = _type.clone(deferred=True) if o.get('is_non_overridable') == 'true': _type = _type.clone(non_overridable=True) if o.get('is_private') == 'true': _type = _type.clone(private=True) if o.get('is_public') == 'true': _type = _type.clone(public=True) if o.find('binding') is not None: bind_name = self.visit(o.find('binding/name'), **kwargs) bind_name_scope = scope.get_symbol_scope(bind_name.name) # Set correct type for interface/binding if bind_name_scope is not None: bind_name = bind_name.rescope(scope=bind_name_scope) else: bind_name = bind_name.clone(type=bind_name.type.clone(dtype=ProcedureType(bind_name.name))) if bind_name.name.lower() == var.name.lower() and not _type.deferred: # No need to assign bind_names property _type = _type.clone(dtype=bind_name.type.dtype) else: # Assign the binding as bind_nameial (and park the interface here for # declarations with deferred attribute) _type = _type.clone(dtype=bind_name.type.dtype, bind_names=(bind_name,)) scope.symbol_attrs[var.name] = _type return var.rescope(scope=scope) def visit_typeBoundGenericProcedure(self, o, **kwargs): scope = kwargs['scope'] var = self.visit(o.find('name'), **kwargs) _type = SymbolAttributes(ProcedureType(name=var.name, is_generic=True)) if o.get('is_private') == 'true': _type = _type.clone(private=True) if o.get('is_public') == 'true': _type = _type.clone(public=True) assert o.find('binding') is not None bind_names = [] for name in o.findall('binding/name'): bind_name = self.visit(name, **kwargs) bind_name_scope = scope.get_symbol_scope(bind_name.name) # Set correct type for interface/binding if bind_name_scope is not None: bind_name = bind_name.rescope(scope=bind_name_scope) else: bind_name = bind_name.clone(type=bind_name.type.clone(dtype=ProcedureType(bind_name.name))) bind_names += [bind_name] _type = _type.clone(bind_names=as_tuple(bind_names)) scope.symbol_attrs[var.name] = _type return var.rescope(scope=scope) def visit_finalProcedure(self, o, **kwargs): scope = kwargs['scope'] var = self.visit(o.find('name'), **kwargs) _type = scope.symbol_attrs.lookup(var.name) scope.symbol_attrs[var.name] = _type return var.rescope(scope=scope) def visit_FdataDecl(self, o, **kwargs): variable = self.visit(o.find('varList'), **kwargs) values = self.visit(o.find('valueList'), **kwargs) return ir.DataDeclaration(variable=variable, values=values, source=kwargs['source']) def visit_varList(self, o, **kwargs): children = tuple(self.visit(c, **kwargs) for c in o) children = tuple(c for c in children if c is not None) return children visit_valueList = visit_varList def visit_FbasicType(self, o, **kwargs): ref = o.attrib.get('ref', None) if ref in self._omni_types: dtype = BasicType.from_fortran_type(self._omni_types[ref]) kind = self.visit(o.find('kind'), **kwargs) if o.find('kind') is not None else None length = o.find('len') if length is not None: if length == '*': pass elif length.attrib.get('is_assumed_size') == 'true': length = '*' elif length.attrib.get('is_assumed_shape') == 'true': length = ':' else: length = self.visit(length, **kwargs) _type = SymbolAttributes(dtype, kind=kind, length=length) elif ref in self.type_map: if o.find('name') is not None: _type = self.visit(self.type_map[ref], name=o.find('name').text, **kwargs) else: _type = self.visit(self.type_map[ref], **kwargs) if o.attrib.get('is_class') == 'true': _type = _type.clone(polymorphic=True) elif ref == 'FnumericAll': _type = SymbolAttributes(BasicType.DEFERRED) else: raise ValueError shape = o.findall('indexRange') if shape: _type.shape = tuple(self.visit(s, **kwargs) for s in shape) # OMNI types are build recursively from references (Matroshka-style) if o.get('intent') is not None: _type.intent = o.get('intent') if o.get('is_allocatable') == 'true': _type.allocatable = True if o.get('is_pointer') == 'true': _type.pointer = True if o.get('is_optional') == 'true': _type.optional = True if o.get('is_parameter') == 'true': _type.parameter = True if o.get('is_target') == 'true': _type.target = True if o.get('is_contiguous') == 'true': _type.contiguous = True if o.get('is_private') == 'true': _type.private = True if o.get('is_public') == 'true': _type.public = True if o.get('is_save') == 'true': _type.save = True if o.get('is_protected') == 'true': _type.protected = True return _type def visit_FfunctionType(self, o, **kwargs): if o.attrib['return_type'] == 'Fvoid': return_type = None elif o.attrib['return_type'] in self._omni_types: return_type = SymbolAttributes(BasicType.from_fortran_type(self._omni_types[o.attrib['return_type']])) elif o.attrib['return_type'] in self.type_map: return_type = self.visit(self.type_map[o.attrib['return_type']], **kwargs) else: raise ValueError if o.attrib['type'] in kwargs['symbol_map']: name = kwargs['symbol_map'][o.attrib['type']].find('name').text else: name = kwargs.get('name', 'UNKNOWN') dtype = ProcedureType(name, is_function=return_type is not None, return_type=return_type) prefix = [] if o.attrib.get('is_pure') == 'true': prefix += ['PURE'] if o.attrib.get('is_elemental') == 'true': prefix += ['ELEMENTAL'] if o.attrib.get('is_recursive') == 'true': prefix += ['RECURSIVE'] return SymbolAttributes(dtype, prefix=prefix or None) def visit_FstructType(self, o, **kwargs): # We have encountered a derived type as part of the declaration in the spec # of a routine. name = o.attrib['type'] if name in kwargs['symbol_map']: name = kwargs['symbol_map'][name].find('name').text # Check if we know that type already dtype = kwargs['scope'].symbol_attrs.lookup(name, recursive=True) if dtype is None or dtype.dtype == BasicType.DEFERRED: dtype = DerivedType(name=name, typedef=BasicType.DEFERRED) else: dtype = dtype.dtype return SymbolAttributes(dtype) def visit_value(self, o, **kwargs): return self.visit(o[0], **kwargs) visit_kind = visit_value visit_len = visit_value def visit_associateStatement(self, o, **kwargs): associations = tuple(self.visit(c, **kwargs) for c in o.findall('symbols/id')) # Create a scope for the associate parent_scope = kwargs['scope'] associate = ir.Associate(associations=(), body=(), parent=parent_scope, source=kwargs['source']) kwargs['scope'] = associate # Put associate expressions into the right scope and determine type of new symbols rescoped_associations = [] for expr, name in associations: # Put symbols in associated expression into the right scope expr = AttachScopesMapper()(expr, scope=parent_scope) # Determine type of new names if isinstance(expr, (sym.TypedSymbol, sym.MetaSymbol)): # Use the type of the associated variable _type = expr.type.clone(parent=None) if isinstance(expr, sym.Array) and expr.dimensions is not None: shape = ExpressionDimensionsMapper()(expr) if shape == (sym.IntLiteral(1),): # For a scalar expression, we remove the shape shape = None _type = _type.clone(shape=shape) else: # TODO: Handle data type and shape of complex expressions shape = ExpressionDimensionsMapper()(expr) if shape == (sym.IntLiteral(1),): # For a scalar expression, we remove the shape shape = None _type = SymbolAttributes(BasicType.DEFERRED, shape=shape) name = name.clone(scope=associate, type=_type) rescoped_associations += [(expr, name)] associations = as_tuple(rescoped_associations) body = self.visit(o.find('body'), **kwargs) associate._update(associations=associations, body=body) return associate def visit_id(self, o, **kwargs): expr = self.visit(o.find('value'), **kwargs) name = self.visit(o.find('name'), **kwargs) return expr, name def visit_exprStatement(self, o, **kwargs): return self.visit(o[0], **kwargs) def visit_FcommentLine(self, o, **kwargs): return ir.Comment(text=o.text, source=kwargs['source']) def visit_FpragmaStatement(self, o, **kwargs): keyword = o.text.split(' ')[0] content = ' '.join(o.text.split(' ')[1:]) return ir.Pragma(keyword=keyword, content=content, source=kwargs['source']) def visit_FassignStatement(self, o, **kwargs): lhs = self.visit(o[0], **kwargs) rhs = self.visit(o[1], **kwargs) return ir.Assignment(lhs=lhs, rhs=rhs, source=kwargs['source']) def visit_FallocateStatement(self, o, **kwargs): variables = tuple(self.visit(c, **kwargs) for c in o.findall('alloc')) alloc_opts = {} if o.find('allocOpt') is not None: alloc_opts = [self.visit(opt, **kwargs) for opt in o.findall('allocOpt')] alloc_opts = [opt for opt in alloc_opts if opt is not None] alloc_opts = dict(alloc_opts) return ir.Allocation(variables=variables, source=kwargs['source'], data_source=alloc_opts.get('source'), status_var=alloc_opts.get('stat')) def visit_allocOpt(self, o, **kwargs): keyword = o.attrib['kind'].lower() if keyword in ('source', 'stat'): return keyword, self.visit(o[0], **kwargs) self.warn_or_fail(f'Unsupported allocation option: {keyword}') return None def visit_FdeallocateStatement(self, o, **kwargs): variables = tuple(self.visit(c, **kwargs) for c in o.findall('alloc')) alloc_opts = {} if o.find('allocOpt') is not None: alloc_opts = [self.visit(opt, **kwargs) for opt in o.findall('allocOpt')] alloc_opts = [opt for opt in alloc_opts if opt is not None] alloc_opts = dict(alloc_opts) return ir.Deallocation(variables=variables, source=kwargs['source'], status_var=alloc_opts.get('stat')) def visit_FnullifyStatement(self, o, **kwargs): variables = tuple(self.visit(c, **kwargs) for c in o.findall('alloc')) return ir.Nullify(variables=variables, source=kwargs['source']) def visit_alloc(self, o, **kwargs): variable = self.visit(o[0], **kwargs) if o.find('arrayIndex') is not None: dimensions = tuple(self.visit(c, **kwargs) for c in o.findall('arrayIndex')) variable = variable.clone(dimensions=dimensions) return variable def visit_FwhereStatement(self, o, **kwargs): conditions = tuple(self.visit(c, **kwargs) for c in o.findall('condition')) bodies = tuple(self.visit(b, **kwargs) for b in o.findall('then/body')) if o.find('else') is not None: default = self.visit(o.find('else/body'), **kwargs) else: default = () return ir.MaskedStatement(conditions=conditions, bodies=bodies, default=default, source=kwargs['source']) def visit_blockStatement(self, o, **kwargs): if (forall_stmt := o.find('body/forallStatement')) is not None: return self.visit(forall_stmt, **kwargs) self.warn_or_fail('Unsupported blockStatement') return None def visit_forallStatement(self, o, **kwargs): body = self.visit(o.find('body'), **kwargs) named_bounds = () for var, index_range in zip(o.findall('Var'), o.findall('indexRange')): variable = self.visit(var, **kwargs) lower = self.visit(index_range.find('lowerBound'), **kwargs) upper = self.visit(index_range.find('upperBound'), **kwargs) bounds = sym.RangeIndex((lower, upper)) named_bounds += ((variable, bounds),) if (condition := o.find('condition')) is not None: mask = self.visit(condition, **kwargs) else: mask = None return ir.Forall(name=None, named_bounds=named_bounds, body=body, mask=mask, inline=False, source=kwargs.get('source')) def visit_FpointerAssignStatement(self, o, **kwargs): target = self.visit(o[0], **kwargs) expr = self.visit(o[1], **kwargs) return ir.Assignment(lhs=target, rhs=expr, ptr=True, source=kwargs['source']) def visit_FdoWhileStatement(self, o, **kwargs): assert o.find('condition') is not None assert o.find('body') is not None condition = self.visit(o.find('condition'), **kwargs) body = self.visit(o.find('body'), **kwargs) return ir.WhileLoop(condition=condition, body=body, source=kwargs['source']) def visit_FdoStatement(self, o, **kwargs): assert o.find('body') is not None body = self.visit(o.find('body'), **kwargs) if o.find('Var') is None: # We are in an unbound do loop return ir.WhileLoop(condition=None, body=body, source=kwargs['source']) variable = self.visit(o.find('Var'), **kwargs) lower = self.visit(o.find('indexRange/lowerBound'), **kwargs) upper = self.visit(o.find('indexRange/upperBound'), **kwargs) step = self.visit(o.find('indexRange/step'), **kwargs) # Drop OMNI's `:1` step counting for ranges in the name of consistency step = None if step == '1' else step bounds = sym.LoopRange((lower, upper, step)) return ir.Loop(variable=variable, body=body, bounds=bounds, source=kwargs['source']) def visit_FdoLoop(self, o, **kwargs): variable = self.visit(o.find('Var'), **kwargs) lower = self.visit(o.find('indexRange/lowerBound'), **kwargs) upper = self.visit(o.find('indexRange/upperBound'), **kwargs) step = self.visit(o.find('indexRange/step'), **kwargs) # Drop OMNI's `:1` step counting for ranges in the name of consistency step = None if step == '1' else step bounds = sym.LoopRange((lower, upper, step)) values = as_tuple(self.visit(o.find('value'), **kwargs)) return sym.InlineDo(values, variable, bounds) def visit_FifStatement(self, o, **kwargs): condition = self.visit(o.find('condition'), **kwargs) body = self.visit(o.find('then/body'), **kwargs) if o.find('else') is not None: else_body = self.visit(o.find('else/body'), **kwargs) else: else_body = () return ir.Conditional(condition=condition, body=body, else_body=else_body, source=kwargs['source']) def visit_condition(self, o, **kwargs): return self.visit(o[0], **kwargs) def visit_FselectCaseStatement(self, o, **kwargs): expr = self.visit(o.find('value'), **kwargs) cases = tuple(self.visit(case, **kwargs) for case in o.findall('FcaseLabel')) values, bodies = zip(*cases) if None in values: else_index = values.index(None) else_body = as_tuple(bodies[else_index]) values = values[:else_index] + values[else_index+1:] bodies = bodies[:else_index] + bodies[else_index+1:] else: else_body = () # Retain comments before the first case value_idx, case_idx = list(o).index(o.find('value')), list(o).index(o.find('FcaseLabel')) pre = as_tuple(self.visit(c, **kwargs) for c in o[value_idx+1:case_idx]) return ( *pre, ir.MultiConditional(expr=expr, values=values, bodies=bodies, else_body=else_body, source=kwargs['source']) ) def visit_FcaseLabel(self, o, **kwargs): values = [self.visit(value, **kwargs) for value in list(o) if value.tag in ('value', 'indexRange')] if not values: values = None elif len(values) == 1: values = values.pop() body = self.visit(o.find('body'), **kwargs) return as_tuple(values) or None, as_tuple(body) def visit_FenumDecl(self, o, **kwargs): enum_type = self.type_map[o.attrib['type']] # Build the list of symbols symbols = [] for i in enum_type.findall('symbols/id'): var = self.visit(i.find('name'), **kwargs) initial = i.find('value') if initial is not None: initial = self.visit(initial, **kwargs) _type = SymbolAttributes(BasicType.INTEGER, initial=initial) symbols += [var.clone(type=_type)] # Put symbols in the right scope (that should register their type in that scope's symbol table) symbols = tuple(s.rescope(scope=kwargs['scope']) for s in symbols) # Create the enum return ir.Enumeration(symbols=symbols, source=kwargs['source']) def visit_FmemberRef(self, o, **kwargs): parent = self.visit(o.find('varRef'), **kwargs) name = f'{parent.name}%{o.attrib["member"]}' variable = sym.Variable(name=name, parent=parent) return variable def visit_name(self, o, **kwargs): return sym.Variable(name=o.text) visit_Var = visit_name def visit_FarrayRef(self, o, **kwargs): var = self.visit(o.find('varRef'), **kwargs) dimensions = as_tuple(self.visit(i, **kwargs) for i in o[1:]) var = var.clone(dimensions=dimensions) return var def visit_varRef(self, o, **kwargs): return self.visit(o[0], **kwargs) def visit_arrayIndex(self, o, **kwargs): return self.visit(o[0], **kwargs) def visit_indexRange(self, o, **kwargs): lbound = o.find('lowerBound') lower = self.visit(lbound, **kwargs) if lbound is not None else None ubound = o.find('upperBound') upper = self.visit(ubound, **kwargs) if ubound is not None else None st = o.find('step') step = self.visit(st, **kwargs) if st is not None else None # Drop OMNI's `:1` step counting for ranges in the name of consistency step = None if step == '1' else step return sym.RangeIndex((lower, upper, step)) def visit_FcharacterRef(self, o, **kwargs): var = self.visit(o.find('varRef'), **kwargs) dimensions = self.visit(o.find('indexRange'), **kwargs) return sym.StringSubscript(var, dimensions) def visit_lowerBound(self, o, **kwargs): return self.visit(o[0], **kwargs) visit_upperBound = visit_lowerBound visit_step = visit_lowerBound def visit_FrealConstant(self, o, **kwargs): if 'kind' in o.attrib and not 'd' in o.text.lower(): _type = self.visit(self.type_map[o.attrib.get('type')], **kwargs) return sym.Literal(value=o.text, type=BasicType.REAL, kind=_type.kind) return sym.Literal(value=o.text, type=BasicType.REAL) def visit_FlogicalConstant(self, o, **kwargs): return sym.Literal(value=o.text, type=BasicType.LOGICAL) def visit_FcharacterConstant(self, o, **kwargs): return sym.Literal(value=f'"{o.text}"', type=BasicType.CHARACTER) def visit_FintConstant(self, o, **kwargs): if 'kind' in o.attrib: _type = self.visit(self.type_map[o.attrib.get('type')], **kwargs) return sym.Literal(value=int(o.text), type=BasicType.INTEGER, kind=_type.kind) return sym.Literal(value=int(o.text), type=BasicType.INTEGER) def visit_FcomplexConstant(self, o, **kwargs): value = ', '.join(f'{self.visit(v, **kwargs)}' for v in list(o)) return sym.IntrinsicLiteral(value=f'({value})') def visit_FarrayConstructor(self, o, **kwargs): values = as_tuple(self.visit(v, **kwargs) for v in o) if 'element_type' in o.attrib: dtype = self.type_from_type_attrib(o.attrib['element_type']) else: dtype = None return sym.LiteralList(values=values, dtype=dtype) def visit_functionCall(self, o, **kwargs): if 'is_intrinsic' in o.attrib: # Register the ProcedureType in the scope before the name lookup pname = o.find('name').text proc_type = ProcedureType( name=pname, is_function=True, is_intrinsic=True, procedure=None ) kwargs['scope'].symbol_attrs[pname] = SymbolAttributes(dtype=proc_type, is_intrinsic=True) if o.find('name') is not None: name = self.visit(o.find('name'), **kwargs) elif o.find('FmemberRef') is not None: name = self.visit(o.find('FmemberRef'), **kwargs) else: raise ValueError args = o.find('arguments') if args is not None: args = as_tuple(self.visit(a, **kwargs) for a in args) # Separate keyword argument from positional arguments kw_args = as_tuple(arg for arg in args if isinstance(arg, tuple)) args = as_tuple(arg for arg in args if not isinstance(arg, tuple)) else: args, kw_args = (), () if o.attrib.get('type', 'Fvoid') == 'Fvoid': # Subroutine call return ir.CallStatement(name=name, arguments=args, kwarguments=kw_args, source=kwargs['source']) if name.name.lower() in ('real', 'int'): assert args expr = args[0] if kw_args: assert len(args) == 1 assert len(kw_args) == 1 and kw_args[0][0] == 'kind' kind = kw_args[0][1] else: kind = args[1] if len(args) > 1 else None return sym.Cast(name, expr, kind=kind) return sym.InlineCall(name, parameters=args, kw_parameters=kw_args) def visit_FstructConstructor(self, o, **kwargs): _type = self.type_from_type_attrib(o.attrib['type'], **kwargs) assert isinstance(_type.dtype, DerivedType) name = sym.Variable(name=_type.dtype.name) args = [self.visit(a, **kwargs) for a in o] # Separate keyword argument from positional arguments kw_args = as_tuple(arg for arg in args if isinstance(arg, tuple)) args = as_tuple(arg for arg in args if not isinstance(arg, tuple)) return sym.InlineCall(name, parameters=args, kw_parameters=kw_args) def visit_FcycleStatement(self, o, **kwargs): # TODO: do-construct-name is not preserved return ir.Intrinsic(text='cycle', source=kwargs['source']) def visit_continueStatement(self, o, **kwargs): return ir.Intrinsic(text='continue', source=kwargs['source']) def visit_FexitStatement(self, o, **kwargs): # TODO: do-construct-name is not preserved return ir.Intrinsic(text='exit', source=kwargs['source']) def visit_FopenStatement(self, o, **kwargs): nvalues = [self.visit(nv, **kwargs) for nv in o.find('namedValueList')] nargs = ', '.join(f'{k}={v}' for k, v in nvalues) return ir.Intrinsic(text=f'open({nargs})', source=kwargs['source']) def visit_FcloseStatement(self, o, **kwargs): nvalues = [self.visit(nv, **kwargs) for nv in o.find('namedValueList')] nargs = ', '.join(f'{k}={v}' for k, v in nvalues) return ir.Intrinsic(text=f'close({nargs})', source=kwargs['source']) def visit_FreadStatement(self, o, **kwargs): nvalues = [self.visit(nv, **kwargs) for nv in o.find('namedValueList')] values = [self.visit(v, **kwargs) for v in o.find('valueList')] nargs = ', '.join(f'{k}={v}' for k, v in nvalues) args = ', '.join(f'{v}' for v in values) return ir.Intrinsic(text=f'read({nargs}) {args}', source=kwargs['source']) def visit_FwriteStatement(self, o, **kwargs): nvalues = [self.visit(nv, **kwargs) for nv in o.find('namedValueList')] values = [self.visit(v, **kwargs) for v in o.find('valueList')] nargs = ', '.join(f'{k}={v}' for k, v in nvalues) args = ', '.join(f'{v}' for v in values) return ir.Intrinsic(text=f'write({nargs}) {args}', source=kwargs['source']) def visit_FprintStatement(self, o, **kwargs): values = [self.visit(v, **kwargs) for v in o.find('valueList')] args = ', '.join(f'{v}' for v in values) args = f", {args}" if values else "" fmt = o.attrib['format'] return ir.Intrinsic(text=f'print {fmt}{args}', source=kwargs['source']) def visit_FformatDecl(self, o, **kwargs): fmt = f'FORMAT{o.attrib["format"]}' return ir.Intrinsic(text=fmt, source=kwargs['source']) def visit_namedValue(self, o, **kwargs): name = o.attrib['name'] if 'value' in o.attrib: return name, o.attrib['value'] return name, self.visit(list(o)[0], **kwargs) @staticmethod def parenthesize_if_needed(expr, enclosing_cls): # Other than FP/OFP, OMNI does not retain any information about parenthesis in the # original source. While the parse tree is semantically correct, # it may cause problems with some agressively optimising compilers. # We inject manual parenthesis here for nested expressions to make sure # we capture as much of the evaluation order of the original source as possible. # Note: this will result in an abundance of trivial/unnecessary parenthesis! if enclosing_cls in (sym.Product, sym.Quotient): if isinstance(expr, sym.Product): return op.ParenthesisedMul(expr.children) if isinstance(expr, sym.Quotient): return op.ParenthesisedDiv(expr.numerator, expr.denominator) if isinstance(expr, sym.Sum): return op.ParenthesisedAdd(expr.children) if isinstance(expr, sym.Power): return op.ParenthesisedPow(expr.base, expr.exponent) return expr def visit_plusExpr(self, o, **kwargs): exprs = tuple(self.visit(c, **kwargs) for c in o) assert len(exprs) == 2 return sym.Sum(exprs) def visit_minusExpr(self, o, **kwargs): exprs = tuple(self.visit(c, **kwargs) for c in o) assert len(exprs) == 2 return sym.Sum((exprs[0], sym.Product((-1, exprs[1])))) def visit_mulExpr(self, o, **kwargs): exprs = tuple(self.visit(c, **kwargs) for c in o) assert len(exprs) == 2 exprs = tuple(self.parenthesize_if_needed(c, sym.Product) for c in exprs) return sym.Product(exprs) def visit_divExpr(self, o, **kwargs): exprs = tuple(self.visit(c, **kwargs) for c in o) assert len(exprs) == 2 exprs = tuple(self.parenthesize_if_needed(c, sym.Quotient) for c in exprs) return sym.Quotient(*exprs) def visit_FpowerExpr(self, o, **kwargs): exprs = tuple(self.visit(c, **kwargs) for c in o) assert len(exprs) == 2 return sym.Power(base=exprs[0], exponent=exprs[1]) def visit_unaryMinusExpr(self, o, **kwargs): exprs = tuple(self.visit(c, **kwargs) for c in o) assert len(exprs) == 1 return sym.Product((-1, exprs[0])) def visit_logOrExpr(self, o, **kwargs): exprs = tuple(self.visit(c, **kwargs) for c in o) return sym.LogicalOr(exprs) def visit_logAndExpr(self, o, **kwargs): exprs = tuple(self.visit(c, **kwargs) for c in o) return sym.LogicalAnd(exprs) def visit_logNotExpr(self, o, **kwargs): exprs = tuple(self.visit(c, **kwargs) for c in o) assert len(exprs) == 1 return sym.LogicalNot(exprs[0]) def visit_logLTExpr(self, o, **kwargs): exprs = tuple(self.visit(c, **kwargs) for c in o) assert len(exprs) == 2 return sym.Comparison(exprs[0], '<', exprs[1]) def visit_logLEExpr(self, o, **kwargs): exprs = tuple(self.visit(c, **kwargs) for c in o) assert len(exprs) == 2 return sym.Comparison(exprs[0], '<=', exprs[1]) def visit_logGTExpr(self, o, **kwargs): exprs = tuple(self.visit(c, **kwargs) for c in o) assert len(exprs) == 2 return sym.Comparison(exprs[0], '>', exprs[1]) def visit_logGEExpr(self, o, **kwargs): exprs = tuple(self.visit(c, **kwargs) for c in o) assert len(exprs) == 2 return sym.Comparison(exprs[0], '>=', exprs[1]) def visit_logEQExpr(self, o, **kwargs): exprs = tuple(self.visit(c, **kwargs) for c in o) assert len(exprs) == 2 return sym.Comparison(exprs[0], '==', exprs[1]) def visit_logNEQExpr(self, o, **kwargs): exprs = tuple(self.visit(c, **kwargs) for c in o) assert len(exprs) == 2 return sym.Comparison(exprs[0], '!=', exprs[1]) def visit_logEQVExpr(self, o, **kwargs): exprs = tuple(self.visit(c, **kwargs) for c in o) assert len(exprs) == 2 return sym.LogicalOr((sym.LogicalAnd(exprs), sym.LogicalNot(sym.LogicalOr(exprs)))) def visit_logNEQVExpr(self, o, **kwargs): exprs = tuple(self.visit(c, **kwargs) for c in o) assert len(exprs) == 2 return sym.LogicalAnd((sym.LogicalNot(sym.LogicalAnd(exprs)), sym.LogicalOr(exprs))) def visit_FconcatExpr(self, o, **kwargs): exprs = tuple(self.visit(c, **kwargs) for c in o) assert len(exprs) == 2 return StringConcat(exprs) def visit_gotoStatement(self, o, **kwargs): label = int(o.attrib['label_name']) return ir.Intrinsic(text=f'go to {label: d}', source=kwargs['source']) def visit_FstopStatement(self, o, **kwargs): code = o.attrib['code'] return ir.Intrinsic(text=f'stop {code!s}', source=kwargs['source']) def visit_statementLabel(self, o, **kwargs): return ir.Comment('__STATEMENT_LABEL__', label=o.attrib['label_name'], source=kwargs['source']) def visit_FreturnStatement(self, o, **kwargs): return ir.Intrinsic(text='return', source=kwargs['source']) loki-ecmwf-0.3.6/loki/frontend/preprocessing.py0000664000175000017500000002270315167130205021757 0ustar alastairalastair# (C) Copyright 2018- ECMWF. # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. """ Preprocessing utilities for frontends. """ from collections import defaultdict, OrderedDict from pathlib import Path import io import re import pcpp from codetiming import Timer from loki.logging import debug, detail from loki.config import config from loki.tools import as_tuple, gettempdir, filehash from loki.ir import Intrinsic, FindNodes from loki.frontend.util import OMNI, FP, REGEX __all__ = ['preprocess_cpp', 'sanitize_input', 'sanitize_registry', 'PPRule'] def preprocess_cpp(source, filepath=None, includes=None, defines=None): """ Invoke an external C-preprocessor to sanitize input files. Note that the global option ``LOKI_CPP_DUMP_FILES`` will cause the intermediate preprocessed source to be written to a temporary file in ``LOKI_TMP_DIR``. Parameters ---------- source : str Source string to preprocess via ``pcpp`` filepath : str or pathlib.Path Optional filepath name, which will be used to derive the filename should intermediate file dumping be enabled via the global config. includes : (list of) str Include paths for the C-preprocessor. defines : (list of) str Symbol definitions to add to the C-preprocessor. """ class _LokiCPreprocessor(pcpp.Preprocessor): def on_comment(self, tok): # pylint: disable=unused-argument # Pass through C-style comments return True def on_error(self, file, line, msg): # Redirect CPP error to our logger and increment return code debug(f'[Loki-CPP] {file}:{line: d} error: {msg}') self.return_code += 1 # Add include paths to PP pp = _LokiCPreprocessor() # Suppress line directives pp.line_directive = None for i in as_tuple(includes): pp.add_path(str(i)) # Add and sanitize defines to PP for d in as_tuple(defines): if '=' not in d: d += '=1' d = d.replace('=', ' ', 1) pp.define(d) # Parse source through preprocessor pp.parse(source) if config['cpp-dump-files']: if filepath is None: pp_path = Path(filehash(source, suffix='.cpp.f90')) else: pp_path = filepath.with_suffix('.cpp.f90') pp_path = gettempdir()/pp_path.name debug(f'[Loki] C-preprocessor, writing {str(pp_path)}') # Dump preprocessed source to file and read it with pp_path.open('w') as f: pp.write(f) with pp_path.open('r') as f: preprocessed = f.read() return preprocessed # Return the preprocessed string s = io.StringIO() pp.write(s) return s.getvalue() @Timer(logger=detail, text=lambda s: f'[Loki::Frontend] Executed sanitize_input in {s:.2f}s') def sanitize_input(source, frontend): """ Apply internal regex-based sanitisation rules to filter out known frontend incompatibilities. Note that this will create a record of all things stripped (``pp_info``), which will be used to re-insert the dropped source info when converting the parsed AST to our IR. The ``sanitize_registry`` (see below) holds pre-defined rules for each frontend. """ # Apply preprocessing rules and store meta-information pp_info = OrderedDict() for name, rule in sanitize_registry[frontend].items(): # Apply rule filter over source file rule.reset() new_source = '' for ll, line in enumerate(source.splitlines(keepends=True)): ll += 1 # Correct for Fortran counting new_source += rule.filter(line, lineno=ll) # Store met-information from rule pp_info[name] = rule.info source = new_source return source, pp_info def reinsert_convert_endian(ir, pp_info): """ Reinsert the CONVERT='BIG_ENDIAN' or CONVERT='LITTLE_ENDIAN' arguments into calls to OPEN. """ if pp_info: for intr in FindNodes(Intrinsic).visit(ir): match = pp_info.get(intr.source.lines[0], [None])[0] if match is not None: source = intr.source assert source is not None text = match['ws'] + match['pre'] + match['convert'] + match['post'] if match['post'].rstrip().endswith('&'): cont_line_index = source.string.find(match['post']) + len(match['post']) text += source.string[cont_line_index:].rstrip() source.string = text intr._update(text=text, source=source) return ir def reinsert_open_newunit(ir, pp_info): """ Reinsert the NEWUNIT=... arguments into calls to OPEN. """ if pp_info: for intr in FindNodes(Intrinsic).visit(ir): match = pp_info.get(intr.source.lines[0], [None])[0] if match is not None: source = intr.source assert source is not None text = match['ws'] + match['open'] + match['args1'] + (match['delim'] or '') text += match['newunit_key'] + match['newunit_val'] + match['args2'] if match['args2'].rstrip().endswith('&'): cont_line_index = source.string.find(match['args2']) + len(match['args2']) text += source.string[cont_line_index:].rstrip() source.string = text intr._update(text=text, source=source) return ir class PPRule: """ A preprocessing rule that defines and applies a source replacement and collects associated meta-data. """ _empty_pattern = re.compile('') def __init__(self, match, replace, postprocess=None): self.match = match self.replace = replace self._postprocess = postprocess self._info = defaultdict(list) def reset(self): self._info = defaultdict(list) def filter(self, line, lineno): """ Filter a source line by matching the given rule and storing meta-content. """ if isinstance(self.match, type(self._empty_pattern)): # Apply a regex pattern to the line and return 'all' for info in self.match.finditer(line): self._info[lineno] += [info.groupdict()] return self.match.sub(self.replace, line) # Apply a regular string replacement if self.match in line: self._info[lineno] += [(self.match, self.replace)] return line.replace(self.match, self.replace) @property def info(self): """ Meta-information that will be dumped alongside preprocessed source files to re-insert information into a fully parsed IR tree. """ return self._info def postprocess(self, ir, info): if self._postprocess is not None: return self._postprocess(ir, info) return ir sanitize_registry = { REGEX: { # Strip line annotations from Fypp preprocessor 'FYPP ANNOTATIONS': PPRule(match=re.compile(r'(# [1-9].*\".*\.(?:fypp|hypp)\"(?:\s+\d+)?\n)'), replace=''), }, OMNI: {}, FP: { # Remove various IBM directives 'IBM_DIRECTIVES': PPRule(match=re.compile(r'(@PROCESS.*\n)'), replace='\n'), # Enquote string CPP directives in Fortran source lines to make them string constants # Note: this is a bit tricky as we need to make sure that we don't replace it inside CPP # directives as this can produce invalid code 'STRING_PP_DIRECTIVES': PPRule( match=re.compile(( r'(?P^\s*#.*__(?:FILE|FILENAME|DATE|VERSION)__)|' # Match inside a directive r'(?P__(?:FILE|FILENAME|DATE|VERSION)__)')), # Match elsewhere replace=lambda m: m['pp'] or f'"{m["else"]}"'), # Replace integer CPP directives by 0 'INTEGER_PP_DIRECTIVES': PPRule(match='__LINE__', replace='0'), # Replace CONVERT argument in OPEN calls 'CONVERT_ENDIAN': PPRule( match=re.compile((r'(?P^\s*)(?P
OPEN\s*\(.*?)'
                              r'(?P,?\s*CONVERT=[\'\"](?:BIG|LITTLE)_ENDIAN[\'\"]\s*)'
                              r'(?P.*?$)'), re.I),
            replace=r'\g\g
\g', postprocess=reinsert_convert_endian),

        # Replace NEWUNIT argument in OPEN calls
        'OPEN_NEWUNIT': PPRule(
            match=re.compile((r'(?P^\s*)(?POPEN\s*\()(?P.*?)(?P,)?'
                              r'(?P,?\s*NEWUNIT=)(?P.*?(?=,|\)|&))'
                              r'(?P.*?$)'), re.I),
            replace=lambda m: f'{m["ws"]}{m["open"]}{m["newunit_val"]}{m["delim"] or ""}' +
                              f'{m["args1"]}{m["args2"]}',
            postprocess=reinsert_open_newunit),

        # Strip line annotations from Fypp preprocessor
        'FYPP ANNOTATIONS': PPRule(match=re.compile(r'(# [1-9].*\".*\.(?:fypp|hypp)\"(?:\s+\d+)?\n)'), replace=''),
    }
}
"""
The frontend sanitization registry dict holds workaround rules for
Fortran features that cause bugs and failures in frontends. It's
mostly a regex expression that removes certains strings and stores
them, so that they can be re-inserted into the IR by a callback.
"""
loki-ecmwf-0.3.6/loki/tests/0000775000175000017500000000000015167130205016041 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/tests/__init__.py0000664000175000017500000000057015167130205020154 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
loki-ecmwf-0.3.6/loki/tests/test_sourcefile.py0000664000175000017500000003637315167130205021626 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from pathlib import Path
from subprocess import CalledProcessError
import pytest
import numpy as np

from loki import (
    Sourcefile, FindNodes, PreprocessorDirective, Intrinsic,
    Assignment, Import, fgen, ProcedureType, ProcedureSymbol,
    StatementFunction, Comment, CommentBlock, RawSource, Scalar
)
from loki.jit_build import jit_compile, clean_test
from loki.frontend import available_frontends, OMNI, FP, REGEX


@pytest.fixture(scope='module', name='here')
def fixture_here():
    return Path(__file__).parent


@pytest.mark.parametrize('frontend', available_frontends())
def test_sourcefile_properties(here, frontend, tmp_path):
    """
    Test that all subroutines and functions are discovered
    and exposed via `subroutines` and `all_subroutines` properties.
    """
    # pylint: disable=no-member
    filepath = here/'sources/sourcefile.f90'
    source = Sourcefile.from_file(filepath, frontend=frontend, xmods=[tmp_path])
    assert len(source.subroutines) == 3
    assert len(source.all_subroutines) == 5

    subroutines = ['routine_a', 'routine_b', 'function_d']
    all_subroutines = subroutines + ['module_routine', 'module_function']
    contained_routines = ['contained_c']

    assert sum(routine.name in subroutines for routine in source.subroutines) == 3
    assert sum(routine.name in all_subroutines for routine in source.subroutines) == 3
    assert sum(routine.name in contained_routines for routine in source.subroutines) == 0

    assert sum(routine.name in subroutines for routine in source.all_subroutines) == 3
    assert sum(routine.name in all_subroutines for routine in source.all_subroutines) == 5
    assert sum(routine.name in contained_routines for routine in source.all_subroutines) == 0


@pytest.mark.parametrize('frontend', available_frontends())
def test_sourcefile_from_source(frontend, tmp_path):
    """
    Test the `from_source` constructor for `Sourcefile` objects.
    """
    # pylint: disable=no-member

    fcode = """
! Some comment
subroutine routine_a
  integer a
  a = 1
end subroutine routine_a

! Some comment
module some_module
contains
  subroutine module_routine
    integer m
    m = 2
  end subroutine module_routine
  function module_function(n)
    integer n
    n = 3
  end function module_function
end module some_module
! Other comment

subroutine routine_b
  integer b
  b = 4
contains
  subroutine contained_c
    integer c
    c = 5
  end subroutine contained_c
end subroutine routine_b
! Other comment

function function_d(d)
  integer d
  d = 6
end function function_d
""".strip()
    source = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    assert len(source.subroutines) == 3
    assert len(source.all_subroutines) == 5

    subroutines = ['routine_a', 'routine_b', 'function_d']
    all_subroutines = subroutines + ['module_routine', 'module_function']

    assert [routine.name.lower() for routine in source.subroutines] == subroutines
    assert [routine.name.lower() for routine in source.all_subroutines] == all_subroutines
    assert 'contained_c' not in [routine.name.lower() for routine in source.subroutines]
    assert 'contained_c' not in [routine.name.lower() for routine in source.all_subroutines]

    comments = FindNodes((Comment, CommentBlock)).visit(source.ir)
    assert len(comments) == 4
    assert all(comment.text.strip() in ['! Some comment', '! Other comment'] for comment in comments)


@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'Files are preprocessed')]))
def test_sourcefile_pp_macros(here, frontend):
    filepath = here/'sources/sourcefile_pp_macros.F90'
    routine = Sourcefile.from_file(filepath, frontend=frontend)['routine_pp_macros']
    directives = FindNodes(PreprocessorDirective).visit(routine.ir)
    assert len(directives) == 8
    assert all(node.text.startswith('#') for node in directives)


@pytest.mark.parametrize('frontend', available_frontends(xfail=[
    (OMNI, 'Files are preprocessed')
]))
def test_sourcefile_pp_directives(here, frontend):
    filepath = here/'sources/sourcefile_pp_directives.F90'
    routine = Sourcefile.from_file(filepath, frontend=frontend)['routine_pp_directives']

    # Note: these checks are rather loose as we currently do not restore the original version but
    # simply replace the PP constants by strings
    directives = FindNodes(PreprocessorDirective).visit(routine.body)
    assert len(directives) == 1
    assert directives[0].text == '#define __FILENAME__ __FILE__'
    intrinsics = FindNodes(Intrinsic).visit(routine.body)
    assert '__FILENAME__' in intrinsics[0].text and '__DATE__' in intrinsics[0].text
    assert '__FILE__' in intrinsics[1].text and '__VERSION__' in intrinsics[1].text

    statements = FindNodes(Assignment).visit(routine.body)
    assert len(statements) == 1
    assert fgen(statements[0]) == 'y = 0*5 + 0'


@pytest.mark.parametrize('frontend', available_frontends())
def test_sourcefile_pp_include(here, frontend):
    filepath = here/'sources/sourcefile_pp_include.F90'
    sourcefile = Sourcefile.from_file(filepath, frontend=frontend, includes=[here/'include'])
    routine = sourcefile['routine_pp_include']

    statements = FindNodes(Assignment).visit(routine.body)
    assert len(statements) == 1
    if frontend == OMNI:
        # OMNI resolves that statement function!
        assert fgen(statements[0]) == 'c = real(a + b, kind=4)'
    else:
        assert fgen(statements[0]) == 'c = add(a, b)'

    if frontend is not OMNI:
        # OMNI resolves the import in the frontend
        imports = FindNodes(Import).visit([routine.spec, routine.body])
        assert len(imports) == 1
        assert imports[0].c_import
        assert imports[0].module == 'some_header.h'


@pytest.mark.parametrize('frontend', available_frontends())
def test_sourcefile_cpp_preprocessing(here, frontend):
    """
    Test the use of the external CPP-preprocessor.
    """
    filepath = here/'sources/sourcefile_cpp_preprocessing.F90'

    source = Sourcefile.from_file(filepath, preprocess=True, frontend=frontend)
    routine = source['sourcefile_external_preprocessing']
    directives = FindNodes(PreprocessorDirective).visit(routine.ir)

    if frontend is not OMNI:
        # OMNI skips the import in the frontend
        imports = FindNodes(Import).visit([routine.spec, routine.body])
        assert len(imports) == 1
        assert imports[0].c_import
        assert imports[0].module == 'some_header.h'

    assert len(directives) == 0
    assert 'b = 123' in fgen(routine)

    # Check that the ``define`` gets propagated correctly
    source = Sourcefile.from_file(filepath, preprocess=True, defines='FLAG_SMALL',
                                  frontend=frontend)
    routine = source['sourcefile_external_preprocessing']
    directives = FindNodes(PreprocessorDirective).visit(routine.ir)

    assert len(directives) == 0
    assert 'b = 6' in fgen(routine)


@pytest.mark.parametrize('frontend', available_frontends())
def test_sourcefile_cpp_stmt_func(here, frontend, tmp_path):
    """
    Test the correct identification of statement functions
    after inlining by preprocessor.
    """
    sourcepath = here/'sources'
    filepath = sourcepath/'sourcefile_cpp_stmt_func.F90'

    source = Sourcefile.from_file(filepath, includes=sourcepath, preprocess=True, frontend=frontend, xmods=[tmp_path])
    module = source['cpp_stmt_func_mod']
    module.name += f'_{frontend!s}'

    # OMNI inlines statement functions, so we can't check the representation
    if frontend != OMNI:
        routine = source['cpp_stmt_func']
        stmt_func_decls = FindNodes(StatementFunction).visit(routine.spec)
        assert len(stmt_func_decls) == 4

        for decl in stmt_func_decls:
            var = routine.variable_map[str(decl.variable)]
            assert isinstance(var, ProcedureSymbol)
            assert isinstance(var.type.dtype, ProcedureType)
            assert var.type.dtype.procedure is decl
            assert decl.source is not None

    # Generate code and compile
    filepath = tmp_path/f'{module.name}.f90'
    mod = jit_compile(source, filepath=filepath, objname=module.name)

    # Verify it produces correct results
    klon, klev = 10, 5
    kidia, kfdia = 1, klon
    zfoeew = np.zeros((klon, klev), order='F')
    mod.cpp_stmt_func(kidia, kfdia, klon, klev, zfoeew)
    assert (zfoeew == 0.25).all()

    clean_test(filepath)


@pytest.mark.parametrize('frontend', available_frontends())
def test_sourcefile_lazy_construction(frontend, tmp_path):
    """
    Test delayed ("lazy") parsing of sourcefile content
    """
    fcode = """
! A comment to test
subroutine routine_a
integer a
a = 1
end subroutine routine_a

module some_module
contains
subroutine module_routine
integer m
m = 2
end subroutine module_routine
function module_function(n)
integer n
n = 3
end function module_function
end module some_module

#ifndef SOME_PREPROC_VAR
subroutine routine_b
integer b
b = 4
contains
subroutine contained_c
integer c
c = 5
end subroutine contained_c
end subroutine routine_b
#endif

function function_d(d)
integer d
d = 6
end function function_d
    """.strip()
    source = Sourcefile.from_source(fcode, frontend=REGEX)
    assert len(source.subroutines) == 3
    assert len(source.all_subroutines) == 5

    some_module = source['some_module']
    routine_b = source['routine_b']
    module_routine = some_module['module_routine']
    function_d = source['function_d']
    assert function_d.arguments == ()

    # Make sure we have an incomplete parse tree until now
    assert source._incomplete
    assert len(FindNodes(RawSource).visit(source.ir)) == 5
    assert len(FindNodes(RawSource).visit(source['routine_a'].ir)) == 1

    # Trigger the full parse
    try:
        source.make_complete(frontend=frontend, xmods=[tmp_path])
    except CalledProcessError as ex:
        if frontend == OMNI and ex.returncode == -11:
            pytest.xfail('F_Front segfault is a known issue on some platforms')
        raise
    assert not source._incomplete

    # Make sure no RawSource nodes are left
    assert not FindNodes(RawSource).visit(source.ir)
    if frontend == FP:
        # Some newlines are also treated as comments
        assert len(FindNodes(Comment).visit(source.ir)) == 2
    else:
        assert len(FindNodes(Comment).visit(source.ir)) == 1
    if frontend == OMNI:
        assert not FindNodes(PreprocessorDirective).visit(source.ir)
    else:
        assert len(FindNodes(PreprocessorDirective).visit(source.ir)) == 2
    for routine in source.all_subroutines:
        assert not FindNodes(RawSource).visit(routine.ir)
        assert len(FindNodes(Assignment).visit(routine.ir)) == 1

    # The previously generated ProgramUnit objects should be the same as before
    assert routine_b is source['routine_b']
    assert some_module is source['some_module']
    assert module_routine is source['some_module']['module_routine']
    assert function_d.arguments == ('d',)
    assert isinstance(function_d.arguments[0], Scalar)


@pytest.mark.parametrize('frontend', available_frontends())
def test_sourcefile_lazy_comments(frontend):
    """
    Make sure that lazy construction can handle comments on source file level
    (i.e. outside a program unit)
    """
    fcode = """
! Comment outside
subroutine myroutine
    ! Comment inside
end subroutine myroutine
! Other comment outside
    """.strip()
    source = Sourcefile.from_source(fcode, frontend=REGEX)

    assert isinstance(source.ir.body[0], RawSource)
    assert isinstance(source.ir.body[2], RawSource)

    myroutine = source['myroutine']
    assert isinstance(myroutine.spec.body[0], RawSource)

    source.make_complete(frontend=frontend)

    assert isinstance(source.ir.body[0], Comment)
    assert isinstance(source.ir.body[2], Comment)
    if frontend == OMNI:
        assert isinstance(myroutine.body.body[0], Comment)
    else:
        assert isinstance(myroutine.docstring[0], Comment)

    code = source.to_fortran()
    assert '! Comment outside' in code
    assert '! Comment inside' in code
    assert '! Other comment outside' in code


@pytest.mark.parametrize('frontend', available_frontends(include_regex=True))
def test_sourcefile_clone(frontend, tmp_path):
    """
    Make sure cloning a source file works as expected
    """
    fcode = """
! Comment outside
module my_mod
  implicit none
  contains
    subroutine my_routine
      implicit none
    end subroutine my_routine
end module my_mod

subroutine other_routine
  use my_mod, only: my_routine
  implicit none
  call my_routine()
end subroutine other_routine
    """.strip()
    source = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    # Clone the source file twice
    new_source = source.clone()
    new_new_source = source.clone()

    # Apply some changes that should only be affecting each clone
    new_source['other_routine'].name = 'new_name'
    new_new_source['my_mod']['my_routine'].name = 'new_mod_routine'

    assert 'other_routine' in source
    assert 'other_routine' not in new_source
    assert 'other_routine' in new_new_source

    assert 'new_name' not in source
    assert 'new_name' in new_source
    assert 'new_name' not in new_new_source

    assert 'my_mod' in source
    assert 'my_mod' in new_source
    assert 'my_mod' in new_new_source

    assert 'my_routine' in source['my_mod']
    assert 'my_routine' in new_source['my_mod']
    assert 'my_routine' not in new_new_source['my_mod']

    assert 'new_mod_routine' not in source['my_mod']
    assert 'new_mod_routine' not in new_source['my_mod']
    assert 'new_mod_routine' in new_new_source['my_mod']

    if not source._incomplete:
        assert isinstance(source.ir.body[0], Comment)
        comment_text = source.ir.body[0].text
        new_comment_text = comment_text + ' some more text'
        source.ir.body[0]._update(text=new_comment_text)

        assert source.ir.body[0].text == new_comment_text
        assert new_source.ir.body[0].text == comment_text
        assert new_new_source.ir.body[0].text == comment_text
    else:
        assert new_source._incomplete
        assert new_new_source._incomplete

        assert source['other_routine']._incomplete
        assert new_source['new_name']._incomplete
        assert new_new_source['other_routine']._incomplete

        assert new_source['new_name']._parser_classes == source['other_routine']._parser_classes
        assert new_new_source['other_routine']._parser_classes == source['other_routine']._parser_classes

        mod = source['my_mod']
        new_mod = new_source['my_mod']
        new_new_mod = new_new_source['my_mod']

        assert mod._incomplete
        assert new_mod._incomplete
        assert new_new_mod._incomplete

        assert new_mod._parser_classes == mod._parser_classes
        assert new_new_mod._parser_classes == mod._parser_classes

        assert mod['my_routine']._incomplete
        assert new_mod['my_routine']._incomplete
        assert new_new_mod['new_mod_routine']._incomplete

        assert new_mod['my_routine']._parser_classes == mod['my_routine']._parser_classes
        assert new_new_mod['new_mod_routine']._parser_classes == mod['my_routine']._parser_classes
loki-ecmwf-0.3.6/loki/tests/test_interfaces.py0000664000175000017500000003767615167130205021620 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from pathlib import Path
import pytest

from loki import (
    Module, Subroutine, FindNodes, Interface, Import, fgen,
    ProcedureSymbol, ProcedureType
)
from loki.frontend import available_frontends, OMNI, REGEX


@pytest.fixture(scope='module', name='here')
def fixture_here():
    return Path(__file__).parent


@pytest.mark.parametrize('frontend', available_frontends(include_regex=True))
def test_interface_spec(frontend, tmp_path):
    """
    Test basic functionality of interface representation
    """
    fcode = """
module interface_spec_mod
    interface
        subroutine sub(a, b)
            integer, intent(in) :: a
            integer, intent(out) :: b
        end subroutine sub
    end interface
end module interface_spec_mod
    """.strip()

    # Parse the source and find the interface
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    interfaces = FindNodes(Interface).visit(module.spec)
    assert len(interfaces) == 1
    interface = interfaces[0]

    # Make sure basic properties are right
    assert interface.abstract is False
    assert interface.symbols == ('sub',)
    assert 'sub' in interface
    assert interface.symbol_map == {'sub': interface.symbols[0]}

    # Check the subroutine is there
    assert len(interface.body) == 1
    assert isinstance(interface.body[0], Subroutine)

    # Sanity check fgen
    code = module.to_fortran().lower()
    assert 'interface' in code
    assert 'end interface' in code
    assert 'subroutine sub' in code

    assert repr(interface) == 'Interface:: sub'


@pytest.mark.parametrize('frontend', available_frontends(include_regex=True))
def test_interface_module_integration(frontend, tmp_path):
    """
    Test correct integration of interfaces into modules
    """
    fcode = """
module interface_module_integration_mod
    abstract interface
        subroutine sub(a, b)
            integer, intent(in) :: a
            integer, intent(out) :: b
        end subroutine sub
    end interface
end module interface_module_integration_mod
    """.strip()

    # Parse the source and find the interface
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    assert len(module.interfaces) == 1
    interface = module.interfaces[0]
    assert isinstance(interface, Interface)

    # Make sure declared symbols are accessible through various properties
    assert interface.symbols == ('sub',)
    assert module.interface_symbols == ('sub',)
    assert module.interface_map['sub'] is interface
    assert module.interface_symbol_map == {'sub': interface.symbols[0]}
    assert 'sub' in module.symbols
    assert module.symbol_map['sub'] == interface.symbols[0]

    # Sanity check fgen
    code = module.to_fortran().lower()
    assert 'abstract interface' in code
    assert 'subroutine sub' in code


@pytest.mark.parametrize('frontend', available_frontends())
def test_interface_subroutine_integration(frontend):
    """
    Test correct integration of interfaces into subroutines
    """
    fcode = """
subroutine interface_subroutine_integration(X, Y, N, PROC)
    INTEGER, INTENT(IN) :: X(:,:), N
    INTEGER, INTENT(OUT) :: Y(:)
    INTERFACE
        SUBROUTINE PROC(A, B)
            INTEGER, INTENT(IN) :: A(:)
            INTEGER, INTENT(OUT) :: B
        END SUBROUTINE PROC
    END INTERFACE
    INTEGER :: I

    DO I=1,N
        CALL PROC(X(:, I), Y(I))
    END DO
end subroutine interface_subroutine_integration
    """.strip()

    # Parse the source and find the interface
    routine = Subroutine.from_source(fcode, frontend=frontend)
    assert len(routine.interfaces) == 1
    interface = routine.interfaces[0]
    assert isinstance(interface, Interface)

    # Make sure the declared symbols are accessible through various properties
    assert interface.symbols == ('proc',)
    assert routine.interface_symbols == ('proc',)
    assert routine.interface_map['proc'] is interface
    assert routine.interface_symbol_map == {'proc': interface.symbols[0]}
    assert 'proc' in routine.symbols
    assert routine.symbol_map['proc'] == interface.symbols[0]
    assert 'proc' in routine.arguments
    assert 'proc' in [arg.lower() for arg in routine.argnames]
    assert routine.symbol_map['proc'].type.dtype.procedure is interface.body[0]

    # Sanity check fgen
    code = routine.to_fortran().lower()
    assert 'interface' in code
    assert 'subroutine proc' in code


@pytest.mark.parametrize('frontend', available_frontends())
def test_interface_import(frontend, tmp_path):
    """
    Test correct representation of ``IMPORT`` statements in interfaces
    """
    # Example from F2008, Note 12.5
    # The IMPORT statement can be used to allow module procedures to have dummy arguments that are
    # procedures with assumed-shape arguments of an opaque type.
    # The MONITOR dummy procedure requires an explicit interface because it has an assumed-shape array
    # argument, but TYPE(T) would not be available inside the interface body without the IMPORT statement.
    fcode = """
module interface_import_mod
    TYPE T
        PRIVATE ! T is an opaque type
    END TYPE
CONTAINS
    SUBROUTINE PROCESS(X,Y,RESULT,MONITOR)
        TYPE(T),INTENT(IN) :: X(:,:),Y(:,:)
        TYPE(T),INTENT(OUT) :: RESULT(:,:)
        INTERFACE
            SUBROUTINE MONITOR(ITERATION_NUMBER,CURRENT_ESTIMATE)
                IMPORT T
                INTEGER,INTENT(IN) :: ITERATION_NUMBER
                TYPE(T),INTENT(IN) :: CURRENT_ESTIMATE(:,:)
            END SUBROUTINE
        END INTERFACE
    END SUBROUTINE
end module interface_import_mod
    """.strip()

    # Parse the source and find the interface
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    interface = module['process'].interface_map['monitor']

    # Find the import statement and test its properties
    assert len(interface.body) == 1
    imprts = FindNodes(Import).visit(interface.body[0].spec)
    assert len(imprts) == 1

    # Sanity check fgen
    assert fgen(imprts[0]).lower() == 'import t'
    assert 'import t' in fgen(interface).lower()


@pytest.mark.parametrize('frontend', available_frontends(include_regex=True))
def test_interface_multiple_routines(frontend, tmp_path):
    """
    Test interfaces with multiple subroutine/function declarations
    in the interface block
    """
    # Example from F2008, Note 12.4
    fcode = """
module interface_multiple_routines_mod
    INTERFACE
        SUBROUTINE EXT1 (X, Y, Z)
            REAL, DIMENSION (100, 100) :: X, Y, Z
        END SUBROUTINE EXT1
        SUBROUTINE EXT2 (X, Z)
            REAL X
            COMPLEX (KIND = 4) Z (2000)
        END SUBROUTINE EXT2
        FUNCTION EXT3 (P, Q)
            LOGICAL EXT3
            INTEGER P (1000)
            LOGICAL Q (1000)
        END FUNCTION EXT3
    END INTERFACE
end module interface_multiple_routines_mod
    """.strip()

    # Parse the source and find the interface
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    if frontend == OMNI:
        # OMNI has to do things differently, of course, and splits the interface
        # block into separate blocks for each procedures
        assert len(module.interfaces) == 3

        # Make sure interfaces can be found under their declared names
        for intf in module.interfaces:
            assert len(intf.symbols) == 1
            name = str(intf.symbols[0])
            assert module.interface_map[name] is intf

    else:
        assert len(module.interfaces) == 1
        intf = module.interfaces[0]

        # Make sure interface is found under all names
        assert all(module.interface_map[name] is intf for name in ['ext1', 'ext2', 'ext3'])

        # Make sure declared names end up in the right places
        assert intf.symbols == ('ext1', 'ext2', 'ext3')

    assert all(name in module.symbols for name in ('ext1', 'ext2', 'ext3'))

    # Sanity check fgen
    code = module.to_fortran().lower()
    assert 'subroutine ext1' in code
    assert 'subroutine ext2' in code
    assert 'function ext3' in code


@pytest.mark.parametrize('frontend', available_frontends(include_regex=True))
def test_interface_generic_spec(frontend, tmp_path):
    """
    Test interfaces with a generic identifier
    """
    # Fortran 2008, Note 12.6
    fcode = """
module interface_generic_spec_mod
    IMPLICIT NONE
    INTERFACE SWITCH
        SUBROUTINE INT_SWITCH (X, Y)
        INTEGER, INTENT (INOUT) :: X, Y
        END SUBROUTINE INT_SWITCH
        SUBROUTINE REAL_SWITCH (X, Y)
            REAL, INTENT (INOUT) :: X, Y
        END SUBROUTINE REAL_SWITCH
        SUBROUTINE COMPLEX_SWITCH (X, Y)
            COMPLEX, INTENT (INOUT) :: X, Y
        END SUBROUTINE COMPLEX_SWITCH
    END INTERFACE SWITCH
end module interface_generic_spec_mod
    """.strip()

    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    if frontend == OMNI:
        # FANTASTIC... OMNI helps us to a treat and separates the subroutine interfaces
        # from the generic interface...
        assert len(module.interfaces) == 4
    else:
        assert len(module.interfaces) == 1

    assert set(module.interfaces[-1].symbols) == {'switch', 'int_switch', 'real_switch', 'complex_switch'}

    # This applies only to OMNI
    for intf in module.interfaces[:-1]:
        assert intf.spec is None

    # Now the actual generic interface
    intf = module.interfaces[-1]
    assert isinstance(intf.spec, ProcedureSymbol)
    assert intf.spec.scope is module
    assert intf.spec == 'switch'
    assert intf.spec.type.dtype.is_generic is True
    assert 'INTERFACE SWITCH' in fgen(intf).upper()
    assert repr(intf).upper() == 'INTERFACE SWITCH:: SWITCH, INT_SWITCH, REAL_SWITCH, COMPLEX_SWITCH'

    assert all(s in module.symbols for s in ('switch', 'int_switch', 'real_switch', 'complex_switch'))


@pytest.mark.parametrize('frontend', available_frontends(include_regex=True))
def test_interface_operator_module_procedure(frontend, tmp_path):
    """
    Test interfaces that declare generic operators and refer to module procedures
    """
    fcode = """
MODULE SPECTRAL_FIELDS_MOD
IMPLICIT NONE
PRIVATE
PUBLIC SPECTRAL_FIELD, ASSIGNMENT(=), OPERATOR(.EQV.)

! Trimmed-down version !
TYPE SPECTRAL_FIELD
    REAL, ALLOCATABLE :: SP2D(:,:)
    INTEGER :: NS2D
    INTEGER :: NSPEC2
END TYPE SPECTRAL_FIELD

INTERFACE ASSIGNMENT (=)
    MODULE PROCEDURE ASSIGN_SCALAR_SP, ASSIGN_SP_AR
END INTERFACE

INTERFACE OPERATOR (.EQV.)
    PROCEDURE EQUIV_SPEC
END INTERFACE

CONTAINS

SUBROUTINE ASSIGN_SCALAR_SP(YDSP,PVAL)
    TYPE (SPECTRAL_FIELD), INTENT(INOUT) :: YDSP
    REAL, INTENT(IN) :: PVAL
    YDSP%SP2D(:,:)  =PVAL
END SUBROUTINE ASSIGN_SCALAR_SP

SUBROUTINE ASSIGN_SP_AR(PFLAT,YDSP)
    REAL, INTENT(OUT) :: PFLAT(:)
    TYPE (SPECTRAL_FIELD), INTENT(IN) :: YDSP
    INTEGER :: I2D,ISHAPE2D(1)

    I2D=YDSP%NS2D*YDSP%NSPEC2
    ISHAPE2D(1)=I2D
    PFLAT(    1:    I2D)=RESHAPE(YDSP%SP2D(:,:)  ,ISHAPE2D)
END SUBROUTINE ASSIGN_SP_AR

LOGICAL FUNCTION EQUIV_SPEC(YDSP1,YDSP2)
    TYPE(SPECTRAL_FIELD), INTENT(IN) :: YDSP1
    TYPE(SPECTRAL_FIELD), INTENT(IN) :: YDSP2
    LOGICAL :: LL
    INTEGER :: JF, JM
    ! Modified for simplicity!

    LL = .TRUE.
    LL = LL .AND. (YDSP1%NS2D ==YDSP2%NS2D)
    LL = LL .AND. (YDSP1%NSPEC2 ==YDSP2%NSPEC2)
    IF (LL) THEN
        DO JF=1,YDSP1%NS2D
            DO JM=1,YDSP1%NSPEC2
                LL = LL .AND. (YDSP1%SP2D(JF, JM)==YDSP2%SP2D(JF, JM))
            ENDDO
        ENDDO
    ENDIF

    EQUIV_SPEC=LL
END FUNCTION EQUIV_SPEC
END MODULE SPECTRAL_FIELDS_MOD
    """.strip()

    mod = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    assert len(mod.interfaces) == 2

    assign_intf = mod.interface_map['assignment(=)']
    assert assign_intf.spec == 'assignment(=)'
    assert set(assign_intf.symbols) == {'assignment(=)', 'assign_scalar_sp', 'assign_sp_ar'}

    assign_map = {s.name.lower(): s for s in assign_intf.symbols}
    assert assign_map['assignment(=)'].type.dtype.is_generic is True
    assert assign_map['assign_scalar_sp'].type.dtype.procedure is mod['assign_scalar_sp']
    assert assign_map['assign_scalar_sp'].type.dtype.is_generic is False
    assert assign_map['assign_sp_ar'].type.dtype.procedure is mod['assign_sp_ar']
    assert assign_map['assign_sp_ar'].type.dtype.is_generic is False

    if frontend == OMNI:  # One declaration per line... :eyeroll:
        assert len(assign_intf.body) == 2
    else:
        assert len(assign_intf.body) == 1
    assign_decl = assign_intf.body[0]
    assert assign_decl.module is True

    op_intf = mod.interface_map['operator(.eqv.)']
    assert op_intf.spec == 'operator(.eqv.)'
    assert set(op_intf.symbols) == {'operator(.eqv.)', 'equiv_spec'}

    op_map = {s.name.lower(): s for s in op_intf.symbols}
    assert op_map['operator(.eqv.)'].type.dtype.is_generic is True
    assert op_map['equiv_spec'].type.dtype.procedure is mod['equiv_spec']
    assert op_map['equiv_spec'].type.dtype.is_generic is False

    assert len(op_intf.body) == 1
    op_decl = op_intf.body[0]

    if frontend != OMNI:  # Grrr...
        assert op_decl.module is False

    assign_code = fgen(assign_intf).lower().strip()
    assert assign_code.startswith('interface assignment(=)')
    assert assign_code.endswith('end interface assignment(=)')
    assert 'module procedure' in assign_code

    op_code = fgen(op_intf).lower().strip()
    assert op_code.startswith('interface operator(.eqv.)')
    assert op_code.endswith('end interface operator(.eqv.)')

    if frontend != OMNI:  # *...*
        assert 'module' not in op_code

    other_code = """
module use_spectral_fields_mod
    use spectral_fields_mod, only: assignment(=), operator(.eqv.)
end module use_spectral_fields_mod
    """.strip()

    other_mod = Module.from_source(other_code, frontend=frontend, definitions=[mod], xmods=[tmp_path])
    assert set(other_mod.symbols) == {'assignment(=)', 'operator(.eqv.)'}
    assert other_mod.imported_symbols == ('assignment(=)', 'operator(.eqv.)')

    assign_sym = other_mod.imported_symbol_map['assignment(=)']
    op_sym = other_mod.imported_symbol_map['operator(.eqv.)']

    assert assign_sym.type.imported is True
    assert op_sym.type.imported is True

    if frontend != REGEX:  # REGEX frontend doesn't use definitions and therefore doesn't import types
        assert isinstance(assign_sym, ProcedureSymbol)
        assert isinstance(assign_sym.type.dtype, ProcedureType)
        assert assign_sym.type.dtype.is_generic is True
        assert assign_sym.type.module is mod

        assert isinstance(op_sym, ProcedureSymbol)
        assert isinstance(op_sym.type.dtype, ProcedureType)
        assert op_sym.type.dtype.is_generic is True
        assert op_sym.type.module is mod

    assert other_code.splitlines()[1].strip() in fgen(other_mod).lower()


@pytest.mark.parametrize('frontend', available_frontends(include_regex=True))
def test_interface_procedure_pointer(frontend, tmp_path):
    fcode = """
module my_interface_mod
implicit none
ABSTRACT INTERFACE
  FUNCTION SIM_FUNC (X)
    REAL, INTENT (IN) :: X
    REAL :: SIM_FUNC
  END FUNCTION SIM_FUNC
END INTERFACE

INTERFACE
  SUBROUTINE SUB2 (X, P)
    REAL, INTENT (INOUT) :: X
    PROCEDURE(SIM_FUNC) :: P
  END SUBROUTINE SUB2
END INTERFACE
end module my_interface_mod
    """.strip()

    mod = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    intf_sim_func = mod.interface_map['sim_func']
    assert intf_sim_func.abstract
    assert intf_sim_func.symbols[0].type.dtype.procedure is intf_sim_func.body[0]
    assert repr(intf_sim_func).upper() == 'ABSTRACT INTERFACE:: SIM_FUNC'

    intf_sub2 = mod.interface_map['sub2']
    assert intf_sub2.symbols[0].type.dtype.procedure is intf_sub2.body[0]
    sub2 = intf_sub2.body[0]

    if frontend != REGEX:
        arg_p = sub2.arguments[1]
        assert isinstance(arg_p.type.dtype, ProcedureType)
loki-ecmwf-0.3.6/loki/tests/test_interprocedural_analysis.py0000664000175000017500000001666515167130205024575 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Various tests for interprocedural analysis features in Loki
"""

import pytest

from loki import (
    Sourcefile, FindNodes, FindInlineCalls, CallStatement, IntLiteral
)
from loki.frontend import available_frontends


@pytest.mark.parametrize('frontend', available_frontends())
def test_ipa_call_statement_arg_iter(frontend, tmp_path):
    """
    Test that :any:`CallStatement.arg_iter` works as expected
    """
    fcode_caller = """
subroutine caller
    use some_mod, only: callee
    implicit none
    integer :: arg1, arg3(10)
    real :: arg2
    call callee(arg1, arg2, arg3, 4)
end subroutine caller
    """.strip()

    fcode_callee = """
module some_mod
    implicit none
contains
    subroutine callee(var, VAR2, arr, val)
        integer, intent(inout) :: var
        real, intent(in) :: var2
        integer, intent(in) :: arr(:)
        integer, intent(in) :: val
    end subroutine callee
end module some_mod
    """.strip()

    callee_source = Sourcefile.from_source(fcode_callee, frontend=frontend, xmods=[tmp_path])
    caller_source = Sourcefile.from_source(
        fcode_caller, frontend=frontend, xmods=[tmp_path],
        definitions=callee_source.definitions
    )

    callee = callee_source['callee']
    caller = caller_source['caller']

    calls = FindNodes(CallStatement).visit(caller.body)
    assert len(calls) == 1
    arg_iter = list(calls[0].arg_iter())
    assert arg_iter == [
        ('var', 'arg1'), ('var2', 'arg2'), ('arr(:)', 'arg3'), ('val', '4')
    ]

    for kernel_arg, caller_arg in calls[0].arg_iter():
        assert kernel_arg.scope is callee
        assert isinstance(caller_arg, IntLiteral) or caller_arg.scope is caller


@pytest.mark.parametrize('frontend', available_frontends())
def test_ipa_call_statement_arg_iter_optional(frontend, tmp_path):
    """
    Test that :any:`CallStatement.arg_iter` works as expected with optional arguments
    """
    fcode_caller = """
subroutine caller
    use some_mod, only: callee
    implicit none
    integer :: arg1, arg3(10)
    real :: arg2
    call callee(arg1, arg2, arg3, 4)
    call callee(arg1, arg2, VAL=4, arr=arg3, opt2=1)
    call callee(arg1, arg2, arg3, 4, 1, 2)
end subroutine caller
    """.strip()

    fcode_callee = """
module some_mod
    implicit none
contains
    subroutine callee(var, VAR2, arr, val, OPT1, opt2)
        integer, intent(inout) :: var
        real, intent(in) :: var2
        integer, intent(in) :: arr(:)
        integer, intent(in) :: val
        integer, intent(in), optional :: OPT1, opt2
    end subroutine callee
end module some_mod
    """.strip()

    callee_source = Sourcefile.from_source(fcode_callee, frontend=frontend, xmods=[tmp_path])
    caller_source = Sourcefile.from_source(
        fcode_caller, frontend=frontend, xmods=[tmp_path],
        definitions=callee_source.definitions
    )

    callee = callee_source['callee']
    caller = caller_source['caller']

    calls = FindNodes(CallStatement).visit(caller.body)
    assert len(calls) == 3
    arg_iter = list(calls[0].arg_iter())
    assert arg_iter == [
        ('var', 'arg1'), ('var2', 'arg2'), ('arr(:)', 'arg3'), ('val', '4')
    ]
    arg_iter = list(calls[1].arg_iter())
    assert arg_iter == [
        ('var', 'arg1'), ('var2', 'arg2'), ('val', '4'), ('arr(:)', 'arg3'), ('opt2', '1')
    ]
    arg_iter = list(calls[2].arg_iter())
    assert arg_iter == [
        ('var', 'arg1'), ('var2', 'arg2'), ('arr(:)', 'arg3'), ('val', '4'), ('opt1', '1'), ('opt2', '2')
    ]

    for call in calls:
        for kernel_arg, caller_arg in call.arg_iter():
            assert kernel_arg.scope is callee
            assert isinstance(caller_arg, IntLiteral) or caller_arg.scope is caller


@pytest.mark.parametrize('frontend', available_frontends())
def test_ipa_inline_call_arg_iter(frontend, tmp_path):
    """
    Test that :any:`InlineCall.arg_iter` works as expected
    """
    fcode_caller = """
subroutine caller
    use some_mod, only: callee
    implicit none
    integer :: arg1, arg3(10)
    real :: arg2, ret
    ret = callee(arg1, arg2, arg3, 4)
end subroutine caller
    """.strip()

    fcode_callee = """
module some_mod
    implicit none
contains
    function callee(var, VAR2, arr, val)
        integer, intent(inout) :: var
        real, intent(in) :: var2
        integer, intent(in) :: arr(:)
        integer, intent(in) :: val
        real :: callee
    end function callee
end module some_mod
    """.strip()

    callee_source = Sourcefile.from_source(fcode_callee, frontend=frontend, xmods=[tmp_path])
    caller_source = Sourcefile.from_source(
        fcode_caller, frontend=frontend, xmods=[tmp_path],
        definitions=callee_source.definitions
    )

    callee = callee_source['callee']
    caller = caller_source['caller']

    calls = list(FindInlineCalls().visit(caller.body))
    assert len(calls) == 1
    arg_iter = list(calls[0].arg_iter())
    assert arg_iter == [
        ('var', 'arg1'), ('var2', 'arg2'), ('arr(:)', 'arg3'), ('val', '4')
    ]

    for kernel_arg, caller_arg in calls[0].arg_iter():
        assert kernel_arg.scope is callee
        assert isinstance(caller_arg, IntLiteral) or caller_arg.scope is caller


@pytest.mark.parametrize('frontend', available_frontends())
def test_ipa_inline_call_arg_iter_optional(frontend, tmp_path):
    """
    Test that :any:`InlineCall.arg_iter` works as expected
    """
    fcode_caller = """
subroutine caller
    use some_mod, only: callee
    implicit none
    integer :: arg1, arg3(10)
    real :: arg2, ret
    ret = callee(arg1, arg2, arg3, 4)
    ret = ret + callee(arg1, ARR=arg3, var2=arg2, opt2=2, val=4)
    ret = ret + callee(arg1, arg2, arg3, 4, 1, 2)
end subroutine caller
    """.strip()

    fcode_callee = """
module some_mod
    implicit none
contains
    function callee(var, VAR2, arr, val, opt1, OPT2)
        integer, intent(inout) :: var
        real, intent(in) :: var2
        integer, intent(in) :: arr(:)
        integer, intent(in) :: val
        integer, intent(in) :: opt1, OPT2
        real :: callee
    end function callee
end module some_mod
    """.strip()

    callee_source = Sourcefile.from_source(fcode_callee, frontend=frontend, xmods=[tmp_path])
    caller_source = Sourcefile.from_source(
        fcode_caller, frontend=frontend, xmods=[tmp_path],
        definitions=callee_source.definitions
    )

    callee = callee_source['callee']
    caller = caller_source['caller']

    calls = list(FindInlineCalls(unique=False).visit(caller.body))
    assert len(calls) == 3
    arg_iter = list(calls[0].arg_iter())
    assert arg_iter == [
        ('var', 'arg1'), ('var2', 'arg2'), ('arr(:)', 'arg3'), ('val', '4')
    ]
    arg_iter = list(calls[1].arg_iter())
    assert arg_iter == [
        ('var', 'arg1'), ('arr(:)', 'arg3'), ('var2', 'arg2'), ('opt2', '2'), ('val', '4')
    ]
    arg_iter = list(calls[2].arg_iter())
    assert arg_iter == [
        ('var', 'arg1'), ('var2', 'arg2'), ('arr(:)', 'arg3'), ('val', '4'), ('opt1', '1'), ('opt2', '2')
    ]

    for call in calls:
        for kernel_arg, caller_arg in call.arg_iter():
            assert kernel_arg.scope is callee
            assert isinstance(caller_arg, IntLiteral) or caller_arg.scope is caller
loki-ecmwf-0.3.6/loki/tests/test_cmake.py0000664000175000017500000002102115167130205020526 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Functional tests for cmake macros.
"""

import os
from pathlib import Path
import re
import shutil
from subprocess import CalledProcessError, run, PIPE, STDOUT
from contextlib import contextmanager
import pytest
import tomli_w

from loki import gettempdir, execute, graphviz_present


def check_cmake():
    """
    Check if CMake is available
    """
    # TODO: Check CMake version
    try:
        execute(['cmake', '--version'], silent=True)
    except CalledProcessError:
        return False
    return True


pytest.mark.skipif(not check_cmake(), reason='CMake not available')

@pytest.fixture(scope='module', name='tmp_dir')
def fixture_tmp_dir():
    """Return a test module lifetime tmp directory"""
    tmp_dir = gettempdir()/'test_cmake'
    if tmp_dir.exists():
        shutil.rmtree(tmp_dir)
    tmp_dir.mkdir()
    yield tmp_dir
    if tmp_dir.exists():
        shutil.rmtree(tmp_dir)


@pytest.fixture(scope='module', name='here')
def fixture_here():
    """Current test directory"""
    return Path(__file__).parent


@pytest.fixture(scope='module', name='silent')
def fixture_silent(pytestconfig):
    """Whether to run commands without output"""
    return pytestconfig.getoption("verbose") == 0


@pytest.fixture(scope='module', name='srcdir')
def fixture_srcdir(here):
    """Base directory of CMake sources"""
    return here/'sources'


@pytest.fixture(scope='module', name='config')
def fixture_config(tmp_dir):
    """
    Write default configuration as a temporary file and return
    the file path
    """
    default_config = {
        'default': {
            'role': 'kernel',
            'expand': True,
            'strict': True,
            'enable_imports': True
        },
        'routines': {
            'driverB': {'role': 'driver'},
        },
        'transformations': {
            'IdemTrafo': {
                'classname': 'IdemTransformation',
                'module': 'loki.transformations.idempotence',
            },
            'FileWriteTransformation': {
                'classname': 'FileWriteTransformation',
                'module': 'loki.transformations.build_system',
                'options': {
                    'include_module_var_imports': True
                }
            }
        },
        'pipelines': {
            'idem': {
                'transformations': ['IdemTrafo']
            }
        }
    }
    filepath = tmp_dir/'test_cmake_loki.config'
    filepath.write_text(tomli_w.dumps(default_config))
    yield filepath
    filepath.unlink()


@pytest.fixture(scope='module', name='ecbuild')
def fixture_ecbuild(tmp_dir):
    """
    Download ecbuild
    """
    ecbuilddir = tmp_dir/'ecbuild'
    if ecbuilddir.exists():
        shutil.rmtree(ecbuilddir)
    execute(['git', 'clone', 'https://github.com/ecmwf/ecbuild.git', str(ecbuilddir)])
    yield ecbuilddir
    shutil.rmtree(ecbuilddir)


@pytest.fixture(scope='module', name='loki_artifacts_and_env', params=[True, False])
def fixture_loki_artifacts_and_env(here, tmp_dir, silent, request):
    """
    Download wheels using the populate mechanism and provide the artifacts dir
    """
    artifacts_dir = tmp_dir/'artifacts'
    if artifacts_dir.exists():
        shutil.rmtree(artifacts_dir)

    cmake_args = []
    env = os.environ.copy()
    if request.param:
        env['ARTIFACTS_DIR'] = str(artifacts_dir)
        env['LOKI_INSTALL_OPTIONS'] = '[tests]'
        execute(['./populate'], silent=silent, cwd=str(here.parent.parent), env=env)
        cmake_args += [f'-DARTIFACTS_DIR={artifacts_dir}']
        # Set http_proxy and https_proxy to nonsense, which should prevent PIP from connecting
        # to a package index during the configure step
        env['http_proxy'] = 'http://foo.bar.baz'
        env['https_proxy'] = 'http://foo.bar.baz'

    yield cmake_args, env

    if artifacts_dir.exists():
        shutil.rmtree(artifacts_dir)


@pytest.fixture(scope='module', name='loki_install', params=['editable', 'relative_install', 'default'])
def fixture_loki_install(here, tmp_dir, ecbuild, loki_artifacts_and_env, silent, request):
    """
    Install Loki using CMake into an install directory
    """
    builddir = tmp_dir/'loki_bootstrap'
    installdir = tmp_dir/'loki'
    artifacts_arg, env = loki_artifacts_and_env
    cmd = [
        'cmake', f'-DCMAKE_MODULE_PATH={ecbuild}/cmake',
        '-S', str(here.parent.parent),
        '-B', str(builddir)
    ]
    cmd += artifacts_arg
    if request.param == 'editable':
        cmd += ['-DENABLE_EDITABLE=ON']
    else:
        cmd += ['-DENABLE_EDITABLE=OFF']

    execute(cmd, silent=silent, cwd=tmp_dir, env=env)

    if request.param == 'editable':
        # Ensure Loki is installed in editable mode
        ps = run([str(builddir/'loki_env/bin/pip'), 'list'], stdout=PIPE, stderr=STDOUT, check=True)
        assert str(here.parent.parent) in ps.stdout.decode()

    if request.param == 'relative_install':
        prefix = 'loki'
    else:
        prefix = installdir
    execute(
        ['cmake', '--install', str(builddir), '--prefix', str(prefix)],
        silent=True, cwd=tmp_dir, env=env
    )

    yield builddir, installdir


@contextmanager
def clean_builddir(builddir):
    """
    Clean the build directory in the temp directory
    """
    builddir = Path(builddir)
    if builddir.exists():
        shutil.rmtree(builddir)
    builddir.mkdir()
    yield builddir


@pytest.fixture(scope='module', name='cmake_project')
def fixture_cmake_project(here, config, srcdir):
    """
    Create a CMake project and set-up paths
    """
    proj_a = '${CMAKE_CURRENT_SOURCE_DIR}/projA'
    proj_b = '${CMAKE_CURRENT_SOURCE_DIR}/projB'

    file_content = f"""
cmake_minimum_required( VERSION 3.19 FATAL_ERROR )
find_package( ecbuild REQUIRED )

project( cmake_test VERSION 1.0.0 LANGUAGES Fortran )

ecbuild_find_package( loki REQUIRED )

loki_transform_plan(
    MODE      idem
    CONFIG    {config}
    SOURCEDIR ${{CMAKE_CURRENT_SOURCE_DIR}}
    CALLGRAPH ${{CMAKE_CURRENT_BINARY_DIR}}/loki_callgraph
    PLAN      ${{CMAKE_CURRENT_BINARY_DIR}}/loki_plan.cmake
    SOURCES
        {proj_a}
        {proj_b}
)
    """
    filepath = srcdir/'CMakeLists.txt'
    filepath.write_text(file_content)

    # Create a symlink to loki
    (srcdir/'loki').symlink_to(here.parent)

    yield filepath

    filepath.unlink()
    (srcdir/'loki').unlink()


def test_cmake_plan(srcdir, tmp_dir, config, cmake_project, loki_install, ecbuild, silent):
    """
    Test the `loki_transform_plan` CMake function with a single task
    graph spanning two projects

    projA: driverB -> kernelB -> compute_l1 -> compute_l2
                         |
    projB:          ext_driver -> ext_kernel
    """
    plan_pattern = re.compile(r'set\(\s*(\w+)\s*(.*?)\s*\)', re.DOTALL)

    assert config.exists()
    assert cmake_project.exists()

    for loki_root in loki_install:
        with clean_builddir(tmp_dir/'test_cmake_plan') as builddir:
            execute(
                [f'{ecbuild}/bin/ecbuild', str(srcdir), f'-Dloki_ROOT={loki_root}'],
                cwd=builddir, silent=silent
            )

            # Make sure the plan files have been created
            assert (builddir/'loki_plan.cmake').exists()
            if graphviz_present():
                assert (builddir/'loki_callgraph.pdf').exists()

            # Validate the content of the plan file
            loki_plan = (builddir/'loki_plan.cmake').read_text()
            plan_dict = {k: v.split() for k, v in plan_pattern.findall(loki_plan)}
            plan_dict = {k: {Path(s).stem for s in v} for k, v in plan_dict.items()}

            expected_files = {
                'driverB_mod', 'kernelB_mod',
                'compute_l1_mod', 'compute_l2_mod',
                'ext_driver_mod', 'ext_kernel',
                'header_mod'
            }

            assert 'LOKI_SOURCES_TO_TRANSFORM' in plan_dict
            assert plan_dict['LOKI_SOURCES_TO_TRANSFORM'] == expected_files

            assert 'LOKI_SOURCES_TO_REMOVE' in plan_dict
            assert plan_dict['LOKI_SOURCES_TO_REMOVE'] == expected_files

            assert 'LOKI_SOURCES_TO_APPEND' in plan_dict
            assert plan_dict['LOKI_SOURCES_TO_APPEND'] == {
                f'{name}.idem' for name in expected_files
            }

        shutil.rmtree(loki_root)
loki-ecmwf-0.3.6/loki/tests/test_modules.py0000664000175000017500000013346615167130205021137 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki import Module, Subroutine, fexprgen, fgen
from loki.jit_build import jit_compile_lib
from loki.expression import symbols as sym
from loki.frontend import available_frontends, OMNI
from loki.ir import (
    nodes as ir, FindNodes, FindInlineCalls, FindTypedSymbols,
    FindVariables, SubstituteExpressions, Transformer
)
from loki.sourcefile import Sourcefile
from loki.types import BasicType, DerivedType, SymbolAttributes


@pytest.mark.parametrize('frontend', available_frontends())
def test_module_from_source(frontend, tmp_path):
    """
    Test the creation of `Module` objects from raw source strings.
    """
    fcode = """
module a_module
  integer, parameter :: x = 2
  integer, parameter :: y = 3

  type derived_type
    real :: array(x, y)
  end type derived_type
contains

  subroutine my_routine(pt)
    type(derived_type) :: pt
    pt%array(:,:) = 42.0
  end subroutine my_routine
end module a_module
""".strip()
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    assert len([o for o in module.spec.body if isinstance(o, ir.VariableDeclaration)]) == 2
    assert len([o for o in module.spec.body if isinstance(o, ir.TypeDef)]) == 1
    assert 'derived_type' in module.typedef_map
    assert len(module.routines) == 1
    assert module.routines[0].name == 'my_routine'
    if frontend != OMNI:
        assert module.source.string == fcode
        assert module.source.lines == (1, fcode.count('\n') + 1)


@pytest.mark.parametrize('frontend', available_frontends())
def test_module_external_typedefs_subroutine(frontend, tmp_path):
    """
    Test that externally provided type information is correctly
    attached to a `Module` subroutine when supplied via the `typedefs`
    parameter in the constructor.
    """
    fcode_external = """
module external_mod
  integer, parameter :: x = 2
  integer, parameter :: y = 3

  type ext_type
    real :: array(x, y)
  end type ext_type
end module external_mod
"""

    fcode_module = """
module a_module
contains

  subroutine my_routine(pt_ext)
    use external_mod, only: ext_type
    implicit none

    type(ext_type) :: pt_ext
    pt_ext%array(:,:) = 42.0
  end subroutine my_routine
end module a_module
"""

    external = Module.from_source(fcode_external, frontend=frontend, xmods=[tmp_path])
    assert 'ext_type' in external.typedef_map

    module = Module.from_source(fcode_module, frontend=frontend, definitions=external, xmods=[tmp_path])
    routine = module.subroutines[0]
    pt_ext = routine.variables[0]

    # OMNI resolves explicit shape parameters in the frontend parser
    exptected_array_shape = '(2, 3)' if frontend == OMNI else '(x, y)'

    # Check that the `array` variable in the `ext` type is found and
    # has correct type and shape info
    assert 'array' in pt_ext.variable_map
    a = pt_ext.variable_map['array']
    assert a.type.dtype == BasicType.REAL
    assert fexprgen(a.shape) == exptected_array_shape

    # Check the LHS of the assignment has correct meta-data
    stmt = FindNodes(ir.Assignment).visit(routine.body)[0]
    pt_ext_arr = stmt.lhs
    assert pt_ext_arr.type.dtype == BasicType.REAL
    assert fexprgen(pt_ext_arr.shape) == exptected_array_shape


@pytest.mark.parametrize('frontend', available_frontends())
def test_module_external_typedefs_type(frontend, tmp_path):
    """
    Test that externally provided type information is correctly
    attached to a `Module` type and used in a contained subroutine
    when supplied via the `typedefs` parameter in the constructor.
    """
    fcode_external = """
module external_mod
  integer, parameter :: x = 2
  integer, parameter :: y = 3

  type ext_type
    real :: array(x, y)
  end type ext_type
end module external_mod
"""

    fcode_other = """
module other_mod
  integer, parameter :: z = 4

  type other_type
    real :: vector(z)
  end type other_type
end module other_mod
    """.strip()

    fcode_module = """
module a_module
  use external_mod, only: ext_type
  use other_mod
  implicit none

  type nested_type
    type(ext_type) :: ext
  end type nested_type
contains

  subroutine my_routine(pt)
    type(nested_type) :: pt
    pt%ext%array(:,:) = 42.0
  end subroutine my_routine

  subroutine other_routine(pt)
    type(other_type) :: pt
    pt%vector(:) = 13.37
  end subroutine other_routine
end module a_module
"""

    external = Module.from_source(fcode_external, frontend=frontend, xmods=[tmp_path])
    assert 'ext_type' in external.typedef_map

    other = Module.from_source(fcode_other, frontend=frontend, xmods=[tmp_path])
    assert 'other_type' in other.typedef_map

    if frontend != OMNI:  # OMNI needs to know imported modules
        module = Module.from_source(fcode_module, frontend=frontend)
        assert 'ext_type' in module.symbol_attrs
        assert module.symbol_attrs['ext_type'].dtype is BasicType.DEFERRED
        assert 'other_type' not in module.symbol_attrs
        assert 'other_type' not in module['other_routine'].symbol_attrs
        assert module['other_routine'].symbol_attrs['pt'].dtype.typedef is BasicType.DEFERRED

    module = Module.from_source(fcode_module, frontend=frontend, definitions=[external, other], xmods=[tmp_path])
    nested = module.typedef_map['nested_type']
    ext = nested.variables[0]

    # Verify correct attachment of type information
    assert 'ext_type' in module.symbol_attrs
    assert isinstance(module.symbol_attrs['ext_type'].dtype.typedef, ir.TypeDef)
    assert isinstance(nested.symbol_attrs['ext'].dtype.typedef, ir.TypeDef)
    assert isinstance(module['my_routine'].symbol_attrs['pt'].dtype.typedef, ir.TypeDef)
    assert isinstance(module['my_routine'].symbol_attrs['pt%ext'].dtype.typedef, ir.TypeDef)
    assert 'other_type' in module.symbol_attrs
    assert 'other_type' not in module['other_routine'].symbol_attrs
    assert isinstance(module.symbol_attrs['other_type'].dtype.typedef, ir.TypeDef)
    assert isinstance(module['other_routine'].symbol_attrs['pt'].dtype.typedef, ir.TypeDef)

    # OMNI resolves explicit shape parameters in the frontend parser
    exptected_array_shape = '(2, 3)' if frontend == OMNI else '(x, y)'

    # Check that the `array` variable in the `ext` type is found and
    # has correct type and shape info
    assert 'array' in ext.variable_map
    a = ext.variable_map['array']
    assert a.type.dtype == BasicType.REAL
    assert fexprgen(a.shape) == exptected_array_shape

    # Check the routine has got type and shape info too
    routine = module['my_routine']
    pt = routine.variables[0]
    pt_ext = pt.variable_map['ext']
    assert 'array' in pt_ext.variable_map
    pt_ext_a = pt_ext.variable_map['array']
    assert pt_ext_a.type.dtype == BasicType.REAL
    assert fexprgen(pt_ext_a.shape) == exptected_array_shape

    # Check the LHS of the assignment has correct meta-data
    stmt = FindNodes(ir.Assignment).visit(routine.body)[0]
    pt_ext_arr = stmt.lhs
    assert pt_ext_arr.type.dtype == BasicType.REAL
    assert fexprgen(pt_ext_arr.shape) == exptected_array_shape


@pytest.mark.parametrize('frontend', available_frontends())
def test_module_nested_types(frontend, tmp_path):
    """
    Test that ensure that nested internal derived type definitions are
    detected and connected correctly.
    """

    fcode = """
module type_mod
  integer, parameter :: x = 2
  integer, parameter :: y = 3

  type sub_type
    real :: array(x, y)
  end type sub_type

  type parent_type
    type(sub_type) :: pt
  end type parent_type
end module type_mod
"""
    # OMNI resolves explicit shape parameters in the frontend parser
    exptected_array_shape = '(2, 3)' if frontend == OMNI else '(x, y)'

    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    parent = module.typedef_map['parent_type']
    pt = parent.variables[0]
    assert 'array' in pt.variable_map
    arr = pt.variable_map['array']
    assert arr.type.dtype == BasicType.REAL
    assert fexprgen(arr.shape) == exptected_array_shape


@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'Loki annotation break parser')]))
def test_dimension_pragmas(frontend, tmp_path):
    """
    Test that loki-specific dimension annotations are detected and
    used to set shapes.
    """

    fcode = """
module type_mod
  implicit none
  type mytype
    !$loki dimension(size)
    integer, pointer :: x(:)
  end type mytype
end module type_mod
"""
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    mytype = module.typedef_map['mytype']
    assert fexprgen(mytype.variables[0].shape) == '(size,)'


@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'Loki annotation break parser')]))
def test_nested_types_dimension_pragmas(frontend, tmp_path):
    """
    Test that loki-specific dimension annotations are detected and
    propagated in nested type definitions.
    """

    fcode = """
module type_mod
  implicit none
  type sub_type
    !$loki dimension(size)
    integer, pointer :: x(:)
  end type sub_type

  type parent_type
    type(sub_type) :: pt
  end type parent_type
end module type_mod
"""
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    parent = module.typedef_map['parent_type']
    child = module.typedef_map['sub_type']
    assert fexprgen(child.variables[0].shape) == '(size,)'

    pt_x = parent.variables[0].variable_map['x']
    assert fexprgen(pt_x.shape) == '(size,)'


@pytest.mark.parametrize('frontend', available_frontends())
def test_internal_function_call(frontend, tmp_path):
    """
    Test the use of `InlineCall` symbols linked to an module function.
    """
    fcode = """
module module_mod
  implicit none
  integer, parameter :: jprb = selected_real_kind(13,300)

contains

  subroutine test_inline_call(v1, v2, v3)
    implicit none

    integer, intent(in) :: v1
    real(kind=jprb), intent(in) :: v2
    real(kind=jprb), intent(out) :: v3

    v3 = util_fct(v2, v1)
  end subroutine test_inline_call

  function util_fct(var, mode)
    real(kind=jprb) :: util_fct
    integer, intent(in) :: var
    real(kind=jprb), intent(in) :: mode

    if (mode == 1) then
      util_fct = var + 2_jprb
    else
      util_fct = var + 3_jprb
    end if
  end function util_fct

end module
"""
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    routine = module['test_inline_call']

    inline_calls = list(FindInlineCalls().visit(routine.body))
    assert len(inline_calls) == 1
    assert inline_calls[0].function.name == 'util_fct'
    assert inline_calls[0].parameters[0] == 'v2'
    assert inline_calls[0].parameters[1] == 'v1'

    assert isinstance(module.symbol_attrs['util_fct'].dtype.procedure, Subroutine)
    assert module.symbol_attrs['util_fct'].dtype.is_function


@pytest.mark.parametrize('frontend', available_frontends())
def test_external_function_call(frontend, tmp_path):
    """
    Test the use of `InlineCall` symbols linked to an external function definition.
    """
    fcode = """
subroutine test_inline_call(v1, v2, v3)
  use util_mod, only: util_fct
  implicit none

  integer, parameter :: jprb = selected_real_kind(13,300)
  integer, intent(in) :: v1
  real(kind=jprb), intent(in) :: v2
  real(kind=jprb), intent(out) :: v3

  v3 = util_fct(v2, v1)
end subroutine test_inline_call
"""

    fcode_util = """
module util_mod
  integer, parameter :: jprb = selected_real_kind(13,300)

contains
  function util_fct(var, mode)
    real(kind=jprb) :: util_fct
    integer, intent(in) :: var
    real(kind=jprb), intent(in) :: mode

    if (mode == 1) then
      util_fct = var + 2_jprb
    else
      util_fct = var + 3_jprb
    end if
  end function util_fct
end module
"""
    module = Module.from_source(fcode_util, frontend=frontend, xmods=[tmp_path])
    routine = Subroutine.from_source(fcode, definitions=module, frontend=frontend, xmods=[tmp_path])

    inline_calls = list(FindInlineCalls().visit(routine.body))
    assert len(inline_calls) == 1
    assert inline_calls[0].function.name == 'util_fct'
    assert inline_calls[0].parameters[0] == 'v2'
    assert inline_calls[0].parameters[1] == 'v1'


@pytest.mark.parametrize('frontend', available_frontends())
def test_module_variables_add_remove(frontend, tmp_path):
    """
    Test local variable addition and removal.
    """
    fcode = """
module module_variables_add_remove
  implicit none
  integer, parameter :: jprb = selected_real_kind(13,300)
  integer :: x, y
  real(kind=jprb), allocatable :: vector(:)
end module module_variables_add_remove
"""
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    module_vars = [str(arg) for arg in module.variables]
    assert module_vars == ['jprb', 'x', 'y', 'vector(:)']

    # Create a new set of variables and add to local routine variables
    x = module.variable_map['x']  # That's the symbol for variable 'x'
    real_type = SymbolAttributes('real', kind=module.variable_map['jprb'])
    int_type = SymbolAttributes('integer')
    a = sym.Variable(name='a', type=real_type, scope=module)
    b = sym.Variable(name='b', dimensions=(x, ), type=real_type, scope=module)
    c = sym.Variable(name='c', type=int_type, scope=module)

    # Add new variables and check that they are all in the module spec
    module.variables += (a, b, c)
    if frontend == OMNI:
        # OMNI frontend inserts a few peculiarities
        assert fgen(module.spec).lower() == """
integer, parameter :: jprb = selected_real_kind(13, 300)
integer :: x
integer :: y
real(kind=selected_real_kind(13, 300)), allocatable :: vector(:)
real(kind=jprb) :: a
real(kind=jprb) :: b(x)
integer :: c
""".strip().lower()

    else:
        assert fgen(module.spec).lower() == """
implicit none
integer, parameter :: jprb = selected_real_kind(13, 300)
integer :: x, y
real(kind=jprb), allocatable :: vector(:)
real(kind=jprb) :: a
real(kind=jprb) :: b(x)
integer :: c
""".strip().lower()

    # Now remove the `vector` variable and make sure it's gone
    module.variables = [v for v in module.variables if v.name != 'vector']
    assert 'vector' not in fgen(module.spec).lower()
    module_vars = [str(arg) for arg in module.variables]
    assert module_vars == ['jprb', 'x', 'y', 'a', 'b(x)', 'c']


@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'Parsing fails without dummy module provided')]))
def test_module_rescope_symbols(frontend, tmp_path):
    """
    Test the rescoping of variables.
    """
    fcode = """
module test_module_rescope
  use some_mod, only: ext1
  implicit none
  integer :: a, b, c
end module test_module_rescope
    """.strip()

    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    ref_fgen = fgen(module)

    # Create a copy of the module with rescoping and make sure all symbols are in the right scope
    spec = Transformer().visit(module.spec)
    module_copy = Module(name=module.name, spec=spec, rescope_symbols=True)

    for var in FindTypedSymbols().visit(module_copy.spec):
        assert var.scope is module_copy

    # Create another copy of the nested subroutine without rescoping
    spec = Transformer().visit(module.spec)
    other_module_copy = Module(name=module.name, spec=spec)

    # Explicitly throw away type information from original module
    module.symbol_attrs.clear()
    assert all(var.type is None for var in other_module_copy.variables)
    assert all(var.scope is not None for var in other_module_copy.variables)

    # fgen of the rescoped copy should work
    assert fgen(module_copy) == ref_fgen

    # fgen of the not rescoped copy should fail because the scope of the variables went away
    with pytest.raises(AttributeError):
        fgen(other_module_copy)


@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'Parsing fails without dummy module provided')]))
def test_module_rescope_clone(frontend, tmp_path):
    """
    Test the rescoping of variables in clone.
    """
    fcode = """
module test_module_rescope_clone
  use some_mod, only: ext1
  implicit none
  integer :: a, b, c
end module test_module_rescope_clone
    """.strip()

    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    ref_fgen = fgen(module)

    # Create a copy of the module with rescoping and make sure all symbols are in the right scope
    module_copy = module.clone()

    for var in FindTypedSymbols().visit(module_copy.spec):
        assert var.scope is module_copy

    # Create another copy of the nested subroutine without rescoping
    other_module_copy = module.clone(rescope_symbols=False, symbol_attrs=None)

    # Explicitly throw away type information from original module
    module.symbol_attrs.clear()
    assert all(var.type is None for var in other_module_copy.variables)
    assert all(var.scope is not None for var in other_module_copy.variables)

    # fgen of the rescoped copy should work
    assert fgen(module_copy) == ref_fgen

    # fgen of the not rescoped copy should fail because the scope of the variables went away
    with pytest.raises(AttributeError):
        fgen(other_module_copy)

@pytest.mark.parametrize('frontend', available_frontends(
    xfail=[(OMNI, 'Parsing fails without dummy module provided')]
))
def test_module_deep_clone(frontend, tmp_path):
    """
    Test the rescoping of variables in clone with nested scopes.
    """
    fcode = """
module test_module_rescope_clone
  use parkind1, only : jpim, jprb
  implicit none

  integer :: n

  real :: array(n)

  type my_type
    real :: vector(n)
    real :: matrix(n, n)
  end type

end module test_module_rescope_clone
"""
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    # Deep-copy/clone the module
    new_module = module.clone()

    n = [v for v in FindVariables().visit(new_module.spec) if v.name == 'n'][0]
    n_decl = FindNodes(ir.VariableDeclaration).visit(new_module.spec)[0]

    # Remove the declaration of `n` and replace it with `3`
    new_module.spec = Transformer({n_decl: None}).visit(new_module.spec)
    new_module.spec = SubstituteExpressions({n: sym.Literal(3)}).visit(new_module.spec)

    # Check the new module has been changed
    assert len(FindNodes(ir.VariableDeclaration).visit(new_module.spec)) == 1
    new_type_decls = FindNodes(ir.VariableDeclaration).visit(new_module['my_type'].body)
    assert len(new_type_decls) == 2
    assert new_type_decls[0].symbols[0] == 'vector(3)'
    assert new_type_decls[1].symbols[0] == 'matrix(3, 3)'

    # Check the old one has not changed
    assert len(FindNodes(ir.VariableDeclaration).visit(module.spec)) == 2
    type_decls = FindNodes(ir.VariableDeclaration).visit(module['my_type'].body)
    assert len(type_decls) == 2
    assert type_decls[0].symbols[0] == 'vector(n)'
    assert type_decls[1].symbols[0] == 'matrix(n, n)'


@pytest.mark.parametrize('frontend', available_frontends())
def test_module_access_spec_none(frontend, tmp_path):
    """
    Test correct parsing without access-spec statements
    """
    fcode = """
module test_access_spec_mod
    implicit none

    integer pub_var = 1
contains
    subroutine routine
        integer i
        i = pub_var
    end subroutine routine
end module test_access_spec_mod
    """.strip()

    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    # Check module properties
    assert module.default_access_spec is None
    assert isinstance(module.public_access_spec, tuple) and not module.public_access_spec
    assert isinstance(module.private_access_spec, tuple) and not module.private_access_spec

    # Check backend output
    code = module.to_fortran().upper()
    assert 'PUBLIC' not in code
    assert 'PRIVATE' not in code

    # Check that property has not propagated to symbol type
    pub_var = module.variable_map['pub_var']
    assert pub_var.type.public is None
    assert pub_var.type.private is None

    # Check properties after clone
    new_module = module.clone(
        default_access_spec='PUBLIC', public_access_spec='PUB_VAR',
        private_access_spec='ROUTINE'
    )
    assert new_module.default_access_spec == 'public'
    assert new_module.public_access_spec == ('pub_var',)
    assert new_module.private_access_spec == ('routine',)


@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'Inlines access-spec as declaration attr')]))
def test_module_access_spec_private(frontend, tmp_path):
    """
    Test correct parsing of access-spec statements with default private
    """
    fcode = """
module test_access_spec_mod
    implicit none
    private
    public :: pub_var, routine
    PRIVATE OTHER_PRIVATE_VAR

    integer pub_var = 1
    integer private_var = 2
    integer other_private_var = 3
contains
    subroutine routine
        integer i
        i = pub_var
    end subroutine routine
end module test_access_spec_mod
    """.strip()

    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    # Check module properties
    assert module.default_access_spec == 'private'
    assert module.public_access_spec == ('pub_var', 'routine')
    assert module.private_access_spec == ('other_private_var',)

    # Check backend output
    code = module.to_fortran().upper()
    assert 'PUBLIC\n' not in code
    assert 'PUBLIC :: PUB_VAR, ROUTINE' in code
    assert 'PRIVATE\n' in code
    assert 'PRIVATE :: OTHER_PRIVATE_VAR' in code

    # Check that property has not propagated to symbol type
    pub_var = module.variable_map['pub_var']
    assert pub_var.type.public is None
    assert pub_var.type.private is None

    # Check properties after clone
    new_module = module.clone(private_access_spec=None)
    assert new_module.default_access_spec == 'private'
    assert new_module.public_access_spec == ('pub_var', 'routine')
    assert new_module.private_access_spec == ()


@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'Inlines access-spec as declaration attr')]))
def test_module_access_spec_public(frontend, tmp_path):
    """
    Test correct parsing of access-spec statements with default public
    """
    fcode = """
module test_access_spec_mod
    implicit none
    PUBLIC
    PUBLIC ROUTINE
    private :: private_var, other_private_var

    integer pub_var = 1
    integer private_var = 2
    integer other_private_var = 3
contains
    subroutine routine
        integer i
        i = pub_var
    end subroutine routine
end module test_access_spec_mod
    """.strip()

    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    # Check module properties
    assert module.default_access_spec == 'public'
    assert module.public_access_spec == ('routine', )
    assert module.private_access_spec == ('private_var', 'other_private_var')

    # Check backend output
    code = module.to_fortran().upper()
    assert 'PUBLIC\n' in code
    assert 'PUBLIC :: ROUTINE' in code
    assert 'PRIVATE\n' not in code
    assert 'PRIVATE :: PRIVATE_VAR, OTHER_PRIVATE_VAR' in code

    # Check that property has not propagated to symbol type
    pub_var = module.variable_map['pub_var']
    assert pub_var.type.public is None
    assert pub_var.type.private is None

    # Check properties after clone
    new_module = module.clone(
        default_access_spec='PRivate', public_access_spec=('ROUTINE', 'pub_var')
    )
    assert new_module.default_access_spec == 'private'
    assert new_module.public_access_spec == ('routine', 'pub_var')
    assert new_module.private_access_spec == ('private_var', 'other_private_var')


@pytest.mark.parametrize('frontend', available_frontends())
def test_module_access_attr(frontend, tmp_path):
    """
    Test correct parsing of access-spec attributes
    """
    fcode = """
module test_access_attr_mod
    implicit none
    private
    integer, public :: pub_var
    integer :: unspecified_var
    integer, private :: priv_var
    integer :: other_var
end module test_access_attr_mod
    """.strip()

    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    code = module.to_fortran().upper()

    priv_var = module.variable_map['priv_var']
    assert priv_var.type.private is True
    assert priv_var.type.public is None

    pub_var = module.variable_map['pub_var']
    assert pub_var.type.public is True
    assert pub_var.type.private is None

    unspecified_var = module.variable_map['unspecified_var']
    other_var = module.variable_map['other_var']

    assert unspecified_var.type.public is None
    assert other_var.type.public is None

    if frontend == OMNI:  # OMNI applies access spec to each variable
        assert code.count('PRIVATE') == 3
        assert unspecified_var.type.private is True
        assert other_var.type.private is True
    else:
        assert code.count('PRIVATE') == 2
        assert unspecified_var.type.private is None
        assert other_var.type.private is None
    assert code.count('PUBLIC') == 1


@pytest.mark.parametrize('frontend', available_frontends())
def test_module_rename_imports_with_definitions(frontend, tmp_path):
    """
    Test use statement with rename lists
    """
    fcode_mod1 = """
module test_rename_mod
    implicit none
    integer :: var1
    integer :: var2
    integer :: var3
end module test_rename_mod
    """.strip()

    fcode_mod2 = """
module test_other_rename_mod
    implicit none
    integer :: var1
    integer :: var2
    integer :: var3
end module test_other_rename_mod
    """.strip()

    fcode_mod3 = """
module some_mod
    use test_rename_mod, first_var1 => var1, first_var3 => var3
    use test_other_rename_mod, only: second_var1 => var1
    use test_other_rename_mod, only: other_var2 => var2, other_var3 => var3
    implicit none
end module some_mod
    """.strip()

    mod1 = Module.from_source(fcode_mod1, frontend=frontend, xmods=[tmp_path])
    mod2 = Module.from_source(fcode_mod2, frontend=frontend, xmods=[tmp_path])
    mod3 = Module.from_source(fcode_mod3, frontend=frontend, xmods=[tmp_path], definitions=[mod1, mod2])

    # Check all entries exist in the symbol table
    mod1_imports = {
        'first_var1': 'var1',
        'var2': None,
        'first_var3': 'var3'
    }
    mod2_imports = {
        'second_var1': 'var1',
        'other_var2': 'var2',
        'other_var3': 'var3'
    }
    expected_symbols = list(mod1_imports) + list(mod2_imports)
    for s in expected_symbols:
        assert s in mod3.symbol_attrs

    # Check that var1 has note been imported under that name
    assert 'var1' not in mod3.symbol_attrs

    # Verify correct symbol attributes
    for s, use_name in mod1_imports.items():
        assert mod3.symbol_attrs[s].imported
        assert mod3.symbol_attrs[s].module is mod1
        assert mod3.symbol_attrs[s].use_name == use_name
        assert mod3.symbol_attrs[s].compare(mod1.symbol_attrs[use_name or s], ignore=('imported', 'module', 'use_name'))
    for s, use_name in mod2_imports.items():
        assert mod3.symbol_attrs[s].imported
        assert mod3.symbol_attrs[s].module is mod2
        assert mod3.symbol_attrs[s].use_name == use_name
        assert mod3.symbol_attrs[s].compare(mod2.symbol_attrs[use_name or s], ignore=('imported', 'module', 'use_name'))

    # Verify Import IR node
    for imprt in FindNodes(ir.Import).visit(mod3.spec):
        if imprt.module == 'test_rename_mod':
            assert imprt.rename_list
            assert not imprt.symbols
            assert 'var1' in dict(imprt.rename_list)
            assert 'var3' in dict(imprt.rename_list)
        else:
            assert not imprt.rename_list
            assert imprt.symbols

    # Verify fgen output
    fcode = fgen(mod3)
    for s, use_name in mod1_imports.items():
        assert use_name is None or f'{s} => {use_name}' in fcode
    for s, use_name in mod2_imports.items():
        assert use_name is None or f'{s} => {use_name}' in fcode


@pytest.mark.parametrize('frontend', available_frontends())
def test_module_rename_imports_no_definitions(frontend, tmp_path):
    """
    Test use statement with rename lists when definitions are not available
    """
    fcode_mod1 = """
module test_rename_mod
    implicit none
    integer :: var1
    integer :: var2
    integer :: var3
end module test_rename_mod
    """.strip()

    fcode_mod2 = """
module test_other_rename_mod
    implicit none
    integer :: var1
    integer :: var2
    integer :: var3
end module test_other_rename_mod
    """.strip()

    _ = Module.from_source(fcode_mod1, frontend=frontend, xmods=[tmp_path])
    _ = Module.from_source(fcode_mod2, frontend=frontend, xmods=[tmp_path])

    fcode_mod3 = """
module some_mod
    use test_rename_mod, first_var1 => var1, first_var3 => var3
    use test_other_rename_mod, only: second_var1 => var1
    use test_other_rename_mod, only: other_var2 => var2, other_var3 => var3
    implicit none
end module some_mod
    """.strip()

    mod3 = Module.from_source(fcode_mod3, frontend=frontend, xmods=[tmp_path])

    # Check all entries exist in the symbol table
    mod1_imports = {
        'first_var1': 'var1',
        'first_var3': 'var3'
    }
    mod2_imports = {
        'second_var1': 'var1',
        'other_var2': 'var2',
        'other_var3': 'var3'
    }
    expected_symbols = list(mod1_imports) + list(mod2_imports)
    for s in expected_symbols:
        assert s in mod3.symbol_attrs

    # Check that var1 has note been imported under that name
    assert 'var1' not in mod3.symbol_attrs
    assert 'var2' not in mod3.symbol_attrs

    # Verify correct symbol attributes
    for s, use_name in mod1_imports.items():
        assert mod3.symbol_attrs[s].imported
        assert mod3.symbol_attrs[s].module is None
        assert mod3.symbol_attrs[s].use_name == use_name
    for s, use_name in mod2_imports.items():
        assert mod3.symbol_attrs[s].imported
        assert mod3.symbol_attrs[s].module is None
        assert mod3.symbol_attrs[s].use_name == use_name

    # Verify Import IR node
    for imprt in FindNodes(ir.Import).visit(mod3.spec):
        if imprt.module == 'test_rename_mod':
            assert imprt.rename_list
            assert not imprt.symbols
            assert 'var1' in dict(imprt.rename_list)
            assert 'var3' in dict(imprt.rename_list)
        else:
            assert not imprt.rename_list
            assert imprt.symbols

    # Verify fgen output
    fcode = fgen(mod3)
    for s, use_name in mod1_imports.items():
        assert use_name is None or f'{s} => {use_name}' in fcode
    for s, use_name in mod2_imports.items():
        assert use_name is None or f'{s} => {use_name}' in fcode


@pytest.mark.parametrize('frontend', available_frontends())
def test_module_use_module_nature(frontend, tmp_path):
    """
    Test module natures attributes in ``USE`` statements
    """
    mcode = """
module iso_fortran_env
    use, intrinsic :: iso_c_binding, only: int16 => c_int16_t
    implicit none
    integer, parameter :: int8 = int16
end module iso_fortran_env
    """.strip()

    fcode = """
module module_nature_mod
    implicit none
contains
    subroutine inquire_my_kinds(i8, i16)
        use, non_intrinsic :: iso_fortran_env, only: int8, int16
        integer, intent(out) :: i8, i16
        i8 = int8
        i16 = int16
    end subroutine inquire_my_kinds
    subroutine inquire_kinds(i8, i16)
        use, intrinsic :: iso_fortran_env, only: int8, int16
        integer, intent(out) :: i8, i16
        i8 = int8
        i16 = int16
    end subroutine inquire_kinds
end module module_nature_mod
    """.strip()

    ext_mod = Module.from_source(mcode, frontend=frontend, xmods=[tmp_path])

    # Check properties on the Import IR node in the external module
    assert ext_mod.imported_symbols == ('int16',)
    imprt = FindNodes(ir.Import).visit(ext_mod.spec)[0]
    assert imprt.nature.lower() == 'intrinsic'
    assert imprt.module.lower() == 'iso_c_binding'
    assert ext_mod.imported_symbol_map['int16'].type.imported is True
    assert ext_mod.imported_symbol_map['int16'].type.module is None

    if frontend == OMNI:
        # OMNI throws Syntax Error on NON_INTRINSIC...
        fcode = fcode.replace('use, non_intrinsic ::', 'use')

    mod = Module.from_source(fcode, frontend=frontend, definitions=[ext_mod], xmods=[tmp_path])

    # Check properties on the Import IR node in both routines
    my_kinds = mod['inquire_my_kinds']
    kinds = mod['inquire_kinds']

    assert set(my_kinds.imported_symbols) == {'int8', 'int16'}
    assert set(kinds.imported_symbols) == {'int8', 'int16'}

    my_import_map = {s.name: imprt for imprt in FindNodes(ir.Import).visit(my_kinds.spec) for s in imprt.symbols}
    import_map = {s.name: imprt for imprt in FindNodes(ir.Import).visit(kinds.spec) for s in imprt.symbols}

    assert my_import_map['int8'] is my_import_map['int16']
    assert import_map['int8'] is import_map['int16']

    if frontend == OMNI:
        assert my_import_map['int8'].nature is None
    else:
        assert my_import_map['int8'].nature.lower() == 'non_intrinsic'
    assert my_import_map['int8'].module.lower() == 'iso_fortran_env'
    assert import_map['int8'].nature.lower() == 'intrinsic'
    assert import_map['int8'].module.lower() == 'iso_fortran_env'

    # Check type annotations for imported symbols
    assert all(s.type.imported is True for s in my_kinds.imported_symbols)
    assert all(s.type.imported is True for s in kinds.imported_symbols)

    assert my_kinds.imported_symbol_map['int8'].type.module is ext_mod
    assert my_kinds.imported_symbol_map['int16'].type.module is ext_mod

    assert kinds.imported_symbol_map['int8'].type.module is None
    assert kinds.imported_symbol_map['int16'].type.module is None

    # Sanity check fgen
    assert 'use, intrinsic' in ext_mod.to_fortran().lower()
    if frontend != OMNI:
        assert 'use, non_intrinsic' in my_kinds.to_fortran().lower()
    assert 'use, intrinsic' in kinds.to_fortran().lower()

    # Verify JIT compile
    file_paths = []
    for _mod in [ext_mod, mod]:
        filepath = tmp_path/f'{_mod.name}.f90'
        filepath.write_text(_mod.to_fortran())
        file_paths += [filepath]
    lib = jit_compile_lib([ext_mod, mod], path=tmp_path, name=mod.name)
    my_kinds_func = lib.module_nature_mod.inquire_my_kinds
    kinds_func = lib.module_nature_mod.inquire_kinds

    my_i8, my_i16 = my_kinds_func()
    i8, i16 = kinds_func()

    assert my_i8 == my_i16
    assert i8 < i16
    assert my_i8 == i16
    assert my_i8 == lib.iso_fortran_env.int8


@pytest.mark.parametrize('spec,part_lengths', [
    ('', (0, 0, 0)),
    ("""
implicit none
integer :: var1
integer :: var2
integer :: var3
    """.strip(), (0, 1, 3)),
    ("""
use header_mod
implicit none
integer :: var1
    """.strip(), (1, 1, 1)),
    ("""
use header_mod
integer :: var1
    """.strip(), (1, 0, 1)),
])
@pytest.mark.parametrize('frontend', available_frontends())
def test_module_spec_parts(frontend, spec, part_lengths, tmp_path):
    """Test the :attr:`spec_parts` property of :class:`Module`"""

    header_mod_fcode = """
module header_mod
    implicit none
    integer, parameter :: param1 = 1
end module header_mod
    """.strip()
    header_mod = Module.from_source(header_mod_fcode, frontend=frontend, xmods=[tmp_path])

    docstring = '! This should become the doc string\n'
    fcode = f"""
module spec_parts
{docstring if frontend != OMNI else ''}{spec}
end module spec_parts
    """.strip()

    module = Module.from_source(fcode, definitions=header_mod, frontend=frontend, xmods=[tmp_path])
    assert isinstance(module.spec_parts, tuple)
    assert all(isinstance(p, tuple) for p in module.spec_parts)

    if frontend == OMNI:
        # OMNI removes any 'IMPLICIT' statements so the middle part is always empty
        part_lengths = (part_lengths[0], 0, part_lengths[2])
    else:
        # OMNI _conveniently_ puts any use statements _before_ the docstring for
        # absolutely zero sensible reasons, so it would be purely based on good luck
        # and favourable circumstances to extract the right amount of comments for the
        # docstring with that _fantastic_ frontend...
        assert isinstance(module.docstring, tuple) and len(module.docstring) == 1

    assert part_lengths == tuple(len(p) for p in module.spec_parts)


@pytest.mark.parametrize('frontend', available_frontends())
def test_module_comparison(frontend, tmp_path):
    """
    Test that string-equivalence works on relevant components.
    """

    fcode = """
module a_module
  integer, parameter :: x = 2
  integer, parameter :: y = 3

  type derived_type
    real :: array(x, y)
  end type derived_type
contains

  subroutine my_routine(pt)
    type(derived_type) :: pt
    pt%array(:,:) = 42.0
  end subroutine my_routine
end module a_module
"""

    # Two distinct string-equivalent subroutine objects
    m1 = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    m2 = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    assert m1.symbol_attrs == m2.symbol_attrs
    assert m1.spec == m2.spec
    assert m1.contains == m2.contains
    assert m1 == m2


@pytest.mark.parametrize('frontend', available_frontends())
def test_module_comparison_case_sensitive(frontend, tmp_path):
    """
    Test that semantic, but no string-equivalence evaluates as not eqal
    """

    fcode = """
module a_module
  integer, parameter :: x = 2
  integer, parameter :: y = 3

  type derived_type
    real :: array(x, y)
  end type derived_type
contains

  subroutine my_routine(pt)
    type(derived_type) :: pt
    pt%array(:,:) = 42.0
  end subroutine my_routine
end module a_module
"""

    # Two distinct string-equivalent subroutine objects
    m1 = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    m2 = Module.from_source(fcode.replace('pt%array', 'pT%aRrAy'), frontend=frontend, xmods=[tmp_path])

    assert not 'pT%aRrAy' in fgen(m1)
    if frontend != OMNI:  # OMNI always downcases!
        assert 'pT%aRrAy' in fgen(m2)

    # Since the routine is different the procedure type will be!
    assert not m1.symbol_attrs == m2.symbol_attrs
    # OMNI source file paths are affected by the string change, which
    # are attached and check to each source node object
    if frontend != OMNI:
        assert m1.spec == m2.spec
    assert not m1.contains == m2.contains
    assert not m1 == m2


@pytest.mark.parametrize('frontend', available_frontends())
def test_module_contains_auto_insert(frontend, tmp_path):
    """
    Test that `CONTAINS` keyword is automatically inserted into the `contains` section
    of a :any:`ProgramUnit` object.
    """
    fcode_mod = """
module empty_mod
    implicit none
end module empty_mod
    """.strip()
    fcode_routine1 = """
subroutine routine1
end subroutine routine1
    """.strip()
    fcode_routine2 = """
subroutine routine2
end subroutine routine2
    """.strip()

    module = Module.from_source(fcode_mod, frontend=frontend, xmods=[tmp_path])
    routine1 = Subroutine.from_source(fcode_routine1, frontend=frontend, xmods=[tmp_path])
    routine2 = Subroutine.from_source(fcode_routine2, frontend=frontend, xmods=[tmp_path])

    assert module.contains is None
    assert routine1.contains is None

    routine1 = routine1.clone(contains=routine2)
    assert isinstance(routine1.contains, ir.Section)
    assert isinstance(routine1.contains.body[0], ir.Intrinsic)
    assert routine1.contains.body[0].text == 'CONTAINS'

    module = module.clone(contains=routine1)
    assert isinstance(module.contains, ir.Section)
    assert isinstance(module.contains.body[0], ir.Intrinsic)
    assert module.contains.body[0].text == 'CONTAINS'


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('only_list', [True, False])
@pytest.mark.parametrize('complete_tree', [True, False])
def test_module_missing_imported_symbol(frontend, only_list, complete_tree, tmp_path):
    fcode_mod1 = """
module mod1
    implicit none
    integer, parameter :: a = 1, b=2
end module mod1
    """.strip()

    fcode_mod2 = f"""
module mod2
    use mod1{', only: a, b' if only_list else ''}
    implicit none
end module mod2
    """.strip()

    fcode_driver = """
subroutine driver
    use mod2, only: a, b
    implicit none
    integer c
    c = a + b
end subroutine driver
    """.strip()

    mod1 = Module.from_source(fcode_mod1, frontend=frontend, xmods=[tmp_path])
    if complete_tree:
        modules = [mod1]
    else:
        modules = []
    modules += [Module.from_source(fcode_mod2, frontend=frontend, definitions=modules, xmods=[tmp_path])]
    driver = Subroutine.from_source(fcode_driver, frontend=frontend, definitions=modules, xmods=[tmp_path])

    a = driver.symbol_map['a']
    b = driver.symbol_map['b']

    if complete_tree:
        assert isinstance(a, sym.Scalar)
        assert a.type.dtype is BasicType.INTEGER
        assert isinstance(b, sym.Scalar)
        assert b.type.dtype is BasicType.INTEGER
    else:
        assert isinstance(a, sym.DeferredTypeSymbol)
        assert a.type.dtype is BasicType.DEFERRED
        assert isinstance(b, sym.DeferredTypeSymbol)
        assert b.type.dtype is BasicType.DEFERRED

    assert a.type.imported
    assert b.type.imported
    assert a.type.module is modules[-1]
    assert b.type.module is modules[-1]


@pytest.mark.parametrize('frontend', available_frontends())
def test_module_all_imports(frontend, tmp_path):
    fcode = {
        'header_a': (
        #--------
"""
module module_all_imports_header_a_mod
implicit none

integer, parameter :: a = 1
integer, parameter :: b = 2
end module module_all_imports_header_a_mod
"""
        ).strip(),
        'header_b': (
        #--------
"""
module module_all_imports_header_b_mod
implicit none

integer, parameter :: a = 2
integer, parameter :: b = 1
end module module_all_imports_header_b_mod
"""
        ).strip(),
        'routine': (
        #-------
"""
module module_all_imports_routine_mod
    use module_all_imports_header_a_mod, only: a
    use module_all_imports_header_b_mod, only: b_b => b
    implicit none
contains
    subroutine routine
        use module_all_imports_header_a_mod, only: b
        use module_all_imports_header_b_mod, only: a
        implicit none
        integer val
        val = a + b + b_b
    end subroutine routine
end module module_all_imports_routine_mod
"""
        ).strip()
    }

    header_a = Module.from_source(fcode['header_a'], frontend=frontend, xmods=[tmp_path])
    header_b = Module.from_source(fcode['header_b'], frontend=frontend, xmods=[tmp_path])
    routine_mod = Module.from_source(
        fcode['routine'], definitions=(header_a, header_b), frontend=frontend, xmods=[tmp_path]
    )
    routine = routine_mod['routine']

    assert routine_mod.parents == ()
    assert routine.parents == (routine_mod,)

    assert routine_mod.all_imports == routine_mod.imports
    assert routine.all_imports == routine.imports + routine_mod.imports

    assert routine.symbol_map['a'].type.module is header_b
    assert routine_mod.symbol_map['a'].type.module is header_a
    assert routine.symbol_map['b'].type.module is header_a
    assert routine_mod.symbol_map['b_b'].type.module is header_b
    assert routine_mod.symbol_map['b_b'].type.use_name == 'b'


@pytest.mark.parametrize('frontend', available_frontends())
def test_module_enrichment_within_file(frontend, tmp_path):
    fcode = """
module foo
  implicit none
  integer, parameter :: j = 16

  contains
    integer function plus_one(v)
      implicit none
      integer, intent(in) :: v
      plus_one = v + 1
    end function plus_one
end module foo

module test
    use foo
    implicit none
    integer, parameter :: rk = selected_real_kind(12)
    integer, parameter :: ik = selected_int_kind(9)
contains
    subroutine calc (n, res)
        integer, intent(in) :: n
        real(kind=rk), intent(inout) :: res
        integer(kind=ik) :: i
        do i = 1, n
            res = res + plus_one(j)
        end do
    end subroutine calc
end module test
"""

    source = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    routine = source['calc']
    calls = list(FindInlineCalls().visit(routine.body))
    assert len(calls) == 1
    assert calls[0].function == 'plus_one'
    assert calls[0].function.type.imported
    assert calls[0].function.type.module is source['foo']
    assert calls[0].function.type.dtype.procedure is source['plus_one']
    if frontend != OMNI:
        # OMNI inlines parameters
        assert calls[0].arguments[0].type.dtype == BasicType.INTEGER
        assert calls[0].arguments[0].type.imported
        assert calls[0].arguments[0].type.parameter
        assert calls[0].arguments[0].type.initial == 16
        assert calls[0].arguments[0].type.module is source['foo']


@pytest.mark.parametrize('frontend', available_frontends())
def test_module_enrichment_typedefs(frontend, tmp_path):
    """ Test that module-level enrichment is propagated correctly """

    fcode_state_mod = """
module state_type_mod
  implicit none

  type state_type
    real, pointer, dimension(:,:) :: a
  end type state_type

end module state_type_mod
"""

    fcode_driver_mod = """
module driver_mod
  use state_type_mod, only: state_type
  implicit none

contains
  subroutine driver_routine(state)
    type(state_type), intent(inout) :: state

    state%a = 1

  end subroutine driver_routine
end module driver_mod
"""
    state_mod = Sourcefile.from_source(fcode_state_mod, frontend=frontend, xmods=[tmp_path])['state_type_mod']
    driver_mod = Sourcefile.from_source(fcode_driver_mod, frontend=frontend, xmods=[tmp_path])['driver_mod']
    driver = driver_mod['driver_routine']

    state = driver.variable_map['state']
    assert isinstance(state.type.dtype, DerivedType)
    assert state.type.dtype.typedef == BasicType.DEFERRED

    # Enrich typedef on the outer module Import
    driver_mod.enrich([state_mod], recurse=True)

    state = driver.variable_map['state']

    # Ensure type info has been propagated to inner subroutine
    assert isinstance(state.type.dtype, DerivedType)
    assert isinstance(state.type.dtype.typedef, ir.TypeDef)

    # Verify that we have the right symbol and type info
    assigns = FindNodes(ir.Assignment).visit(driver.body)
    assert len(assigns) == 1
    assert assigns[0].lhs.type.dtype == BasicType.REAL
    assert assigns[0].lhs.type.shape == (':', ':')
    assert isinstance(assigns[0].lhs, sym.Array)
    assert assigns[0].lhs.parent.type.dtype.typedef == state_mod['state_type']
loki-ecmwf-0.3.6/loki/tests/test_pickle.py0000664000175000017500000002304615167130205020726 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
A set of tests that ensures that serialisation/deserialisation via
pickle works and creates equivalent objects of various types.
"""
from pathlib import Path
from pickle import dumps, loads
import pytest

from loki import (
    Subroutine, Module, Sourcefile, SymbolAttributes, BasicType,
    Scope, AttachScopes
)
from loki.batch import Item
from loki.expression import symbols
from loki.frontend import available_frontends, OMNI


@pytest.fixture(scope='module', name='here')
def fixture_here():
    return Path(__file__).parent


def test_pickle_expression():
    """
    Ensure pickle-replication of Pymbolic-backed expressions.
    """
    # pylint: disable=no-member

    # Ensure basic variable components are picklable
    t = SymbolAttributes(BasicType.INTEGER)
    v1 = symbols.Variable(name='v1', type=t)
    assert v1.symbol == loads(dumps(v1.symbol))
    assert v1.type == loads(dumps(v1.type))
    assert v1 == loads(dumps(v1))

    # Now we add a scope to the expression and replicate both
    scope = Scope()
    v2 = symbols.Variable(name='v2', scope=scope, type=t)
    scope_new = loads(dumps(scope))
    v2_new = loads(dumps(v2))

    # Re-attach the new expression to the new scope
    v2_new = AttachScopes().visit(v2_new, scope=scope_new)

    assert len(scope_new.symbol_attrs) == 1
    assert 'v2' in scope_new.symbol_attrs
    assert scope_new.symbol_attrs['v2'] == t
    assert v2_new == v2

    # And now, one more time but with arrays!
    scope = Scope()
    i = symbols.Variable(name='i', scope=scope, type=t)
    v3 = symbols.Variable(name='v3', dimensions=(i,), scope=scope, type=t)
    scope_new = loads(dumps(scope))
    v3_new = loads(dumps(v3))
    v3_new = AttachScopes().visit(v3_new, scope=scope_new)

    assert len(scope_new.symbol_attrs) == 2
    assert 'v3' in scope_new.symbol_attrs
    assert 'i' in scope_new.symbol_attrs
    assert scope_new.symbol_attrs['v3'] == t
    assert v3_new == v3

    # Check that Literals are trivial replicated
    i = symbols.IntLiteral(value=1., kind='jpim')
    assert loads(dumps(i)) == i


@pytest.mark.parametrize('frontend', available_frontends())
def test_pickle_subroutine(frontend):
    """
    Ensure that :any:`Subroutine` and its components are picklable.
    """

    fcode = """
subroutine my_routine(n, a, b, d)
  integer, intent(in) :: n
  real, intent(in) :: a(n), b(n)
  real, intent(out) :: d(n)
  integer :: i

  do i=1, n
    d(i) = a(i) + b(i)
  end do
end subroutine my_routine
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    # First, replicate the scope individually, ...
    scope_new = Scope()
    scope_new.symbol_attrs.update(loads(dumps(routine.symbol_attrs)))

    # Replicate spec and body independently...
    spec_new = loads(dumps(routine.spec))
    spec_new = AttachScopes().visit(spec_new, scope=scope_new)
    assert spec_new == routine.spec

    body_new = loads(dumps(routine.body))
    body_new = AttachScopes().visit(body_new, scope=scope_new)
    assert body_new == routine.body

    # Ensure equivalence after pickle-cyle
    assert routine == loads(dumps(routine))


@pytest.mark.parametrize('frontend', available_frontends())
def test_pickle_module(frontend, tmp_path):
    """
    Ensure that serialisation/deserialisation via pickling works as expected.
    """

    fcode = """
module my_type_mod

  real(8) :: a, b
  integer :: s

end module my_type_mod
"""
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    # Ensure equivalence after pickle-cyle
    assert module.symbol_attrs == loads(dumps(module.symbol_attrs))
    assert module.spec == loads(dumps(module.spec))
    assert module.contains == loads(dumps(module.contains))
    assert module == loads(dumps(module))


@pytest.mark.parametrize('frontend', available_frontends())
def test_pickle_module_with_typedef(frontend, tmp_path):
    """
    Ensure that a type definition in a module is pickle-safe.
    """

    fcode = """
module my_type_mod

  type a_type
    real(kind=8) :: scalar
    real(kind=8) :: vector(3)
  end type a_type

  type(a_type) :: some_numbers

end module my_type_mod
"""
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    # Replicate the TypeDef individually
    typedef = module['a_type']
    typedef_new = loads(dumps(typedef))
    assert typedef_new == typedef

    # Replicate the scope individually
    scope_new = Scope()
    scope_new.symbol_attrs.update(loads(dumps(module.symbol_attrs)))

    # Replicate the spec independently...
    spec_new = loads(dumps(module.spec))
    spec_new = AttachScopes().visit(spec_new, scope=scope_new)
    assert spec_new == module.spec

    # Replicate the member type
    contains_new = loads(dumps(module.contains))
    contains_new = AttachScopes().visit(contains_new, scope=scope_new)
    assert contains_new == module.contains

    # Ensure equivalence after pickle-cyle
    assert module.symbol_attrs == loads(dumps(module.symbol_attrs))
    assert module.spec == loads(dumps(module.spec))
    assert module.contains == loads(dumps(module.contains))
    assert module == loads(dumps(module))


@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'No external module available')]))
def test_pickle_subroutine_with_member(frontend):
    """
    Ensure that :any:`Subroutine` and its components are picklable.
    """

    fcode = """
subroutine my_routine(n, a, b, d)
  use another_module, only: some_routine

  integer, intent(in) :: n
  real, intent(in) :: a(n), b(n)
  real, intent(out) :: d(n)
  integer :: i

  call member_routine(a, b)

  contains

  subroutine member_routine(n, a, b)
    integer, intent(in) :: n
    real, intent(in) :: a(n), b(n)

  end subroutine member_routine
end subroutine my_routine
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    # First, replicate the scope individually, ...
    scope_new = Scope()
    scope_new.symbol_attrs.update(loads(dumps(routine.symbol_attrs)))

    # Replicate spec and body independently...
    spec_new = loads(dumps(routine.spec))
    spec_new = AttachScopes().visit(spec_new, scope=scope_new)
    assert spec_new == routine.spec

    body_new = loads(dumps(routine.body))
    body_new = AttachScopes().visit(body_new, scope=scope_new)
    assert body_new == routine.body

    # Replicate the member routine
    contains_new = loads(dumps(routine.contains))
    body_new = AttachScopes().visit(contains_new, scope=scope_new)
    assert contains_new == routine.contains

    # Ensure equivalence after pickle-cyle with scope-level replication
    routine_new = loads(dumps(routine))
    assert routine_new.spec == routine.spec
    assert routine_new.body == routine.body
    assert routine_new.contains == routine.contains
    assert routine_new.symbol_attrs == routine.symbol_attrs
    assert routine_new == routine


@pytest.mark.parametrize('frontend', available_frontends())
def test_pickle_module_with_routines(frontend, tmp_path):
    """
    Ensure that :any:`Module` object with cross-calling subroutines
    pickle cleanly, including the procedure type symbols.
    """

    fcode = """
module my_module
  implicit none

  contains
  subroutine my_routine(n, a, b, d)
    integer, intent(in) :: n
    real, intent(in) :: a(n), b(n)
    real, intent(out) :: d(n)
    integer :: i

    call other_routine(a, b)
  end subroutine my_routine

  subroutine other_routine(n, a, b)
    integer, intent(in) :: n
    real, intent(in) :: a(n), b(n)

  end subroutine other_routine
end module my_module
"""
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    # First, replicate the scope individually, ...
    scope_new = Scope()
    scope_new.symbol_attrs.update(loads(dumps(module.symbol_attrs)))

    # Replicate spec and body independently...
    spec_new = loads(dumps(module.spec))
    spec_new = AttachScopes().visit(spec_new, scope=scope_new)
    assert spec_new == module.spec

    contains_new = loads(dumps(module.contains))
    # We need to attach the parent here first, so that the deferred
    # procedure type symbol in the call can be resolved
    contains_new.body[1]._reset_parent(scope_new)
    contains_new.body[-1]._reset_parent(scope_new)
    contains_new = AttachScopes().visit(contains_new, scope=scope_new)
    assert contains_new == module.contains

    # Ensure equivalence after pickle-cyle with scope-level replication
    module_new = loads(dumps(module))
    assert module_new.spec == module.spec
    assert module_new.contains == module.contains
    assert module_new.symbol_attrs == module.symbol_attrs
    assert module_new == module


@pytest.mark.parametrize('frontend', available_frontends())
def test_pickle_scheduler_item(here, frontend, tmp_path):
    """
    Test that :any:`Item` objects are picklable, so that we may use
    them with parallel processes.
    """
    filepath = here/'sources/sourcefile_item.f90'
    source = Sourcefile.from_file(filename=filepath, frontend=frontend, xmods=[tmp_path])
    item_a = Item(name='#routine_a', source=source)

    # Check the individual routines and modules in the parsed source file
    for node in item_a.source.ir.body:
        assert loads(dumps(node)) == node

    assert loads(dumps(item_a.source.ir)) == item_a.source.ir
    assert loads(dumps(item_a.source)) == item_a.source
    assert loads(dumps(item_a)) == item_a
loki-ecmwf-0.3.6/loki/tests/test_source_identity.py0000664000175000017500000002710315167130205022666 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Test identity of source-to-source translation.

The tests in here do rarely verify correct representation internally,
they mostly check whether at the end comes out what went in at the beginning.

"""
import pytest

from loki import Sourcefile, Subroutine
from loki.backend import fgen, FortranStyle
from loki.jit_build import clean_test
from loki.ir import nodes as ir, FindNodes
from loki.frontend import available_frontends, OMNI


@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'OMNI stores no source.string')]))
def test_raw_source_loop(tmp_path, frontend):
    """Verify that the raw_source property is correctly used to annotate
    AST nodes with source strings for loops."""
    fcode = """
subroutine routine_raw_source_loop (ia, ib, ic)
integer, intent(in) :: ia, ib, ic

outer: do ia=1,10
  ib = ia
  do 6 while (ib .lt. 20)
    ic = ib
    if (ic .gt. 10) then
      print *, ic
    else
      print *, ib
    end if
6 end do
end do outer
end subroutine routine_raw_source_loop
    """.strip()
    filename = tmp_path / (f'routine_raw_source_loop_{frontend}.f90')
    Sourcefile.to_file(fcode, filename)

    source = Sourcefile.from_file(filename, frontend=frontend)
    routine = source['routine_raw_source_loop']
    assert source.source.string.strip() == fcode
    assert routine.source.string.strip() == fcode

    fcode = fcode.splitlines()
    assert source.source.lines == (1, len(fcode) + 1)
    assert routine.source.lines == (1, len(fcode))

    # Check the intrinsics
    intrinsic_lines = (9, 11)
    for node in FindNodes(ir.Intrinsic).visit(routine.body):
        # Verify that source string is subset of the relevant line in the original source
        assert node.source is not None
        assert node.source.lines in ((l, l) for l in intrinsic_lines)
        assert node.source.string in fcode[node.source.lines[0]-1]

    # Check the do loops
    do_construct_name_found = 0  # Note: this is the construct name 'outer'
    loop_label_found = 0  # Note: this is the do label '6'
    do_lines = ((4, 14), (6, 13))
    for node in FindNodes((ir.Loop, ir.WhileLoop)).visit(routine.ir):
        # Verify that source string is subset of the relevant line in the original source
        assert node.source is not None
        assert node.source.lines in do_lines
        assert node.source.string in ('\n'.join(fcode[start-1:end]) for start, end in do_lines)
        # Make sure the labels and names are correctly identified and contained
        if node.name:
            do_construct_name_found += 1
            assert node.name == 'outer'
        if node.loop_label:
            loop_label_found += 1
            assert node.loop_label == '6'
    assert do_construct_name_found == 1
    assert loop_label_found == 1

    # Assert output of body matches original string (except for case)
    ref = '\n'.join(fcode[3:-1]).replace('.lt.', '<').replace('.gt.', '>')
    assert fgen(routine.body).strip().lower() == ref

    clean_test(filename)


@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'OMNI stores no source.string')]))
def test_raw_source_conditional(tmp_path, frontend):
    """Verify that the raw_source property is correctly used to annotate
    AST nodes with source strings for conditionals."""
    fcode = """
subroutine routine_raw_source_cond (ia, ib, ic)
integer, intent(in) :: ia, ib, ic

check: if (ib > 0) then
  print *, ia
else if (ib == 0) then check
  print *, ib
else check
  print *, ic
end if check
if (ic == 1) print *, ic
end subroutine routine_raw_source_cond
    """.strip()
    filename = tmp_path / (f'routine_raw_source_cond_{frontend}.f90')
    Sourcefile.to_file(fcode, filename)

    source = Sourcefile.from_file(filename, frontend=frontend)
    routine = source['routine_raw_source_cond']
    assert source.source.string.strip() == fcode
    assert routine.source.string.strip() == fcode

    fcode = fcode.splitlines()
    assert source.source.lines == (1, len(fcode) + 1)
    assert routine.source.lines == (1, len(fcode))

    # Check the intrinsics
    intrinsic_lines = (5, 7, 9, 11)
    for node in FindNodes(ir.Intrinsic).visit(routine.body):
        # Verify that source string is subset of the relevant line in the original source
        assert node.source is not None
        assert node.source.lines in ((l, l) for l in intrinsic_lines)
        assert node.source.string in fcode[node.source.lines[0]-1]

    # Check the conditionals
    cond_name_found = 0
    cond_lines = ((4, 10), (6, 10), (11, 11))
    for node in FindNodes(ir.Conditional).visit(routine.ir):
        assert node.source is not None
        assert node.source.lines in cond_lines
        # Verify that source string is subset of the relevant lines in the original source
        assert node.source.string in ('\n'.join(fcode[start-1:end]) for start, end in cond_lines)
        if node.name:
            cond_name_found += 1
            assert node.name == 'check'
    assert cond_name_found == 1

    # Assert output of body matches original string (except for case)
    assert fgen(routine.body).strip().lower() == '\n'.join(fcode[3:-1])

    clean_test(filename)


@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'OMNI stores no source.string')]))
def test_raw_source_multicond(tmp_path, frontend):
    """Verify that the raw_source property is correctly used to annotate
    AST nodes with source strings for multi conditionals."""
    fcode = """
subroutine routine_raw_source_multicond (ia, ib, ic)
integer, intent(in) :: ia, ib, ic

multicond: select case (ic)
case (10) multicond
  print *, ic
case (ia) multicond
  print *, ia
case default multicond
  print *, ib
end select multicond
end subroutine routine_raw_source_multicond
    """.strip()
    filename = tmp_path / (f'routine_raw_source_multicond_{frontend}.f90')
    Sourcefile.to_file(fcode, filename)

    source = Sourcefile.from_file(filename, frontend=frontend)
    routine = source['routine_raw_source_multicond']
    assert source.source.string.strip() == fcode
    assert routine.source.string.strip() == fcode

    fcode = fcode.splitlines()
    assert source.source.lines == (1, len(fcode) + 1)
    assert routine.source.lines == (1, len(fcode))

    # Check the intrinsics
    intrinsic_lines = (6, 8, 10)
    for node in FindNodes(ir.Intrinsic).visit(routine.body):
        # Verify that source string is subset of the relevant line in the original source
        assert node.source is not None
        assert node.source.lines in ((l, l) for l in intrinsic_lines)
        assert node.source.string in fcode[node.source.lines[0]-1]

    # Check the conditional
    cond_name_found = 0
    cond_lines = ((4, 11),)
    for node in FindNodes(ir.MultiConditional).visit(routine.ir):
        assert node.source is not None
        assert node.source.lines in cond_lines
        # Verify that source string is subset of the relevant lines in the original source
        assert node.source.string in ('\n'.join(fcode[start-1:end]) for start, end in cond_lines)
        if node.name:
            cond_name_found += 1
            assert node.name == 'multicond'
    assert cond_name_found == 1

    # Assert output of body matches original string (except for case)
    assert fgen(routine.body).strip().lower() == '\n'.join(fcode[3:-1])

    clean_test(filename)


@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'This is outright impossible')]))
def test_subroutine_conservative(frontend):
    """
    Test that conservative output of fgen reproduces the original source string for
    a simple subroutine.
    This has a few limitations, in particular with respect to the signature of the
    subroutine.
    """
    fcode = """
SUBROUTINE CONSERVATIVE (X, Y, SCALAR, VECTOR, MATRIX)
  IMPLICIT NONE
  INTEGER, PARAMETER :: JPRB = SELECTED_REAL_KIND(13, 300)
  INTEGER, INTENT(IN) :: X, Y
  REAL(KIND=JPRB), INTENT(IN) :: SCALAR
  REAL(KIND=JPRB), INTENT(INOUT) :: VECTOR(X)
  REAL(KIND=JPRB), DIMENSION(X, Y), INTENT(OUT) :: MATRIX
  INTEGER :: I, SOME_VERY_LONG_INTEGERS, TO_SEE_IF_LINE_CONTUATION, IS_NOT_ENFORCED_IN_THIS_CASE
  ! Some comment that is very very long and exceeds the line width but should not be wrapped
  DO I=1, X
    VECTOR(I) = VECTOR(I) + SCALAR
!$LOKI SOME NONSENSE PRAGMA WITH VERY LONG TEXT THAT EXCEEDS THE LINE WIDTH LIMIT OF OUTPUT
    MATRIX(I, :) = I * VECTOR(I)
  ENDDO
END SUBROUTINE CONSERVATIVE
    """.strip()

    # Parse and re-generate the code
    routine = Subroutine.from_source(fcode, frontend=frontend)
    source = fgen(routine, style=FortranStyle(linewidth=90), conservative=True)
    assert source == fcode


@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'This is outright impossible')]))
def test_subroutine_simple_fgen(frontend):
    """
    Test that non-conservative output produces the original source string for
    a simple subroutine.
    This has a few limitations, in particular for formatting of expressions.
    """
    fcode = """
SUBROUTINE SIMPLE_FGEN (X, Y, SCALAR, VECTOR, MATRIX)
  IMPLICIT NONE
  INTEGER, PARAMETER :: JPRB = SELECTED_REAL_KIND(13, 300)
  INTEGER, INTENT(IN) :: X, Y
  REAL(KIND=JPRB), INTENT(IN) :: SCALAR  ! A very long inline comment that should not be wrapped but simply appended
  REAL(KIND=JPRB), INTENT(INOUT) :: VECTOR(X)
  REAL(KIND=JPRB), INTENT(OUT), DIMENSION(X, Y) :: MATRIX
  INTEGER :: I, SOME_VERY_LONG_INTEGERS, TO_SEE_IF_LINE_CONTINUATION,  &
  & WORKS_AS_EXPECTED_IN_ITEM_LISTS
  ! Some comment that is very very long and exceeds the line width but should not be wrapped
  DO I=1,X
    VECTOR(I) = VECTOR(I) + SCALAR
!$LOKI SOME PRAGMA WITH VERY LONG TEXT JUST UNDER THE LINE WIDTH LIMIT OF OUTPUT
    MATRIX(I, :) = I*VECTOR(I)
    IF (SOME_VERY_LONG_INTEGERS > X) THEN
      ! Some comment to have more than one line
      ! in the body of the condtional
      IF (TO_SEE_IF_LINE_CONTINUATION > Y) THEN
        PRINT *, 'Intrinsic statement'
      END IF
    END IF
  END DO
END SUBROUTINE SIMPLE_FGEN
    """.strip()

    # Parse and write the code
    routine = Subroutine.from_source(fcode, frontend=frontend)
    source = fgen(routine, style=FortranStyle(linewidth=90))
    assert source == fcode


@pytest.mark.parametrize('frontend', available_frontends(
    xfail=[(OMNI, 'OMNI does it for you BUT WITHOUT DELETING THE KEYWORD!!!')])
)
def test_multiline_pragma(frontend):
    """
    Test that multi-line pragmas are combined correctly.
    """
    fcode = """
subroutine multiline_pragma
  implicit none
  integer :: dummy
!$foo some very long pragma &
!$foo with a line break
  dummy = 1
!$bar some pragma         &
!$bar with more than      &
!$bar one line break
!$bar followed by    &
!$bar    another multiline pragma &
!$bar    with same keyword
  dummy = dummy + 1
!$foobar and yet &
!$foobar another multiline pragma
end subroutine multiline_pragma
    """.strip()

    # Parse the code
    routine = Subroutine.from_source(fcode, frontend=frontend)
    pragmas = FindNodes(ir.Pragma).visit(routine.body)
    pragma_content = {
        'foo': ['some very long pragma with a line break'],
        'bar': ['some pragma with more than one line break',
                'followed by another multiline pragma with same keyword'],
        'foobar': ['and yet another multiline pragma']
    }

    assert len(pragmas) == 4
    assert all(pragma.content in pragma_content[pragma.keyword] for pragma in pragmas)
loki-ecmwf-0.3.6/loki/tests/kind_map0000664000175000017500000000160115167130205017544 0ustar  alastairalastair{
     'real': {
        '': 'float',
        '4': 'float',
        'real32': 'float',
        'REAL32': 'float',
        'c_float': 'float',
        '8': 'double',
        'jprb': 'double',
        'JPRB': 'double',
        'selected_real_kind(6,37)': 'float',
        'selected_real_kind6,37': 'float',
        'selected_real_kind6, 37': 'float',
        'real32': 'float',
        'REAL32': 'float',
        'selected_real_kind(13,300)': 'double',
        'selected_real_kind13,300': 'double',
        'selected_real_kind13, 300': 'double',
        'real64': 'double',
        'REAL64': 'double',
        'c_double': 'double',
    },
    'integer': {
        '': 'int',
        '4': 'int',
        'int8': 'char',
        'INT8': 'char',
        'int32': 'int',
        'INT32': 'int',
        'selected_int_kind(9)': 'int',
        'selected_int_kind9': 'int',
        'jpim': 'int',
    },
}
loki-ecmwf-0.3.6/loki/tests/test_subroutine.py0000664000175000017500000025272715167130205021670 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

# pylint: disable=too-many-lines
from pathlib import Path
import pytest

from loki import Sourcefile, Module, Subroutine, Function, fgen, fexprgen
from loki.jit_build import jit_compile, jit_compile_lib, clean_test
from loki.expression import symbols as sym
from loki.frontend import available_frontends, OMNI, REGEX
from loki.ir import (
    nodes as ir, FindNodes, FindVariables, FindTypedSymbols,
    FindInlineCalls, Transformer
)
from loki.types import (
    BasicType, DerivedType, ProcedureType, SymbolAttributes
)


@pytest.fixture(scope='module', name='here')
def fixture_here():
    return Path(__file__).parent


@pytest.fixture(scope='module', name='header_path')
def fixture_header_path(here):
    return here/'sources/header.f90'


@pytest.mark.parametrize('frontend', available_frontends())
def test_routine_simple(frontend):
    """
    A simple standard looking routine to test argument declarations.
    """
    fcode = """
subroutine routine_simple (x, y, scalar, vector, matrix)
  ! This is the docstring ...

  ! It spans multiple intersected lines ...
  ! ... and is followed by a ...

  !$loki routine fun

  integer, parameter :: jprb = selected_real_kind(13,300)
  integer, intent(in) :: x, y
  real(kind=jprb), intent(in) :: scalar
  real(kind=jprb), intent(inout) :: vector(x), matrix(x, y)
  integer :: i

  do i=1, x
     vector(i) = vector(i) + scalar
     matrix(i, :) = i * vector(i)
  end do
end subroutine routine_simple
"""

    # Test the internals of the subroutine
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert routine.arguments == ('x', 'y', 'scalar', 'vector(x)', 'matrix(x, y)')
    assert routine.variables == ('jprb', 'x', 'y', 'scalar', 'vector(x)', 'matrix(x, y)', 'i')

    # Check the docstring
    assert len(routine.docstring) == 1
    assert isinstance(routine.docstring[0], ir.CommentBlock)
    if frontend == OMNI:
        assert len(routine.docstring[0].comments) == 3
        assert routine.docstring[0].comments[0].text == '! This is the docstring ...'
        assert routine.docstring[0].comments[1].text == '! It spans multiple intersected lines ...'
        assert routine.docstring[0].comments[2].text == '! ... and is followed by a ...'
    else:
        assert len(routine.docstring[0].comments) == 5
        assert routine.docstring[0].comments[0].text == '! This is the docstring ...'
        assert routine.docstring[0].comments[2].text == '! It spans multiple intersected lines ...'
        assert routine.docstring[0].comments[3].text == '! ... and is followed by a ...'
    assert routine.definitions == ()

    # Check the spec
    assert isinstance(routine.body, ir.Section)
    if frontend == OMNI:
        assert len(routine.spec.body) == 9
        assert isinstance(routine.spec.body[0], ir.Intrinsic)
        assert isinstance(routine.spec.body[1], ir.Pragma)
        assert all(isinstance(n, ir.VariableDeclaration) for n in routine.spec.body[2:])
        assert routine.spec.body[2].symbols == ('jprb',)
        assert routine.spec.body[3].symbols == ('x',)
        assert routine.spec.body[4].symbols == ('y',)
        assert routine.spec.body[5].symbols == ('scalar',)
        assert routine.spec.body[6].symbols == ('vector(x)',)
        assert routine.spec.body[7].symbols == ('matrix(x, y)',)
        assert routine.spec.body[8].symbols == ('i',)
    else:
        assert len(routine.spec.body) == 7
        assert isinstance(routine.spec.body[0], ir.Pragma)
        assert isinstance(routine.spec.body[1], ir.Comment)
        assert all(isinstance(n, ir.VariableDeclaration) for n in routine.spec.body[2:])
        assert routine.spec.body[2].symbols == ('jprb',)
        assert routine.spec.body[3].symbols == ('x', 'y')
        assert routine.spec.body[4].symbols == ('scalar',)
        assert routine.spec.body[5].symbols == ('vector(x)', 'matrix(x, y)')
        assert routine.spec.body[6].symbols == ('i',)

    # Check the routine body
    assert isinstance(routine.spec, ir.Section)
    loops = FindNodes(ir.Loop).visit(routine.body)
    assert len(loops) == 1 and loops[0].variable == 'i'
    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assigns) == 2
    assert assigns[0] in loops[0].body and assigns[1] in loops[0].body


@pytest.mark.parametrize('frontend', available_frontends())
def test_routine_prefix(frontend):
    """ Test matching of prefix attributes for subroutines """
    fcode = """
pure elemental subroutine my_routine(x, y)
  implicit none
  integer(kind=8), intent(inout) :: x, y

  x = x + y
end subroutine my_routine
"""
    # Note that Fparser fails here if the legal 'recursive' prefix is used!
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert routine.name == 'my_routine'
    assert len(routine.prefix) == 2
    assert routine.prefix == ('PURE', 'ELEMENTAL')

    # Check that routine was parsed completely
    assert isinstance(routine.body.body[-1], ir.Assignment)
    assert routine.body.body[-1].lhs == 'x'
    assert routine.body.body[-1].rhs == 'x + y'


@pytest.mark.parametrize('frontend', available_frontends(
    xfail=[(OMNI, 'OMNI frontend interface does not provide interfaces')]
))
def test_routine_bind(frontend, tmp_path):
    """ Test matching of 'bind" suffix for subroutines in interfaces """
    fcode = """
module my_module
  implicit none

  interface
    subroutine my_routine(x, y) bind(C, name='my_routine_c')
      use, intrinsic :: iso_c_binding
      integer(kind=c_int), intent(inout) :: x, y
    end subroutine my_routine
  end interface

contains

  subroutine my_routine(x, y)
    integer(kind=4), intent(inout) :: x, y

    x = x + y
  end subroutine my_routine
end module my_module
"""
    # Note that Fparser fails here if the legal 'recursive' prefix is used!
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    routine = module['my_routine']

    intf_routine = module.interface_map['my_routine'].body[0]
    # The bind attribute is named, check the name
    assert isinstance(intf_routine.bind, sym.StringLiteral)
    assert intf_routine.bind == 'my_routine_c'
    assert "BIND(c, name='my_routine_c')" in fgen(intf_routine)
    # TODO: bind(C) is not honoured atm

    # TODO: Interface definition and module routine alias, is that intended?
    assert intf_routine == routine

    # Check that module routine was parsed completely
    assert isinstance(routine.body.body[-1], ir.Assignment)
    assert routine.body.body[-1].lhs == 'x'
    assert routine.body.body[-1].rhs == 'x + y'


@pytest.mark.parametrize('frontend', available_frontends())
def test_routine_arguments(frontend):
    """
    A set of test to test internalisation and handling of arguments.
    """

    fcode = """
subroutine routine_arguments &
 ! Test multiline dummy arguments with comments
 & (x, y, scalar, &
 ! Of course, not one...
 ! but two comment lines
 & vector, matrix)
  implicit none
  integer, parameter :: jprb = selected_real_kind(13,300)
  ! The order below is intentioanlly inverted
  real(kind=jprb), intent(inout) :: matrix(x, y)
  real(kind=jprb), intent(in)    :: scalar
  real(kind=jprb), dimension(x)  :: local_vector
  real(kind=jprb), dimension(x), intent(out) :: vector
  integer, intent(in) :: x, y

  integer :: i, j
  real(kind=jprb) :: local_matrix(x, y)

  do i=1, x
     local_vector(i) = i * 10.
     do j=1, y
        local_matrix(i, j) = local_vector(i) + j * scalar
     end do
  end do

  vector(:) = local_vector(:)
  matrix(:, :) = local_matrix(:, :)

end subroutine routine_arguments
"""

    routine = Subroutine.from_source(fcode, frontend=frontend)

    # The line-creaking comments are attributed to the docstring.
    # This behaviour is not ideal, but it is the current status quo!
    if frontend == OMNI:
        assert not routine.docstring
    else:
        assert len(routine.docstring) == 1
        assert len(routine.docstring[0].comments) == 3
        assert routine.docstring[0].comments[0].text == '! Test multiline dummy arguments with comments'
        assert routine.docstring[0].comments[1].text == '! Of course, not one...'
        assert routine.docstring[0].comments[2].text == '! but two comment lines'

    # Argument order is determined by the dummies in the signature
    assert routine.arguments == ('x', 'y', 'scalar', 'vector(x)', 'matrix(x, y)')
    assert all(isinstance(a, sym.Scalar) for a in routine.arguments[0:3])
    assert all(a.type.intent == 'in' for a in routine.arguments[0:3])
    assert all(isinstance(a, sym.Array) for a in routine.arguments[3:])
    assert all(a.type.dtype == BasicType.INTEGER for a in routine.arguments[0:2])
    assert all(a.type.dtype == BasicType.REAL for a in routine.arguments[2:5])
    if frontend == OMNI:
        assert all(isinstance(a.type.kind, sym.InlineCall) for a in routine.arguments[2:5])
    else:
        assert all(a.type.kind == 'jprb' for a in routine.arguments[2:5])
    assert routine.arguments[3].shape == ('x',)
    assert routine.arguments[4].shape == ('x', 'y')
    assert routine.arguments[3].type.intent == 'out'
    assert routine.arguments[4].type.intent == 'inout'

    # Local variable order is determined by the order of the declarations
    assert routine.variables == (
        'jprb', 'matrix(x, y)', 'scalar', 'local_vector(x)',
        'vector(x)', 'x', 'y', 'i', 'j', 'local_matrix(x, y)'
    )
    assert routine.variables[0].type.parameter
    assert isinstance(routine.variables[0].type.initial, sym.InlineCall)
    assert routine.variables[0].type.initial.function == 'selected_real_kind'
    assert routine.variables[1].type.dtype == BasicType.REAL
    assert routine.variables[1].shape == ('x', 'y')
    assert routine.variables[2].type.dtype == BasicType.REAL
    assert routine.variables[3].type.dtype == BasicType.REAL
    assert routine.variables[3].shape == ('x',)
    assert routine.variables[4].type.dtype == BasicType.REAL
    assert routine.variables[4].shape == ('x',)
    assert routine.variables[5].type.dtype == BasicType.INTEGER
    assert routine.variables[6].type.dtype == BasicType.INTEGER
    assert routine.variables[7].type.dtype == BasicType.INTEGER
    assert routine.variables[8].type.dtype == BasicType.INTEGER
    assert routine.variables[9].type.dtype == BasicType.REAL
    assert routine.variables[9].shape == ('x', 'y')


@pytest.mark.parametrize('frontend', available_frontends())
def test_routine_arguments_add_remove(frontend):
    """
    Test addition and removal of subroutine arguments.
    """
    fcode = """
subroutine routine_arguments_add_remove(x, y, scalar, vector, matrix)
  integer, parameter :: jprb = selected_real_kind(13, 300)
  integer, intent(in) :: x, y
  real(kind=jprb), intent(in) :: scalar
  real(kind=jprb), intent(inout) :: vector(x), matrix(x, y)
end subroutine routine_arguments_add_remove
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)
    assert routine.arguments == ('x', 'y', 'scalar', 'vector(x)', 'matrix(x, y)')

    # Create a new set of variables and add to local routine variables
    x = routine.variables[1]  # That's the symbol for variable 'x'
    real_type = routine.symbol_attrs['scalar']  # Type of variable 'maximum'
    a = sym.Scalar(name='a', type=real_type, scope=routine)
    b = sym.Array(name='b', dimensions=(x, ), type=real_type, scope=routine)
    c = sym.Variable(name='c', type=x.type, scope=routine)

    # Add new arguments and check that they are all in the routine spec
    routine.arguments += (a, b, c)
    routine_args = [str(arg) for arg in routine.arguments]
    assert routine_args in (
        ['x', 'y', 'scalar', 'vector(x)', 'matrix(x, y)', 'a', 'b(x)', 'c'],
    )
    if frontend == OMNI:
        assert fgen(routine.spec).lower() == """
implicit none
integer, parameter :: jprb = selected_real_kind(13, 300)
integer, intent(in) :: x
integer, intent(in) :: y
real(kind=selected_real_kind(13, 300)), intent(in) :: scalar
real(kind=selected_real_kind(13, 300)), intent(inout) :: vector(x)
real(kind=selected_real_kind(13, 300)), intent(inout) :: matrix(x, y)
real(kind=selected_real_kind(13, 300)), intent(in) :: a
real(kind=selected_real_kind(13, 300)), intent(in) :: b(x)
integer, intent(in) :: c
""".strip().lower()
    else:
        assert fgen(routine.spec).lower() == """
integer, parameter :: jprb = selected_real_kind(13, 300)
integer, intent(in) :: x, y
real(kind=jprb), intent(in) :: scalar
real(kind=jprb), intent(inout) :: vector(x), matrix(x, y)
real(kind=jprb), intent(in) :: a
real(kind=jprb), intent(in) :: b(x)
integer, intent(in) :: c
""".strip().lower()

    # Remove a select number of arguments
    routine.arguments = [arg for arg in routine.arguments if 'x' not in str(arg)]
    assert routine.arguments == ('y', 'scalar', 'a', 'c' )

    # Check that removed args still exist as variables
    routine_vars = [str(arg) for arg in routine.variables]
    assert 'vector(x)' in routine_vars
    assert 'matrix(x, y)' in routine_vars
    assert 'b(x)' in routine_vars


@pytest.mark.parametrize('frontend', available_frontends())
def test_routine_variable_caching(frontend):
    """
    Test that equivalent names in distinct routines don't cache.
    """
    fcode_real = """
subroutine routine_real (x, y, scalar, vector, matrix)
  integer, parameter :: jprb = selected_real_kind(13,300)
  integer, intent(in) :: x, y
  real(kind=jprb), intent(in) :: scalar
  real(kind=jprb), intent(inout) :: vector(x), matrix(x, y)
  integer :: i

  do i=1, x
     vector(i) = vector(i) + scalar
     matrix(i, :) = i * vector(i)
  end do
end subroutine routine_real
"""

    fcode_int = """
subroutine routine_simple_caching (x, y, scalar, vector, matrix)
  ! A simple standard looking routine to test variable caching.
  integer, parameter :: jpim = selected_int_kind(9)
  integer, intent(in) :: x, y
  ! The next two share names with `routine_simple`, but have different
  ! dimensions or types, so that we can test variable caching.
  integer(kind=jpim), intent(in) :: scalar
  integer(kind=jpim), intent(inout) :: vector(y), matrix(x, y)
  integer :: i

  do i=1, y
     vector(i) = vector(i) + scalar
     matrix(:, i) = i * vector(i)
  end do
end subroutine routine_simple_caching
"""

    # Test the internals of the subroutine
    routine = Subroutine.from_source(fcode_real, frontend=frontend)
    assert routine.arguments == ('x', 'y', 'scalar', 'vector(x)', 'matrix(x, y)')
    assert routine.arguments[2].type.dtype == BasicType.REAL
    assert routine.arguments[3].type.dtype == BasicType.REAL

    routine = Subroutine.from_source(fcode_int, frontend=frontend)
    assert routine.arguments == ('x', 'y', 'scalar', 'vector(y)', 'matrix(x, y)')
    # Ensure that the types in the second routine have been picked up
    assert routine.arguments[2].type.dtype == BasicType.INTEGER
    assert routine.arguments[3].type.dtype == BasicType.INTEGER


@pytest.mark.parametrize('frontend', available_frontends())
def test_routine_variables_add_remove(frontend):
    """
    Test local variable addition and removal.
    """
    fcode = """
subroutine routine_variables_add_remove(x, y, maximum, vector)
  integer, parameter :: jprb = selected_real_kind(13,300)
  integer, intent(in) :: x, y
  real(kind=jprb), intent(out) :: maximum
  real(kind=jprb), intent(inout) :: vector(x)
  real(kind=jprb) :: matrix(x, y)
end subroutine routine_variables_add_remove
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)
    assert routine.variables == ('jprb', 'x', 'y', 'maximum', 'vector(x)', 'matrix(x, y)')

    # Create a new set of variables and add to local routine variables
    x = routine.variable_map['x']  # That's the symbol for variable 'x'
    real_type = SymbolAttributes('real', kind=routine.variable_map['jprb'])
    int_type = SymbolAttributes('integer')
    a = sym.Scalar(name='a', type=real_type, scope=routine)
    b = sym.Array(name='b', dimensions=(x, ), type=real_type, scope=routine)
    c = sym.Variable(name='c', type=int_type, scope=routine)

    # Add new variables and check that they are all in the routine spec
    routine.variables += (a, b, c)
    if frontend == OMNI:
        # OMNI frontend inserts a few peculiarities
        assert fgen(routine.spec).lower() == """
implicit none
integer, parameter :: jprb = selected_real_kind(13, 300)
integer, intent(in) :: x
integer, intent(in) :: y
real(kind=selected_real_kind(13, 300)), intent(out) :: maximum
real(kind=selected_real_kind(13, 300)), intent(inout) :: vector(x)
real(kind=selected_real_kind(13, 300)) :: matrix(x, y)
real(kind=jprb) :: a
real(kind=jprb) :: b(x)
integer :: c
""".strip().lower()

    else:
        assert fgen(routine.spec).lower() == """
integer, parameter :: jprb = selected_real_kind(13, 300)
integer, intent(in) :: x, y
real(kind=jprb), intent(out) :: maximum
real(kind=jprb), intent(inout) :: vector(x)
real(kind=jprb) :: matrix(x, y)
real(kind=jprb) :: a
real(kind=jprb) :: b(x)
integer :: c
""".strip().lower()

    # Now remove the `maximum` variable and make sure it's gone
    routine.variables = [v for v in routine.variables if v.name != 'maximum']
    assert 'maximum' not in fgen(routine.spec).lower()
    assert routine.variables == (
        'jprb', 'x', 'y', 'vector(x)', 'matrix(x, y)', 'a', 'b(x)', 'c'
    )
    # Ensure `maximum` has been removed from arguments, but they are otherwise unharmed
    assert routine.arguments == ('x', 'y', 'vector(x)')


@pytest.mark.parametrize('frontend', available_frontends())
def test_routine_variables_find(frontend):
    """
    Tests the `FindVariables` utility (not the best place to put this).
    """
    fcode = """
subroutine routine_variables_find (x, y, maximum)
  integer, parameter :: jprb = selected_real_kind(13,300)
  integer, intent(in) :: x, y
  real(kind=jprb), intent(out) :: maximum
  integer :: i, j
  real(kind=jprb), dimension(x) :: vector
  real(kind=jprb) :: matrix(x, y)

  do i=1, x
     vector(i) = i * 10.
  end do
  do i=1, x
     do j=1, y
        matrix(i, j) = vector(i) + j * 2.
     end do
  end do
  maximum = matrix(x, y)
end subroutine routine_variables_find
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    vars_all = FindVariables(unique=False).visit(routine.body)
    # Note, we are not counting declarations tmp_path
    assert sum(1 for s in vars_all if str(s) == 'i') == 6
    assert sum(1 for s in vars_all if str(s) == 'j') == 3
    assert sum(1 for s in vars_all if str(s) == 'matrix(i, j)') == 1
    assert sum(1 for s in vars_all if str(s) == 'matrix(x, y)') == 1
    assert sum(1 for s in vars_all if str(s) == 'maximum') == 1
    assert sum(1 for s in vars_all if str(s) == 'vector(i)') == 2
    assert sum(1 for s in vars_all if str(s) == 'x') == 3
    assert sum(1 for s in vars_all if str(s) == 'y') == 2

    vars_unique = FindVariables(unique=True).visit(routine.ir)
    assert sum(1 for s in vars_unique if str(s) == 'i') == 1
    assert sum(1 for s in vars_unique if str(s) == 'j') == 1
    assert sum(1 for s in vars_unique if str(s) == 'matrix(i, j)') == 1
    assert sum(1 for s in vars_unique if str(s) == 'matrix(x, y)') == 1
    assert sum(1 for s in vars_unique if str(s) == 'maximum') == 1
    assert sum(1 for s in vars_unique if str(s) == 'vector(i)') == 1
    assert sum(1 for s in vars_unique if str(s) == 'x') == 1
    assert sum(1 for s in vars_unique if str(s) == 'y') == 1


@pytest.mark.parametrize('frontend', available_frontends())
def test_routine_variables_dim_shapes(frontend):
    """
    A set of test to ensure matching different dimension and shape
    expressions against strings and other expressions works as expected.
    """
    fcode = """
subroutine routine_dim_shapes(v1, v2, v3, v4, v5)
  ! Simple variable assignments with non-trivial sizes and indices
  integer, parameter :: jprb = selected_real_kind(13,300)
  real(kind=jprb), allocatable, intent(out) :: v3(:)
  real(kind=jprb), intent(out) :: v4(v1,v2), v5(0:v1,v2-1)
  integer, intent(in) :: v1, v2

  allocate(v3(v1))
  v3(v1-v2+1) = 1.
  v4(3:v1,1:v2-3) = 2.
  v5(:,:) = 3.

end subroutine routine_dim_shapes
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)
    assert routine.arguments == ('v1', 'v2', 'v3(:)', 'v4(v1, v2)', 'v5(0:v1, v2 - 1)')

    # Make sure variable/argument shapes on the routine work
    shapes = [fexprgen(v.shape) for v in routine.arguments if isinstance(v, sym.Array)]
    assert shapes == ['(:,)', '(v1, v2)', '(0:v1, v2 - 1)']

    # Ensure that all spec variables (including dimension symbols) are scoped correctly
    spec_vars = [v for v in FindVariables(unique=False).visit(routine.spec) if v.name.lower() != 'selected_real_kind']
    assert all(v.scope == routine for v in spec_vars)
    assert all(isinstance(v, (sym.Scalar, sym.Array)) for v in spec_vars)

    # Ensure shapes of body variables are ok
    b_shapes = [fexprgen(v.shape) for v in FindVariables(unique=False).visit(routine.body)
                if isinstance(v, sym.Array)]
    assert b_shapes == ['(:,)', '(:,)', '(v1, v2)', '(0:v1, v2 - 1)']


@pytest.mark.parametrize('frontend', available_frontends())
def test_routine_variables_shape_propagation(tmp_path, header_path, frontend):
    """
    Test for the correct identification and forward propagation of variable shapes
    from the subroutine declaration.
    """

    # Parse simple kernel routine to check plain array arguments
    routine = Subroutine.from_source(frontend=frontend, source="""
subroutine routine_shape(x, y, scalar, vector, matrix)
  integer, parameter :: jprb = selected_real_kind(13,300)
  integer, intent(in) :: x, y
  real(kind=jprb), intent(in) :: scalar
  real(kind=jprb), intent(inout) :: vector(x), matrix(x, y)
  integer :: i

  do i=1, x
     vector(i) = vector(i) + scalar
     matrix(i, :) = i * vector(i)
  end do
end subroutine routine_shape
""")

    # Check shapes on the internalized variable and argument lists
    # x, y, = routine.arguments[0], routine.arguments[1]
    # TODO: The string comparison tmp_path is due to the fact that shapes are actually
    # `RangeIndex(upper=Scalar)` objects, instead of the raw dimension variables.
    # This needs some more thorough conceptualisation of dimensions and indices!
    assert fexprgen(routine.arguments[3].shape) == '(x,)'
    assert fexprgen(routine.arguments[4].shape) == '(x, y)'

    # Verify that all variable instances have type and shape information
    variables = FindVariables().visit(routine.body)
    assert all(v.shape is not None for v in variables if isinstance(v, sym.Array))

    vmap = {v.name: v for v in variables}
    assert fexprgen(vmap['vector'].shape) == '(x,)'
    assert fexprgen(vmap['matrix'].shape) == '(x, y)'

    # Parse kernel with external typedefs to test shape inferred from
    # external derived type definition
    fcode = """
subroutine routine_typedefs_simple(item)
  ! simple vector/matrix arithmetic with a derived type
  ! imported from an external header module
  use header, only: derived_type
  implicit none

  type(derived_type), intent(inout) :: item
  integer :: i, j, n

  n = 3
  do i=1, n
    item%vector(i) = item%vector(i) + item%scalar
  end do

  do j=1, n
    do i=1, n
      item%matrix(i, j) = item%matrix(i, j) + item%scalar
    end do
  end do

end subroutine routine_typedefs_simple
"""
    header = Sourcefile.from_file(header_path, frontend=frontend, xmods=[tmp_path])['header']
    routine = Subroutine.from_source(fcode, frontend=frontend, definitions=header, xmods=[tmp_path])

    # Verify that all derived type variables have shape info
    variables = FindVariables().visit(routine.body)
    assert all(v.shape is not None for v in variables if isinstance(v, sym.Array))

    # Verify shape info from imported derived type is propagated
    vmap = {v.name: v for v in variables}
    assert fexprgen(vmap['item%vector'].shape) == '(3,)'
    assert fexprgen(vmap['item%matrix'].shape) == '(3, 3)'


@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'OMNI does not like Loki pragmas, yet!')]))
def test_routine_variables_dimension_pragmas(frontend):
    """
    Test that `!$loki dimension` pragmas can be used to verride the
    conceptual `.shape` of local and argument variables.
    """
    fcode = """
subroutine routine_variables_dimensions(x, y, v1, v2, v3, v4)
  integer, parameter :: jprb = selected_real_kind(13,300)
  integer, intent(in) :: x, y
  !$loki dimension(x,:)
  real(kind=jprb), intent(inout) :: v1(:,:)
  !$loki dimension(x,y,:)
  real(kind=jprb), dimension(:,:,:), intent(inout) :: v2, v3
  !$loki dimension(x,y)
  real(kind=jprb), pointer, intent(inout) :: v4(:,:)
  !$loki dimension(y,:)
  real(kind=jprb), allocatable :: v5(:,:)
  !$loki dimension(x+y)
  real(kind=jprb), dimension(:), pointer :: v6

end subroutine routine_variables_dimensions
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)
    assert fexprgen(routine.variable_map['v1'].shape) == '(x, :)'
    assert fexprgen(routine.variable_map['v2'].shape) == '(x, y, :)'
    assert fexprgen(routine.variable_map['v3'].shape) == '(x, y, :)'
    assert fexprgen(routine.variable_map['v4'].shape) == '(x, y)'
    assert fexprgen(routine.variable_map['v5'].shape) == '(y, :)'
    assert fexprgen(routine.variable_map['v6'].shape) == '(x + y,)'


@pytest.mark.parametrize('frontend', available_frontends())
def test_routine_type_propagation(header_path, frontend, tmp_path):
    """
    Test for the forward propagation of derived-type information from
    a standalone module to a foreign subroutine via the :param typedef:
    argument.
    """
    # TODO: Note, if we wanted to test the reference solution with
    # typedefs, we need to extend compile_and_load to use multiple
    # source files/paths, so that the header can be compiled alongside
    # the subroutine in the same f90wrap execution.

    # Parse simple kernel routine to check plain array arguments
    routine = Subroutine.from_source(frontend=frontend, source="""
subroutine routine_simple (x, y, scalar, vector, matrix)
  integer, parameter :: jprb = selected_real_kind(13,300)
  integer, intent(in) :: x, y
  real(kind=jprb), intent(in) :: scalar
  real(kind=jprb), intent(inout) :: vector(x), matrix(x, y)
  integer :: i

  do i=1, x
     vector(i) = vector(i) + scalar
     matrix(i, :) = i * vector(i)
  end do
end subroutine routine_simple
""")

    # Check types on the internalized variable and argument lists
    assert routine.arguments[0].type.dtype == BasicType.INTEGER
    assert routine.arguments[1].type.dtype == BasicType.INTEGER
    assert routine.arguments[2].type.dtype == BasicType.REAL
    assert str(routine.arguments[2].type.kind) in ('jprb', 'selected_real_kind(13, 300)')
    assert routine.arguments[3].type.dtype == BasicType.REAL
    assert str(routine.arguments[3].type.kind) in ('jprb', 'selected_real_kind(13, 300)')
    assert routine.arguments[4].type.dtype == BasicType.REAL
    assert str(routine.arguments[4].type.kind) in ('jprb', 'selected_real_kind(13, 300)')

    # Verify that all variable instances have type information
    variables = FindVariables().visit(routine.body)
    assert all(v.type is not None for v in variables if isinstance(v, (sym.Scalar, sym.Array)))

    vmap = {v.name: v for v in variables}
    assert vmap['x'].type.dtype == BasicType.INTEGER
    assert vmap['scalar'].type.dtype == BasicType.REAL
    assert str(vmap['scalar'].type.kind) in ('jprb', 'selected_real_kind(13, 300)')
    assert vmap['vector'].type.dtype == BasicType.REAL
    assert str(vmap['vector'].type.kind) in ('jprb', 'selected_real_kind(13, 300)')
    assert vmap['matrix'].type.dtype == BasicType.REAL
    assert str(vmap['matrix'].type.kind) in ('jprb', 'selected_real_kind(13, 300)')

    # Parse kernel routine and provide external typedefs
    fcode = """
subroutine routine_typedefs_simple(item)
  ! simple vector/matrix arithmetic with a derived type
  ! imported from an external header module
  use header, only: derived_type
  implicit none

  type(derived_type), intent(inout) :: item
  integer :: i, j, n

  n = 3
  do i=1, n
    item%vector(i) = item%vector(i) + item%scalar
  end do

  do j=1, n
    do i=1, n
      item%matrix(i, j) = item%matrix(i, j) + item%scalar
    end do
  end do

end subroutine routine_typedefs_simple
"""
    header = Sourcefile.from_file(header_path, frontend=frontend, xmods=[tmp_path])['header']
    routine = Subroutine.from_source(fcode, frontend=frontend, definitions=header, xmods=[tmp_path])

    # Check that external typedefs have been propagated to kernel variables
    # First check that the declared parent variable has the correct type
    assert routine.arguments[0].name == 'item'
    assert routine.arguments[0].type.dtype.name == 'derived_type'

    # Verify that all variable instances have type and shape information
    variables = FindVariables().visit(routine.body)
    assert all(v.type is not None for v in variables)

    # Verify imported derived type info explicitly
    vmap = {v.name: v for v in variables}
    assert vmap['item%scalar'].type.dtype == BasicType.REAL
    assert str(vmap['item%scalar'].type.kind) in ('jprb', 'selected_real_kind(13, 300)')
    assert vmap['item%vector'].type.dtype == BasicType.REAL
    assert str(vmap['item%vector'].type.kind) in ('jprb', 'selected_real_kind(13, 300)')
    assert vmap['item%matrix'].type.dtype == BasicType.REAL
    assert str(vmap['item%matrix'].type.kind) in ('jprb', 'selected_real_kind(13, 300)')


@pytest.mark.parametrize('frontend', available_frontends())
def test_routine_call_arrays(header_path, frontend, tmp_path):
    """
    Test that arrays passed down a subroutine call are treated as arrays.
    """
    fcode = """
subroutine routine_call_caller(x, y, vector, matrix, item)
  ! Simple routine calling another routine
  use header, only: derived_type
  implicit none

  integer, parameter :: jprb = selected_real_kind(13,300)
  integer, intent(in) :: x, y
  real(kind=jprb), intent(inout) :: vector(x), matrix(x, y)
  type(derived_type), intent(inout) :: item

  ! To a parser, these arrays look like scalarst!
  call routine_call_callee(x, y, vector, matrix, item%matrix)

end subroutine routine_call_caller
"""
    header = Sourcefile.from_file(header_path, frontend=frontend, xmods=[tmp_path])['header']
    routine = Subroutine.from_source(fcode, frontend=frontend, definitions=header, xmods=[tmp_path])
    call = FindNodes(ir.CallStatement).visit(routine.body)[0]

    assert str(call.arguments[0]) == 'x'
    assert str(call.arguments[1]) == 'y'
    assert str(call.arguments[2]) == 'vector'
    assert str(call.arguments[3]) == 'matrix'
    assert str(call.arguments[4]) == 'item%matrix'

    assert isinstance(call.arguments[0], sym.Scalar)
    assert isinstance(call.arguments[1], sym.Scalar)
    assert isinstance(call.arguments[2], sym.Array)
    assert isinstance(call.arguments[3], sym.Array)
    assert isinstance(call.arguments[4], sym.Array)

    assert fexprgen(call.arguments[2].shape) == '(x,)'
    assert fexprgen(call.arguments[3].shape) == '(x, y)'
#    assert fexprgen(call.arguments[4].shape) in ['(3, 3)', '(1:3, 1:3)']

    assert fgen(call) == 'CALL routine_call_callee(x, y, vector, matrix, item%matrix)'


@pytest.mark.parametrize('frontend', available_frontends())
def test_call_no_arg(frontend):
    routine = Subroutine.from_source(frontend=frontend, source="""
subroutine routine_call_no_arg()
  implicit none

  call abort
end subroutine routine_call_no_arg
""")
    calls = FindNodes(ir.CallStatement).visit(routine.body)
    assert len(calls) == 1
    assert calls[0].arguments == ()
    assert calls[0].kwarguments == ()


@pytest.mark.parametrize('frontend', available_frontends())
def test_call_kwargs(frontend):
    routine = Subroutine.from_source(frontend=frontend, source="""
subroutine routine_call_kwargs()
  implicit none
  integer :: kprocs

  call mpl_init(kprocs=kprocs, cdstring='routine_call_kwargs')
end subroutine routine_call_kwargs
""")
    calls = FindNodes(ir.CallStatement).visit(routine.body)
    assert len(calls) == 1
    assert calls[0].name == 'mpl_init'

    assert calls[0].arguments == ()
    assert len(calls[0].kwarguments) == 2
    assert all(isinstance(arg, tuple) and len(arg) == 2 for arg in calls[0].kwarguments)

    assert calls[0].kwarguments[0][0] == 'kprocs'
    assert (isinstance(calls[0].kwarguments[0][1], sym.Scalar) and
            calls[0].kwarguments[0][1].name == 'kprocs')

    assert calls[0].kwarguments[1] == ('cdstring', sym.StringLiteral('routine_call_kwargs'))


@pytest.mark.parametrize('frontend', available_frontends())
def test_call_args_kwargs(frontend):
    routine = Subroutine.from_source(frontend=frontend, source="""
subroutine routine_call_args_kwargs(pbuf, ktag, kdest)
  implicit none
  integer, intent(in) :: pbuf(:), ktag, kdest

  call mpl_send(pbuf, ktag, kdest, cdstring='routine_call_args_kwargs')
end subroutine routine_call_args_kwargs
""")
    calls = FindNodes(ir.CallStatement).visit(routine.body)
    assert len(calls) == 1
    assert calls[0].name == 'mpl_send'
    assert len(calls[0].arguments) == 3
    assert all(a.name == b.name for a, b in zip(calls[0].arguments, routine.arguments))
    assert calls[0].kwarguments == (('cdstring', sym.StringLiteral('routine_call_args_kwargs')),)


@pytest.mark.parametrize('frontend', available_frontends())
def test_convert_endian(tmp_path, frontend):
    pre = """
SUBROUTINE ROUTINE_CONVERT_ENDIAN()
  INTEGER :: IUNIT
  CHARACTER(LEN=100) :: CL_CFILE
"""
    body = """
IUNIT = 61
OPEN(IUNIT, FILE=TRIM(CL_CFILE), FORM="UNFORMATTED", CONVERT='BIG_ENDIAN')
IUNIT = 62
OPEN(IUNIT, FILE=TRIM(CL_CFILE), CONVERT="LITTLE_ENDIAN", &
  & FORM="UNFORMATTED")
"""
    post = """
END SUBROUTINE ROUTINE_CONVERT_ENDIAN
"""
    fcode = pre + body + post

    filepath = tmp_path/(f'routine_convert_endian_{frontend}.f90')
    Sourcefile.to_file(fcode, filepath)
    routine = Sourcefile.from_file(filepath, frontend=frontend, preprocess=True)['routine_convert_endian']

    if frontend == OMNI:
        # F... OMNI
        body = body.replace('OPEN(IUNIT', 'OPEN(UNIT=IUNIT')
        body = body.replace('"', "'")
        body = body.replace('&\n  & ', '')
    # TODO: This is hacky as the fgen backend is still pretty much WIP
    assert fgen(routine.body).upper().strip() == body.strip()
    filepath.unlink()


@pytest.mark.parametrize('frontend', available_frontends())
def test_open_newunit(tmp_path, frontend):
    pre = """
SUBROUTINE ROUTINE_OPEN_NEWUNIT()
  INTEGER :: IUNIT
  CHARACTER(LEN=100) :: CL_CFILE
"""
    body = """
OPEN(NEWUNIT=IUNIT, FILE=TRIM(CL_CFILE), FORM="UNFORMATTED")
OPEN(FILE=TRIM(CL_CFILE), FORM="UNFORMATTED", NEWUNIT=IUNIT)
OPEN(FILE=TRIM(CL_CFILE), NEWUNIT=IUNIT, &
  & FORM="UNFORMATTED")
OPEN(FILE=TRIM(CL_CFILE), NEWUNIT=IUNIT&
  & , FORM="UNFORMATTED")
"""
    post = """
END SUBROUTINE ROUTINE_OPEN_NEWUNIT
"""
    fcode = pre + body + post

    filepath = tmp_path/(f'routine_open_newunit_{frontend}.f90')
    Sourcefile.to_file(fcode, filepath)
    routine = Sourcefile.from_file(filepath, frontend=frontend, preprocess=True)['routine_open_newunit']

    if frontend == OMNI:
        # F... OMNI
        body = body.replace('"', "'")
        body = body.replace('&\n  & ', '')
    # TODO: This is hacky as the fgen backend is still pretty much WIP
    assert fgen(routine.body).upper().strip() == body.strip()
    filepath.unlink()


@pytest.mark.parametrize('frontend', available_frontends())
def test_empty_spec(frontend):
    routine = Subroutine.from_source(frontend=frontend, source="""
subroutine routine_empty_spec
write(*,*) 'Hello world!'
end subroutine routine_empty_spec
""")
    if frontend == OMNI:
        # OMNI inserts IMPLICIT NONE into spec
        assert len(routine.spec.body) == 1
    else:
        assert not routine.spec.body
    assert len(routine.body.body) == 1


@pytest.mark.parametrize('frontend', available_frontends())
def test_member_procedures(tmp_path, frontend):
    """
    Test member subroutine and function
    """
    fcode = """
subroutine routine_member_procedures(in1, in2, out1, out2)
  ! Test member subroutine and function
  implicit none
  integer, intent(in) :: in1, in2
  integer, intent(out) :: out1, out2
  integer :: localvar

  localvar = in2

  call member_procedure(in1, out1)
  out2 = member_function(out1)
contains
  subroutine member_procedure(in1, out1)
    ! This member procedure shadows some variables and uses
    ! a variable from the parent scope
    implicit none
    integer, intent(in) :: in1
    integer, intent(out) :: out1

    out1 = 5 * in1 + localvar + member_function(1)
  end subroutine member_procedure

  ! Below is disabled because f90wrap (wrongly) exhibits that
  ! symbol to the public, which causes double defined symbols
  ! upon compilation.

  function member_function(in2)
    ! This function is just included to test that functions
    ! are also possible
    implicit none
    integer, intent(in) :: in2
    integer :: member_function

    member_function = 3 * in2 + 2
  end function member_function
end subroutine routine_member_procedures
"""
    # Check that member procedures are parsed correctly
    routine = Subroutine.from_source(fcode, frontend=frontend)
    assert len(routine.members) == 2

    assert routine.members[0].name == 'member_procedure'
    assert routine.members[0].symbol_attrs.lookup('localvar', recursive=False) is None
    assert routine.members[0].symbol_attrs.lookup('localvar') is not None
    assert routine.members[0].get_symbol_scope('localvar') is routine
    assert routine.members[0].symbol_attrs.lookup('in1') is not None
    assert routine.symbol_attrs.lookup('in1') is not None
    assert routine.members[0].get_symbol_scope('in1') is routine.members[0]

    # Check that inline function is correctly identified
    inline_calls = list(FindInlineCalls().visit(routine.members[0].body))
    assert len(inline_calls) == 1
    assert inline_calls[0].function.name == 'member_function'
    assert inline_calls[0].function.type.dtype.procedure == routine.members[1]

    assert routine.members[1].name == 'member_function'
    assert routine.members[1].symbol_attrs.lookup('in2') is not None
    assert routine.members[1].get_symbol_scope('in2') is routine.members[1]
    assert routine.symbol_attrs.lookup('in2') is not None
    assert routine.get_symbol_scope('in2') is routine

    # Generate code, compile and load
    filepath = tmp_path/(f'routine_member_procedures_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname='routine_member_procedures')

    # Test results of the generated and compiled code
    out1, out2 = function(1, 2)
    assert out1 == 12
    assert out2 == 38
    clean_test(filepath)


@pytest.mark.parametrize('frontend', available_frontends())
def test_member_routine_clone(frontend):
    """
    Test that member subroutine scopes get cloned correctly.
    """
    fcode = """
subroutine member_routine_clone(in1, in2, out1, out2)
  ! Test member subroutine and function
  implicit none
  integer, intent(in) :: in1, in2
  integer, intent(out) :: out1, out2
  integer :: localvar

  localvar = in2

  call member_procedure(in1, out1)
  out2 = 3 * out1 + 2

contains
  subroutine member_procedure(in1, out1)
    ! This member procedure shadows some variables and uses
    ! a variable from the parent scope
    implicit none
    integer, intent(in) :: in1
    integer, intent(out) :: out1

    out1 = 5 * in1 + localvar
  end subroutine member_procedure
end subroutine
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)
    new_routine = routine.clone()

    # Ensure we have cloned routine and member
    assert routine is not new_routine
    assert routine.members[0] is not new_routine.members[0]
    assert fgen(routine) == fgen(new_routine)
    assert fgen(routine.members[0]) == fgen(new_routine.members[0])

    # Check that the scopes are linked correctly
    assert routine.members[0].parent is routine
    assert new_routine.members[0].parent is new_routine

    # Check that variables are in the right scope everywhere
    assert all(v.scope is routine for v in FindVariables().visit(routine.ir))
    assert all(v.scope in (routine, routine.members[0]) for v in FindVariables().visit(routine.members[0].ir))
    assert all(v.scope is new_routine for v in FindVariables().visit(new_routine.ir))
    assert all(
        v.scope in (new_routine, new_routine.members[0])
        for v in FindVariables().visit(new_routine.members[0].ir)
    )


@pytest.mark.parametrize('frontend', available_frontends())
def test_member_routine_clone_inplace(frontend):
    """
    Test that member subroutine scopes get cloned correctly.
    """
    fcode = """
subroutine member_routine_clone(in1, in2, out1, out2)
  ! Test member subroutine and function
  implicit none
  integer, intent(in) :: in1, in2
  integer, intent(out) :: out1, out2
  integer :: localvar

  localvar = in2

  call member_procedure(in1, out1)
  out2 = 3 * out1 + 2

contains
  subroutine member_procedure(in1, out1)
    ! This member procedure shadows some variables and uses
    ! a variable from the parent scope
    implicit none
    integer, intent(in) :: in1
    integer, intent(out) :: out1

    out1 = 5 * in1 + localvar
  end subroutine member_procedure

  subroutine other_member(inout1)
    ! Another member that uses a parent symbol
    implicit none
    integer, intent(inout) :: inout1

    inout1 = 2 * inout1 + localvar
  end subroutine other_member
end subroutine
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    # Make sure the initial state is as expected
    member = routine['member_procedure']
    assert member.parent is routine
    assert member.symbol_attrs.parent is routine.symbol_attrs
    other_member = routine['other_member']
    assert other_member.parent is routine
    assert other_member.symbol_attrs.parent is routine.symbol_attrs

    # Put the inherited symbol in the local scope, first with a clean clone...
    member.variables += (routine.variable_map['localvar'].clone(scope=member),)
    member = member.clone(parent=None)
    # ...and then with a clone that preserves the symbol table
    other_member.variables += (routine.variable_map['localvar'].clone(scope=other_member),)
    other_member = other_member.clone(parent=None, symbol_attrs=other_member.symbol_attrs)
    # Ultimately, remove the member routines
    routine = routine.clone(contains=None)

    # Check that variables are in the right scope everywhere
    assert all(v.scope is routine for v in FindVariables().visit(routine.ir))
    assert all(v.scope is member for v in FindVariables().visit(member.ir))

    # Check that we aren't looking somewhere above anymore
    assert member.parent is None
    assert member.symbol_attrs.parent is None
    assert member.parent is None
    assert member.symbol_attrs._parent is None
    assert other_member.parent is None
    assert other_member.symbol_attrs.parent is None
    assert other_member.parent is None
    assert other_member.symbol_attrs.parent is None


@pytest.mark.parametrize('frontend', available_frontends())
def test_external_stmt(tmp_path, frontend):
    """
    Tests procedures passed as dummy arguments and declared as EXTERNAL.
    """
    fcode_external = """
! This should be tested as well with interface statements in the caller
! routine, and the subprogram definitions outside (to have "truly external"
! procedures, however, we need to make the INTERFACE support more robust first

subroutine other_external_subroutine(outvar)
  implicit none
  integer, intent(out) :: outvar
  outvar = 4
end subroutine other_external_subroutine

function other_external_function() result(outvar)
  implicit none
  integer :: outvar
  outvar = 6
end function other_external_function
    """.strip()

    fcode = """
subroutine routine_external_stmt(invar, sub1, sub2, sub3, outvar, func1, func2, func3)
  implicit none
  integer, intent(in) :: invar
  external sub1
  external :: sub2, sub3
  integer, intent(out) :: outvar
  integer, external :: func1, func2
  integer, external :: func3
  integer tmp

  call sub1(tmp)
  outvar = invar + tmp  ! invar + 1
  call sub2(tmp)
  outvar = outvar + tmp + func1()  ! (invar + 1) + 1 + 6
  call sub3(tmp)
  outvar = outvar + tmp + func2()  ! (invar + 8) + 4 + 2
  tmp = func3()
  outvar = outvar + tmp  ! (invar + 14) + 2
end subroutine routine_external_stmt

subroutine routine_call_external_stmt(invar, outvar)
  implicit none
  integer, intent(in) :: invar
  integer, intent(out) :: outvar

  interface
    subroutine other_external_subroutine(outvar)
      integer, intent(out) :: outvar
    end subroutine other_external_subroutine
  end interface

  interface
    function other_external_function()
      integer :: other_external_function
    end function other_external_function
  end interface

  call routine_external_stmt(invar, external_subroutine, external_subroutine, other_external_subroutine, &
                            &outvar, other_external_function, external_function, external_function)

contains

  subroutine external_subroutine(outvar)
    implicit none
    integer, intent(out) :: outvar
    outvar = 1
  end subroutine external_subroutine

  function external_function()
    implicit none
    integer :: external_function
    external_function = 2
  end function external_function

end subroutine routine_call_external_stmt
    """.strip()

    source = Sourcefile.from_source(fcode, frontend=frontend)
    routine = source['routine_external_stmt']
    assert len(routine.arguments) == 8

    for decl in FindNodes(ir.ProcedureDeclaration).visit(routine.spec):
        # Is the EXTERNAL attribute set?
        assert decl.external
        for v in decl.symbols:
            # Are procedure names represented as Scalar objects?
            assert isinstance(v, sym.ProcedureSymbol)
            assert isinstance(v.type.dtype, ProcedureType)
            assert v.type.external is True
            assert v.type.dtype.procedure == BasicType.DEFERRED
            if 'sub' in v.name:
                assert not v.type.dtype.is_function
                assert v.type.dtype.return_type is None
            else:
                assert v.type.dtype.is_function
                assert v.type.dtype.return_type.compare(SymbolAttributes(BasicType.INTEGER))

    # Generate code, compile and load
    extpath = tmp_path/(f'subroutine_routine_external_{frontend}.f90')
    with extpath.open('w') as f:
        f.write(fcode_external)
    filepath = tmp_path/(f'subroutine_routine_external_stmt_{frontend}.f90')
    source.path = filepath
    lib = jit_compile_lib([source, extpath], path=tmp_path, name='subroutine_external')
    function = lib.routine_call_external_stmt

    outvar = function(7)
    assert outvar == 23
    clean_test(filepath)


@pytest.mark.parametrize('frontend', available_frontends())
def test_subroutine_interface(tmp_path, frontend, header_path):
    """
    Test auto-generation of an interface block for a given subroutine.
    """
    fcode = """
subroutine test_subroutine_interface (in1, in2, in3, out1, out2)
  use header, only: jprb
  IMPLICIT NONE
  integer, intent(in) :: in1, in2
  real(kind=jprb), intent(in) :: in3(in1, in2)
  real(kind=jprb), intent(out) :: out1, out2
  integer :: localvar
  localvar = in1 + in2
  out1 = real(localvar, kind=jprb)
  out2 = out1 + 2.
end subroutine
"""
    if frontend == OMNI:
        # Generate xmod
        Sourcefile.from_file(header_path, frontend=frontend, xmods=[tmp_path])

    routine = Subroutine.from_source(fcode, xmods=[tmp_path], frontend=frontend)

    if frontend == OMNI:
        assert fgen(routine.interface).strip() == """
INTERFACE
  SUBROUTINE test_subroutine_interface (in1, in2, in3, out1, out2)
    USE header, ONLY: jprb
    IMPLICIT NONE
    INTEGER, INTENT(IN) :: in1
    INTEGER, INTENT(IN) :: in2
    REAL(KIND=selected_real_kind(13, 300)), INTENT(IN) :: in3(in1, in2)
    REAL(KIND=selected_real_kind(13, 300)), INTENT(OUT) :: out1
    REAL(KIND=selected_real_kind(13, 300)), INTENT(OUT) :: out2
  END SUBROUTINE test_subroutine_interface
END INTERFACE
""".strip()
    else:
        assert fgen(routine.interface).strip() == """
INTERFACE
  SUBROUTINE test_subroutine_interface (in1, in2, in3, out1, out2)
    USE header, ONLY: jprb
    IMPLICIT NONE
    INTEGER, INTENT(IN) :: in1, in2
    REAL(KIND=jprb), INTENT(IN) :: in3(in1, in2)
    REAL(KIND=jprb), INTENT(OUT) :: out1, out2
  END SUBROUTINE test_subroutine_interface
END INTERFACE
""".strip()


@pytest.mark.parametrize('frontend', available_frontends())
def test_subroutine_rescope_symbols(tmp_path, frontend):
    """ Test the rescoping of variables. """

    fcode_module = """
module some_mod
implicit none
contains
  subroutine ext1(a)
    integer, intent(inout) :: a(:)
  end subroutine ext1

  subroutine ext2(a)
    integer, intent(inout) :: a(:)
  end subroutine ext2
end module some_mod
    """

    fcode = """
subroutine test_subroutine_rescope(a, b, n)
  use some_mod, only: ext1
  implicit none
  integer, intent(in) :: a(n)
  integer, intent(out) :: b(n)
  integer, intent(in) :: n
  integer :: j

  b(:) = 0

  do j=1,n
    b(j) = a(j)
  end do

  call nested_routine(b, n)
contains

  subroutine nested_routine(a, n)
    use some_mod, only: ext2
    integer, parameter :: jpim = selected_int_kind(4)
    integer, intent(inout) :: a(n)
    integer, intent(in) :: n
    integer(kind=jpim) :: j

    do j=1,n
      a(j) = a(j) + 1
    end do

    call ext1(a)
    call ext2(a)
  end subroutine nested_routine
end subroutine test_subroutine_rescope
    """.strip()

    Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path])
    routine = Subroutine.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    ref_fgen = fgen(routine)

    # Create a copy of the nested subroutine with rescoping and
    # make sure all symbols are in the right scope
    nested_spec = Transformer().visit(routine.members[0].spec)
    nested_body = Transformer().visit(routine.members[0].body)
    nested_routine = Subroutine(name=routine.members[0].name, args=routine.members[0]._dummies,
                                spec=nested_spec, body=nested_body, parent=routine,
                                rescope_symbols=True)

    for var in FindTypedSymbols().visit(nested_routine.ir):
        if var.name == 'ext1':
            assert var.scope is routine
        else:
            if var.name.lower() == 'selected_int_kind':
                continue
            assert var.scope is nested_routine

    # Make sure the KIND parameter symbol in the variable's type is also correctly rescoped
    if frontend == OMNI:  # OMNI resolves paramter kind symbols
        assert routine.members[0].variable_map['j'].type.kind == 4
        assert nested_routine.variable_map['j'].type.kind == 4
    else:
        assert routine.members[0].variable_map['j'].type.kind.scope is routine.members[0]
        assert nested_routine.variable_map['j'].type.kind.scope is nested_routine

    # Create another copy of the nested subroutine without rescoping
    nested_spec = Transformer().visit(routine.members[0].spec)
    nested_body = Transformer().visit(routine.members[0].body)
    other_routine = Subroutine(name=routine.members[0].name, args=routine.members[0].argnames,
                               spec=nested_spec, body=nested_body, parent=routine)

    # Save the kind symbol for later
    other_kind_var = other_routine.variable_map['j'].type.kind
    if frontend == OMNI:
        assert other_kind_var == 4
    else:
        assert other_kind_var.scope is routine.members[0]

    # Explicitly throw away type information from original nested routine
    routine.members[0]._parent = None
    routine.members[0].symbol_attrs.clear()
    routine.members[0].symbol_attrs._parent = None
    assert all(var.type is None for var in other_routine.variables)
    assert all(var.scope is not None for var in other_routine.variables)

    # Replace member routine by copied routine
    contains = [nested_routine if isinstance(c, Subroutine) else c for c in routine.contains.body]
    routine.contains = routine.contains.clone(body=contains)

    # Now, all variables should still be well-defined and fgen should produce the same string
    assert all(var.scope is not None for var in nested_routine.variables)
    assert fgen(routine) == ref_fgen

    # accessing any local type information should fail because either the scope got garbage
    # collected or its types are gonee
    assert all(var.scope is None or var.type is None for var in other_routine.variables)

    # Make sure changes apply also to the KIND attribute
    if frontend == OMNI:
        assert routine.members[0].variable_map['j'].type.kind == 4
    else:
        assert routine.members[0].variable_map['j'].type.kind.scope is routine.members[0]

        # This points (weakly) to an entry in routine.members[0].symbols which may or may not
        # have been garbage collected at this point
        assert other_kind_var.scope is not other_routine

    # fgen of the not rescoped routine should lack some type information and thus either fail or
    # produce a different output, depending on whether GC has already happened
    try:
        other_fgen = fgen(other_routine)
        assert other_fgen != ref_fgen
        assert len(other_fgen) < len(ref_fgen)
    except AttributeError as e:
        assert str(e) in (
            "'NoneType' object has no attribute 'compare'",
            "'NoneType' object has no attribute 'dtype'",
            "'NoneType' object has no attribute 'use_name'"
        )


@pytest.mark.parametrize('frontend', available_frontends())
def test_subroutine_rescope_clone(tmp_path, frontend):
    """ Test the rescoping of variables in clone. """
    fcode_module = """
module some_mod
implicit none
contains
  subroutine ext1(a)
    integer, intent(inout) :: a(:)
  end subroutine ext1

  subroutine ext2(a)
    integer, intent(inout) :: a(:)
  end subroutine ext2
end module some_mod
    """

    fcode = """
subroutine test_subroutine_rescope_clone(a, b, n)
  use some_mod, only: ext1
  implicit none
  integer, intent(in) :: a(n)
  integer, intent(out) :: b(n)
  integer, intent(in) :: n
  integer :: j

  b(:) = 0

  do j=1,n
    b(j) = a(j)
  end do

  call nested_routine(b, n)
contains

  subroutine nested_routine(a, n)
    use some_mod, only: ext2
    integer, intent(inout) :: a(n)
    integer, intent(in) :: n
    integer :: j

    do j=1,n
      a(j) = a(j) + 1
    end do

    call ext1(a)
    call ext2(a)
  end subroutine nested_routine
end subroutine test_subroutine_rescope_clone
    """.strip()

    Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path])
    routine = Subroutine.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    ref_fgen = fgen(routine)

    # Create a copy of the nested subroutine with rescoping and
    # make sure all symbols are in the right scope
    nested_routine = routine.members[0].clone()

    for var in FindTypedSymbols().visit(nested_routine.ir):
        if var.name == 'ext1':
            assert var.scope is routine
        else:
            assert var.scope is nested_routine

    # Create another copy of the nested subroutine without rescoping (this breaks
    # things on purpose and should never be done in practice, but hey, for the lolz)
    other_routine = routine.members[0].clone(symbol_attrs=routine.symbol_attrs.clone(), rescope_symbols=False)

    # Explicitly throw away type information from original nested routine
    routine.members[0]._parent = None
    routine.members[0].symbol_attrs.clear()
    routine.members[0].symbol_attrs._parent = None
    assert all(var.type is None for var in other_routine.variables)
    assert all(var.scope is not None for var in other_routine.variables)

    # Replace member routine by copied routine
    contains = [nested_routine if isinstance(c, Subroutine) else c for c in routine.contains.body]
    routine.contains = routine.contains.clone(body=contains)

    # Now, all variables should still be well-defined and fgen should produce the same string
    assert all(var.scope is not None for var in nested_routine.variables)
    assert fgen(routine) == ref_fgen

    # accessing any local type information should fail because either the scope got garbage
    # collected or its types are gonee
    assert all(var.scope is None or var.type is None for var in other_routine.variables)

    # fgen of the not rescoped routine should lack some type information and thus either fail or
    # produce a different output, depending on whether GC has already happened
    try:
        other_fgen = fgen(other_routine)
        assert other_fgen != ref_fgen
        assert len(other_fgen) < len(ref_fgen)
    except AttributeError as e:
        assert str(e) in (
            "'NoneType' object has no attribute 'compare'",
            "'NoneType' object has no attribute 'dtype'",
            "'NoneType' object has no attribute 'use_name'"
        )


@pytest.mark.parametrize('frontend', available_frontends())
def test_subroutine_stmt_func(tmp_path, frontend):
    """
    Test the correct identification of statement functions
    """
    fcode = """
subroutine subroutine_stmt_func(a, b)
    implicit none
    integer, intent(in) :: a
    integer, intent(out) :: b
    integer :: array(a)
    integer :: i, j, plus, minus
    plus(i, j) = i + j
    minus(i, j) = i - j
    integer :: mult
    integer :: tmp
    mult(i, j) = i * j

    array(a) = a
    tmp = plus(a, 5)
    tmp = minus(tmp, 1)
    b = mult(2, tmp)
end subroutine subroutine_stmt_func
    """
    routine = Subroutine.from_source(fcode, frontend=frontend)
    routine.name += f'_{frontend!s}'

    # Make sure the statement function injection doesn't invalidate source
    for assignment in FindNodes(ir.Assignment).visit(routine.body):
        assert assignment.source is not None

    # OMNI inlines statement functions, so we can only check correct representation
    # for fparser
    if frontend != OMNI:
        stmt_func_decls = {d.variable: d for d in FindNodes(ir.StatementFunction).visit(routine.spec)}
        assert len(stmt_func_decls) == 3

        for name in ('plus', 'minus', 'mult'):
            var = routine.variable_map[name]
            assert isinstance(var, sym.ProcedureSymbol)
            assert isinstance(var.type.dtype, ProcedureType)
            assert var.type.dtype.procedure is stmt_func_decls[var]
            assert stmt_func_decls[var].source is not None

    # Make sure this produces the correct result
    filepath = tmp_path/f'{routine.name}.f90'
    function = jit_compile(routine, filepath=filepath, objname=routine.name)
    assert function(3) == 14
    clean_test(filepath)


@pytest.mark.parametrize('frontend', available_frontends())
def test_mixed_declaration_interface(frontend):
    """
    A simple test to catch and shame mixed declarations.
    """
    fcode = """
subroutine valid_fortran(i, m)
   integer :: i, j, m
   integer :: k,l
end subroutine valid_fortran
"""

    with pytest.raises(AssertionError) as error:
        routine = Subroutine.from_source(fcode, frontend=frontend)
        assert isinstance(routine.body, ir.Section)
        assert isinstance(routine.spec, ir.Section)
        _ = routine.interface

    assert "Declarations must have intents" in str(error.value)


@pytest.mark.parametrize('frontend', available_frontends())
def test_subroutine_comparison(frontend):
    """
    Test that string-equivalence works on relevant components.
    """

    fcode = """
subroutine my_routine(n, a, b, d)
  integer, intent(in) :: n
  real, intent(in) :: a(n), b(n)
  real, intent(out) :: d(n)
  integer :: i

  do i=1, n
    d(i) = a(i) + b(i)
  end do
end subroutine my_routine
"""
    # Two distinct string-equivalent subroutine objects
    r1 = Subroutine.from_source(fcode, frontend=frontend)
    r2 = Subroutine.from_source(fcode, frontend=frontend)

    assert r1.symbol_attrs == r2.symbol_attrs
    assert r1.spec == r2.spec
    assert r1.body == r2.body
    assert r1 == r2

    # Counter example: Change the semantic meaning by adding an index
    # offset, so that symbol table and declaration spec are identical.
    r3 = Subroutine.from_source(fcode.replace('d(i)', 'd(i+1)'), frontend=frontend)
    assert r1.symbol_attrs == r3.symbol_attrs
    # OMNI source file paths are affected by the string change, which
    # are attached and check to each source node object
    if frontend != OMNI:
        assert r1.spec == r3.spec
    assert not r1.body == r3.body
    assert not r1 == r3


@pytest.mark.parametrize('frontend', available_frontends())
def test_subroutine_comparison_case_sensitive(frontend):
    """
    Test that semantic, but no string-equivalence evaluates as not eqal
    """

    fcode = """
subroutine my_routine(n, a, b, d)
  integer, intent(in) :: n
  real, intent(in) :: a(n), b(n)
  real, intent(out) :: d(n)
  integer :: i

  do i=1, n
    d(i) = a(i) + b(i)
  end do
end subroutine my_routine
"""
    # Create two subroutine objects, but capitalize a variable in one
    r1 = Subroutine.from_source(fcode, frontend=frontend)
    r2 = Subroutine.from_source(fcode.replace('d(i)', 'D(I)'), frontend=frontend)

    assert not 'D(I)' in fgen(r1)
    if frontend != OMNI:  # OMNI always downcases!
        assert 'D(I)' in fgen(r2)

    # Ensure that the equivalent parts match, but body and routine do not!
    assert r1.symbol_attrs == r2.symbol_attrs
    # OMNI source file paths are affected by the string change, which
    # are attached and check to each source node object
    if frontend != OMNI:
        assert r1.spec == r2.spec
    assert not r1.body == r2.body
    assert not r1 == r2


@pytest.mark.parametrize('frontend', available_frontends())
def test_subroutine_lazy_arguments_incomplete1(frontend):
    """
    Test that argument lists for subroutines are correctly captured when the object is made
    complete.

    The rationale for this test is that for dummy argument lists with interleaved comments and line
    breaks, matching is non-trivial and, since we don't currently need the argument list
    in the incomplete REGEX-parsed IR, we accept that this information is incomplete initially.
    tmp_path, we make sure this information is captured correctly after completing the full frontend
    parse.
    """
    fcode = """
subroutine my_routine(n, a, b, d)
    integer, intent(in) :: n
    real, intent(in) :: a(n), b(n)
    real, intent(out) :: d(n)
    integer :: i

    do i=1, n
        d(i) = a(i) + b(i)
    end do
end subroutine my_routine
    """.strip()

    routine = Subroutine.from_source(fcode, frontend=REGEX)
    assert routine._incomplete
    assert routine.arguments == ()
    assert routine.argnames == []
    assert routine._dummies == ()
    assert all(isinstance(arg, sym.DeferredTypeSymbol) for arg in routine.arguments)

    routine.make_complete(frontend=frontend)
    assert not routine._incomplete
    assert routine.arguments == ('n', 'a(n)', 'b(n)', 'd(n)')
    assert routine.argnames == ['n', 'a', 'b', 'd']
    assert routine._dummies == ('n', 'a', 'b', 'd')
    assert isinstance(routine.arguments[0], sym.Scalar)
    assert all(isinstance(arg, sym.Array) for arg in routine.arguments[1:])


@pytest.mark.parametrize('frontend', available_frontends())
def test_subroutine_lazy_arguments_incomplete2(frontend):
    """
    Test that argument lists for subroutines are correctly captured when the object is made
    complete.

    The rationale for this test is that for dummy argument lists with interleaved comments and line
    breaks, matching is non-trivial and, since we don't currently need the argument list
    in the incomplete REGEX-parsed IR, we accept that this information is not available initially.
    tmp_path, we make sure this information is captured correctly after completing the full frontend
    parse.
    """
    fcode = """
SUBROUTINE CLOUDSC &
 !---input
 & (KIDIA,    KFDIA,    KLON,    KLEV,&
 & PT, PQ, &
 !---prognostic fields
 & PA,&
 & PCLV,  &
 & PSUPSAT,&
!-- arrays for aerosol-cloud interactions
!!! & PQAER,    KAER, &
 & PRE_ICE,&
 & PCCN,     PNICE,&
 !---diagnostic output
 & PCOVPTOT, PRAINFRAC_TOPRFZ,&
 !---resulting fluxes
 & PFSQLF,   PFSQIF ,  PFCQNNG,  PFCQLNG&
 & )
IMPLICIT NONE
INTEGER, PARAMETER :: JPIM = SELECTED_INT_KIND(9)
INTEGER, PARAMETER :: JPRB = SELECTED_REAL_KIND(13,300)
INTEGER(KIND=JPIM),PARAMETER :: NCLV=5      ! number of microphysics variables
INTEGER(KIND=JPIM),INTENT(IN)    :: KLON             ! Number of grid points
INTEGER(KIND=JPIM),INTENT(IN)    :: KLEV             ! Number of levels
INTEGER(KIND=JPIM),INTENT(IN)    :: KIDIA
INTEGER(KIND=JPIM),INTENT(IN)    :: KFDIA
REAL(KIND=JPRB)   ,INTENT(IN)    :: PT(KLON,KLEV)    ! T at start of callpar
REAL(KIND=JPRB)   ,INTENT(IN)    :: PQ(KLON,KLEV)    ! Q at start of callpar
REAL(KIND=JPRB)   ,INTENT(IN)    :: PA(KLON,KLEV)    ! Original Cloud fraction (t)
REAL(KIND=JPRB)   ,INTENT(IN)    :: PCLV(KLON,KLEV,NCLV)
REAL(KIND=JPRB)   ,INTENT(IN)    :: PSUPSAT(KLON,KLEV)
REAL(KIND=JPRB)   ,INTENT(IN)    :: PRE_ICE(KLON,KLEV)
REAL(KIND=JPRB)   ,INTENT(IN)    :: PCCN(KLON,KLEV)     ! liquid cloud condensation nuclei
REAL(KIND=JPRB)   ,INTENT(IN)    :: PNICE(KLON,KLEV)    ! ice number concentration (cf. CCN)
REAL(KIND=JPRB)   ,INTENT(OUT)   :: PCOVPTOT(KLON,KLEV) ! Precip fraction
REAL(KIND=JPRB)   ,INTENT(OUT)   :: PRAINFRAC_TOPRFZ(KLON)
REAL(KIND=JPRB)   ,INTENT(OUT)   :: PFSQLF(KLON,KLEV+1)  ! Flux of liquid
REAL(KIND=JPRB)   ,INTENT(OUT)   :: PFSQIF(KLON,KLEV+1)  ! Flux of ice
REAL(KIND=JPRB)   ,INTENT(OUT)   :: PFCQLNG(KLON,KLEV+1) ! -ve corr for liq
REAL(KIND=JPRB)   ,INTENT(OUT)   :: PFCQNNG(KLON,KLEV+1) ! -ve corr for ice
END SUBROUTINE CLOUDSC
    """.strip()

    argnames = (
        'kidia', 'kfdia', 'klon', 'klev', 'pt', 'pq',
        'pa', 'pclv', 'psupsat',
        'pre_ice', 'pccn', 'pnice',
        'pcovptot', 'prainfrac_toprfz',
        'pfsqlf', 'pfsqif', 'pfcqnng', 'pfcqlng'
    )
    argnames_with_dim = (
        'kidia', 'kfdia', 'klon', 'klev', 'pt(klon, klev)', 'pq(klon, klev)',
        'pa(klon, klev)', 'pclv(klon, klev, nclv)', 'psupsat(klon, klev)',
        'pre_ice(klon, klev)', 'pccn(klon, klev)', 'pnice(klon, klev)',
        'pcovptot(klon, klev)', 'prainfrac_toprfz(klon)',
        'pfsqlf(klon, klev + 1)', 'pfsqif(klon, klev + 1)', 'pfcqnng(klon, klev + 1)', 'pfcqlng(klon, klev + 1)'
    )

    routine = Subroutine.from_source(fcode, frontend=REGEX)
    assert routine._incomplete
    # NOTE: This represents the current capabilities of the REGEX frontend. If this test
    # suddenly fails because the argument list happens to be captured correctly:
    # Nice one! Go ahead and change the test.
    assert routine.arguments == ()
    assert routine.argnames == []
    assert routine._dummies == ()
    assert all(isinstance(arg, sym.DeferredTypeSymbol) for arg in routine.arguments)

    routine.make_complete(frontend=frontend)
    assert not routine._incomplete
    assert routine.arguments == argnames_with_dim
    assert [arg.upper() for arg in routine.argnames] == [arg.upper() for arg in argnames]
    assert routine._dummies == argnames
    assert all(isinstance(arg, sym.Scalar) for arg in routine.arguments[:4])
    assert all(isinstance(arg, sym.Array) for arg in routine.arguments[4:])


@pytest.mark.parametrize('frontend', available_frontends())
def test_subroutine_clone_contained(frontend):
    fcode = """
subroutine driver(n, a)
    implicit none
    integer, intent(in) :: n
    integer, intent(out), allocatable :: a(:)
    integer, allocatable :: b(:)
    integer :: index

    allocate(a(n))
    allocate(b(n))
    a(:) = 1
    call kernel1(a, b, index)
    call kernel2(index, b)
    a(:) = b(:)
    deallocate(b)
contains
    subroutine kernel1(a, b, index)
        integer, intent(in) :: a(:)
        integer, intent(inout) :: b(:)
        integer, intent(in) :: index
        b(:) = a(:)
    end subroutine kernel1

    subroutine kernel2(index, a)
        integer, intent(in) :: index
        integer, intent(inout) :: a(:)
        a(:) = a(:) + 1
    end subroutine kernel2
end subroutine driver
    """.strip()

    source = Sourcefile.from_source(fcode, frontend=frontend)
    driver = source['driver']
    kernels = driver.subroutines

    def _verify_call_enrichment(driver_, kernels_):
        calls = FindNodes(ir.CallStatement).visit(driver_.body)
        assert len(calls) == 2

        for call in calls:
            assert call.name in ('kernel1', 'kernel2')
            assert isinstance(call.routine, Subroutine)
            assert call.routine in kernels_
            assert call.routine in driver_.subroutines

        for kernel in kernels_:
            kernel_type = [r.procedure_type for r in driver_.subroutines if r.name == kernel.name][0]
            assert kernel_type.procedure is kernel

    _verify_call_enrichment(driver, kernels)

    # !!! Note: it is not necessary to use all these clone() calls below, but it exposes a certain edge case !!!

    # We create new contained kernels, e.g. as a result of some transformation or hoisting or similar...
    cloned_kernels = tuple(k.clone() for k in kernels)
    # ... and create a new, separate driver object
    cloned_driver = driver.clone(contains=cloned_kernels)
    assert cloned_driver is not driver

    # Make sure we didn't call clone() on the provided override of the contained subroutines
    assert all(k1 is k2 for k1, k2 in zip(cloned_kernels, cloned_driver.subroutines))

    # And make sure the cloned kernels are different from the original kernels but point
    # to the right parent
    for cloned_kernel, kernel in zip(cloned_kernels, kernels):
        assert cloned_kernel.name == kernel.name
        assert cloned_kernel.parent is cloned_driver
        assert kernel.parent is driver
        assert cloned_kernel is not kernel

    _verify_call_enrichment(driver, kernels)
    _verify_call_enrichment(cloned_driver, cloned_kernels)

    # Get a list of the names of driver arguments
    driver_args = [a.name.lower() for a in cloned_driver.arguments]
    assert driver_args == ['n', 'a']

    _verify_call_enrichment(driver, kernels)
    _verify_call_enrichment(cloned_driver, cloned_kernels)


@pytest.mark.parametrize('frontend', available_frontends())
def test_enrich_explicit_interface(frontend):
    """
    Test enrich points to the actual routine and not the symbol declared
    in an explicit interface.
    """

    fcode_kernel = """
    subroutine kernel(a,b)
    implicit none
    integer, intent(inout) :: a
    integer, intent(out) :: b


    a = a + 1
    b = a

    end subroutine kernel
    """

    fcode_driver = """
    subroutine driver()
    implicit none

    interface
    subroutine kernel(a,b)
    integer, intent(inout) :: a
    integer, intent(out) :: b
    end subroutine kernel
    end interface

    integer :: a = 0
    integer :: b

    call kernel(a,b)

    end subroutine driver
    """

    kernel = Subroutine.from_source(fcode_kernel, frontend=frontend)
    driver = Subroutine.from_source(fcode_driver, frontend=frontend)

    driver.enrich(kernel)

    # check if call is enriched correctly
    calls = FindNodes(ir.CallStatement).visit(driver.body)
    assert calls[0].routine is kernel

    # check if the procedure symbol in the interface block has been removed from
    # driver's symbol table
    intfs = FindNodes(ir.Interface).visit(driver.spec)
    assert not intfs[0].body[0].parent

    # check that call still points to correct subroutine
    _ = [sym for intf in intfs for sym in intf.symbols]
    assert calls[0].routine is kernel

    # confirm that rescoping symbols has no effect
    driver.rescope_symbols()
    assert calls[0].routine is kernel


@pytest.mark.parametrize('frontend', available_frontends())
def test_enrich_derived_types(tmp_path, frontend):
    fcode = """
subroutine enrich_derived_types_routine(yda_array)
use field_array_module, only : field_3rb_array
implicit none
type(field_3rb_array), intent(inout) :: yda_array
yda_array%p = 0.
end subroutine enrich_derived_types_routine
    """.strip()

    fcode_module = """
module field_array_module
implicit none
type field_3rb_array
    real, pointer :: p(:,:,:)
end type field_3rb_array
end module field_array_module
    """.strip()

    module = Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path])
    routine = Subroutine.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    # The derived type is a dangling import
    field_3rb_symbol = routine.symbol_map['field_3rb_array']
    assert field_3rb_symbol.type.imported
    assert field_3rb_symbol.type.module is None
    assert field_3rb_symbol.type.dtype is BasicType.DEFERRED

    # The variable type is recognized as a derived type but without enrichment
    yda_array = routine.variable_map['yda_array']
    assert isinstance(yda_array.type.dtype, DerivedType)
    assert routine.variable_map['yda_array'].type.dtype.typedef is BasicType.DEFERRED

    # The pointer member has no type information
    yda_array_p = routine.resolve_typebound_var('yda_array%p')
    assert yda_array_p.type.dtype is BasicType.DEFERRED
    assert yda_array_p.type.shape is None

    # Pick out the typedef (before enrichment to validate object consistency)
    field_3rb_tdef = module['field_3rb_array']
    assert isinstance(field_3rb_tdef, ir.TypeDef)

    # Enrich the routine with module definitions
    routine.enrich(module)

    field_3rb_symbol = routine.symbol_map['field_3rb_array']
    yda_array_p = routine.resolve_typebound_var('yda_array%p')

    # Ensure the imported type symbol is correctly enriched
    assert field_3rb_symbol.type.imported
    assert field_3rb_symbol.type.module is module
    assert isinstance(field_3rb_symbol.type.dtype, DerivedType)

    # Ensure the information has been propagated to other variables
    assert isinstance(yda_array.type.dtype, DerivedType)
    assert yda_array.type.dtype.typedef is field_3rb_tdef
    assert yda_array_p.type.dtype is BasicType.REAL
    assert yda_array_p.type.shape == (':', ':', ':')
    assert isinstance(yda_array_p, sym.Array)

    # Double-check body and spec expressions
    decls = FindNodes(ir.VariableDeclaration).visit(routine.spec)
    assert len(decls) == 1
    assert len(decls[0].symbols) == 1
    assert isinstance(decls[0].symbols[0], sym.Scalar)
    assert decls[0].symbols[0].type.dtype.typedef == field_3rb_tdef

    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assigns) == 1
    assert isinstance(assigns[0].lhs, sym.Array)
    assert assigns[0].lhs.type.dtype == BasicType.REAL
    assert assigns[0].lhs.type.shape == (':', ':', ':')
    assert assigns[0].lhs.parent.type.dtype.typedef == field_3rb_tdef


@pytest.mark.parametrize('frontend', available_frontends())
def test_subroutine_deep_clone(frontend, tmp_path):
    """
    Test that deep-cloning a subroutine actually ensures clean scope separation.
    """
    fcode_module = """
module my_types
  implicit none
  integer, parameter :: jprb=4

  type nothing
    logical :: different
  end type nothing

  type that_thing
    integer :: n
    integer :: else
    type(nothing) :: entirely
  end type that_thing
end module my_types
"""

    fcode = """
subroutine myroutine(something)
  use my_types, only : jprb, that_thing
  implicit none

  type(that_thing), intent(inout) :: something
  real(kind=jprb) :: foo(something%n)

  foo(:)=0.0_jprb

  associate(thing=>something%else)
    if (something%entirely%different) then
      foo(:)=42.0_jprb
    else
      foo(:)=66.6_jprb
    end if
  end associate
end subroutine myroutine
"""
    Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path])
    routine = Subroutine.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    # Create a deep-copy of the routine
    new_routine = routine.clone()

    # Replace all assignments with dummy calls
    map_nodes={}
    for assign in FindNodes(ir.Assignment).visit(new_routine.body):
        map_nodes[assign] = ir.CallStatement(
            name=sym.DeferredTypeSymbol(name='testcall'), arguments=(assign.lhs,), scope=new_routine
        )
    new_routine.body = Transformer(map_nodes).visit(new_routine.body)

    # Ensure that the original copy of the routine remains unaffected
    assert len(FindNodes(ir.Assignment).visit(routine.body)) == 3
    assert len(FindNodes(ir.Assignment).visit(new_routine.body)) == 0


@pytest.mark.parametrize('frontend', available_frontends())
def test_call_args_kwargs_conversion(frontend):

    fcode_kernel = """
    subroutine kernel(a,b,c,d,e,f,g)
    implicit none
    integer, intent(inout) :: a
    integer, intent(out) :: b
    integer, intent(in) :: c, d, e, f, g


    a = a + 1
    b = a + c + d + e + f + g

    end subroutine kernel
    """

    fcode_driver = """
    subroutine driver()
    implicit none

    integer :: a
    integer :: b
    integer :: driver_c
    integer :: driver_d
    integer :: driver_ze
    integer :: driver_f
    integer :: driver_g

    a = 0

    call kernel(a, b, driver_c, driver_d, driver_ze, driver_f, driver_g)
    call kernel(a=a, b=b, c=driver_c, d=driver_d, e=driver_ze, f=driver_f, g=driver_g)
    call kernel(b=b, e=driver_ze, c=driver_c, d=driver_d, f=driver_f, g=driver_g, a=a)
    ! this is NOT allowed in Fortran
    ! call kernel(driver_c, driver_d, driver_ze, driver_f, driver_g, a=a, b=b)
    call kernel(a,b,driver_c, driver_d, driver_ze, g=driver_g, f=driver_f)

    end subroutine driver
    """

    kernel = Subroutine.from_source(fcode_kernel, frontend=frontend)
    driver = Subroutine.from_source(fcode_driver, frontend=frontend)
    driver.enrich(kernel)

    # already correct ordered kwarguments?
    kwargs_in_order = [True, True, False, False]
    # expected (kw)arguments in calls, 'driver_ze' to break alphabetical order
    call_args = ('a', 'b', 'driver_c', 'driver_d', 'driver_ze', 'driver_f', 'driver_g')
    # expected amount of kwargs for the corresponding calls
    len_kwargs = (0, 7, 7, 2)

    # sort kwargs
    for i_call, call in enumerate(FindNodes(ir.CallStatement).visit(driver.body)):
        assert call.is_kwargs_order_correct() == kwargs_in_order[i_call]
        call.sort_kwarguments()

    # check calls with sorted kwargs
    for i_call, call in enumerate(FindNodes(ir.CallStatement).visit(driver.body)):
        assert tuple(arg[1].name for arg in call.arg_iter()) == call_args
        assert len(call.kwarguments) == len_kwargs[i_call]

    # kwarg to arg conversion
    for call in FindNodes(ir.CallStatement).visit(driver.body):
        call.convert_kwargs_to_args()

    # check calls with kwargs converted to args
    for call in FindNodes(ir.CallStatement).visit(driver.body):
        assert tuple(arg.name for arg in call.arguments) == call_args
        assert call.kwarguments == ()


@pytest.mark.parametrize('frontend', available_frontends())
def test_resolve_typebound_var(frontend, tmp_path):
    """
    Test correct behaviour of :any:`Scope.resolve_typebound_var` utility
    """
    fcode = """
module header_mod
    implicit none
    type some_type
        integer :: ival
    end type some_type

    type other_type
        type(some_type) :: other
    end type other_type

    type third_type
        type(other_type) :: some
    end type third_type
end module header_mod

subroutine some_routine
    use header_mod, only: third_type
    implicit none
    type(third_type) :: tt
end subroutine
    """.strip()

    source = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    routine = source['some_routine']

    tt_some = routine.resolve_typebound_var('tt%some')
    assert tt_some == 'tt%some'
    assert tt_some.type.dtype.name == 'other_type'
    assert tt_some.type.dtype.typedef is source['header_mod']['other_type']

    tt_some_other_ival = routine.resolve_typebound_var('tt%some%other%ival')
    assert tt_some_other_ival == 'tt%some%other%ival'
    assert tt_some_other_ival.type.dtype == BasicType.INTEGER
    assert tt_some_other_ival.parent.type.dtype.name == 'some_type'
    assert tt_some_other_ival.parent.type.dtype.typedef is source['header_mod']['some_type']

    tt = routine.resolve_typebound_var('tt')
    assert tt == 'tt'
    assert tt.type.dtype.name == 'third_type'
    assert tt.type.dtype.typedef is source['header_mod']['third_type']

    # This throws an error as the type definition is available and therefore
    # the invalid member can be deduced
    with pytest.raises(KeyError):
        routine.resolve_typebound_var('tt%invalid%val')

    with pytest.raises(KeyError):
        routine.resolve_typebound_var('tt%some%invalid')

    # This throws errors as resolving derived type members for
    # non-declared derived types should not be possible
    with pytest.raises(KeyError):
        routine.resolve_typebound_var('not_tt%invalid')

    with pytest.raises(KeyError):
        routine.resolve_typebound_var('not_a_var')

    # Instead, we can creatae a deferred type variable in the scope and
    # resolve members relative to it
    not_tt = sym.Variable(name='not_tt', scope=routine)
    assert not_tt.type.dtype == BasicType.DEFERRED  # pylint: disable=no-member
    not_tt_invalid = not_tt.get_derived_type_member('invalid')  # pylint: disable=no-member
    assert not_tt_invalid == 'not_tt%invalid'
    assert not_tt_invalid.type.dtype == BasicType.DEFERRED


@pytest.mark.parametrize('frontend', available_frontends())
def test_resolve_typebound_var_missing_definition(frontend, tmp_path):
    """
    Test correct behaviour of :any:`Scope.resolve_typebound_var` utility
    in the absence of type information
    """
    fcode_module = """
module header_mod
    implicit none
    type some_type
        integer :: ival
    end type some_type

    type other_type
        type(some_type) :: other
    end type other_type

    type third_type
        type(other_type) :: some
    end type third_type
end module header_mod
"""

    fcode = """
subroutine some_routine
    use header_mod, only: third_type
    implicit none
    type(third_type) :: tt
end subroutine
    """.strip()

    Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path])
    source = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    routine = source['some_routine']

    # This does not throw an error as the use-case of incomplete type definitions
    # may well require working with incomplete type definitions
    tt_invalid_val = routine.resolve_typebound_var('tt%invalid%val')
    assert tt_invalid_val == 'tt%invalid%val'
    assert tt_invalid_val.type.dtype == BasicType.DEFERRED
    assert tt_invalid_val.parent.type.dtype == BasicType.DEFERRED


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('dim_decl', [':: add_to_a(n)', ', DIMENSION(n) :: add_to_a'])
def test_function_array_return_type(frontend, dim_decl):
    """
    Verify array return types are correctly represented with all frontends
    """
    fcode = f"""
subroutine member_functions
    implicit none
    integer :: i
    real(kind=8) :: a(3)
    contains
    function add_to_a(b, n)
      integer, intent(in) :: n
      real(kind=8), intent(in) :: b(n)
      real(kind=8) {dim_decl}

      do i = 1, n
        add_to_a(i) = a(i) + b(i)
      end do
    end function
end subroutine member_functions
    """.strip()
    routine = Function.from_source(fcode, frontend=frontend)
    add_to_a = routine['add_to_a']
    return_type = add_to_a.procedure_type.return_type
    assert return_type.dtype == BasicType.REAL
    assert return_type.shape == ('n',)
    ret_var = add_to_a.variable_map['add_to_a']
    assert ret_var.type.dtype == BasicType.REAL
    assert ret_var.type.shape == ('n',)
    assert ret_var.dimensions == ('n',)

    if frontend == OMNI:
        # OMNI frontend puts the shape declaration always on the variable
        assert ':: add_to_a(n)' in routine.to_fortran()
    else:
        assert dim_decl in routine.to_fortran()


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('ext2_import', [
    ('use ext2, only: ext2_var1, ext2_var2', ('ext2_var1', 'ext2_var2')),
    ('use ext2', ()),
    ('use ext2, ext2_var => ext2_var1', ('ext2_var',)),
    ('use ext2, only: ext2_var2', ('ext2_var2',))
    ])
def test_subroutine_imported_symbols(tmp_path, frontend, ext2_import):
    """ Test return of imported symbols """
    fcode_ext1_mod = """
    module ext1
      implicit none
      integer :: ext1_var1, ext1_var2, ext1_var3
    end module ext1
    """

    fcode_ext2_mod = """
    module ext2
      implicit none
      integer :: ext2_var1, ext2_var2
    end module ext2
    """

    fcode_ext3_mod = """
    module ext3
      implicit none
      integer :: ext3_var1, ext3_var2
    end module ext3
    """

    fcode_module = f"""
module parent_mod
  use ext1, only: ext1_var1
  {ext2_import[0]}
  implicit none
  contains
   subroutine routine1(a)
     use ext1, only: ext1_var2, ext1_var3
     use ext3
     integer, intent(inout) :: a(:)
   end subroutine routine1

   subroutine routine2(b)
     use ext3, only: ext3_var1, ext3_var2
     integer, intent(inout) :: b(:)
   end subroutine routine2
end module parent_mod
    """

    ext1_mod = Module.from_source(fcode_ext1_mod, frontend=frontend, xmods=[tmp_path])
    ext2_mod = Module.from_source(fcode_ext2_mod, frontend=frontend, xmods=[tmp_path])
    ext3_mod = Module.from_source(fcode_ext3_mod, frontend=frontend, xmods=[tmp_path])
    module  = Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path],
            definitions=[ext1_mod, ext2_mod, ext3_mod])
    routine1 = module.subroutines[0]
    routine2 = module.subroutines[1]

    # get imported_symbols and all_imported_symbols
    mod_imp_symbols = set(module.imported_symbols)
    mod_all_imp_symbols = set(module.all_imported_symbols)
    routine1_imp_symbols = set(routine1.imported_symbols)
    routine1_all_imp_symbols = set(routine1.all_imported_symbols)
    routine2_imp_symbols = set(routine2.imported_symbols)
    routine2_all_imp_symbols = set(routine2.all_imported_symbols)

    # check/test results
    exp_mod_imp_symbols = set(('ext1_var1',) + ext2_import[1])
    assert mod_imp_symbols == exp_mod_imp_symbols
    for var in exp_mod_imp_symbols:
        assert var in module.all_imported_symbol_map
    exp_routine1_imp_symbols = set(['ext1_var2', 'ext1_var3'])
    assert routine1_imp_symbols == exp_routine1_imp_symbols
    exp_routine2_imp_symbols = set(['ext3_var1', 'ext3_var2'])
    assert routine2_imp_symbols == exp_routine2_imp_symbols
    assert mod_imp_symbols == mod_all_imp_symbols
    assert routine1_all_imp_symbols == exp_routine1_imp_symbols | exp_mod_imp_symbols
    for var in exp_routine1_imp_symbols | exp_mod_imp_symbols:
        assert var in routine1.all_imported_symbol_map
    assert routine2_all_imp_symbols == exp_routine2_imp_symbols | exp_mod_imp_symbols
    for var in exp_routine2_imp_symbols | exp_mod_imp_symbols:
        assert var in routine2.all_imported_symbol_map
loki-ecmwf-0.3.6/loki/tests/test_function.py0000664000175000017500000002071615167130205021305 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki import Module, Function, fgen
from loki.frontend import available_frontends, OMNI, REGEX
from loki.ir import nodes as ir, FindNodes
from loki.types import BasicType


@pytest.mark.parametrize('frontend', available_frontends())
def test_function_return_type(tmp_path, frontend):
    """
    Test various ways to define the return type of a function
    """
    fcode = """
module my_funcs
implicit none
contains

  real(kind=8) function funca(a)
    real, intent(in) :: a

    funca = a
  end function funca

  function funcb(a)
    real(kind=8), intent(in) :: a
    real(kind=8) :: funcb

    funcb = a
  end function funcb

  function funcky(a) result(fun)
    real, intent(in) :: a
    real(kind=8) :: fun

    fun = a
  end function funcky

  real function square(a) result(b)
    implicit none
    real, intent(in) :: a
    b = a * a
  end function square
end module my_funcs
    """
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    # Implicit return type definition
    assert isinstance(module['funca'], Function)
    assert module['funca'].result_name == 'funca'
    assert module['funca'].return_type.dtype == BasicType.REAL
    assert module['funca'].return_type.kind == 8
    assert len(FindNodes(ir.VariableDeclaration).visit(module['funca'].spec)) == 2 if frontend == OMNI else 1
    assert len(FindNodes(ir.ProcedureDeclaration).visit(module['funca'].spec)) == 0

    if frontend == OMNI:
        # Ensure return type is declared (OMNI alwas inserts declaration)
        fdecl = tuple(
            d for d in FindNodes(ir.VariableDeclaration).visit(module['funca'].spec)
            if 'funca' in d.symbols
        )
        assert len(fdecl) == 1 and fdecl[0].symbols[0] == 'funca'
        assert fdecl[0].symbols[0].type.dtype == BasicType.REAL
        assert fdecl[0].symbols[0].type.kind == 8
    else:
        # Check for implicit return value in `fgen`
        fstr_header = fgen(module['funca']).splitlines()[0]
        assert 'real(kind=8) function funca (a)' == fstr_header.lower()

    # Explicit return type declaration
    assert isinstance(module['funcb'], Function)
    assert module['funcb'].result_name == 'funcb'
    assert module['funcb'].return_type.dtype == BasicType.REAL
    assert module['funcb'].return_type.kind == 8
    assert len(FindNodes(ir.VariableDeclaration).visit(module['funcb'].spec)) == 2
    assert len(FindNodes(ir.ProcedureDeclaration).visit(module['funcb'].spec)) == 0

    # Re-named return type declaration
    assert isinstance(module['funcky'], Function)
    assert module['funcky'].result_name == 'fun'
    assert module['funcky'].return_type.dtype == BasicType.REAL
    assert module['funcky'].return_type.kind == 8
    assert len(FindNodes(ir.VariableDeclaration).visit(module['funcky'].spec)) == 2
    assert len(FindNodes(ir.ProcedureDeclaration).visit(module['funcky'].spec)) == 0

    # Implicit return type and renamed result name
    assert isinstance(module['square'], Function)
    assert module['square'].result_name == 'b'
    assert module['square'].return_type.dtype == BasicType.REAL
    assert len(FindNodes(ir.VariableDeclaration).visit(module['square'].spec)) == 2 if frontend == OMNI else 1
    assert len(FindNodes(ir.ProcedureDeclaration).visit(module['square'].spec)) == 0


@pytest.mark.parametrize('frontend', available_frontends())
def test_function_prefix(frontend):
    """
    Test various prefixes that can occur in function/subroutine definitions
    """
    fcode = """
pure elemental real function f_elem(a)
    real, intent(in) :: a
    f_elem = a
end function f_elem
    """.strip()

    routine = Function.from_source(fcode, frontend=frontend)
    assert 'PURE' in routine.prefix
    assert 'ELEMENTAL' in routine.prefix
    assert isinstance(routine, Function)
    assert routine.return_type.dtype is BasicType.REAL

    assert routine.procedure_type.is_function is True
    assert routine.procedure_type.return_type.dtype is BasicType.REAL
    assert routine.procedure_type.procedure is routine

    assert routine.procedure_symbol.type.dtype.is_function is True
    assert routine.procedure_symbol.type.dtype.return_type.dtype is BasicType.REAL
    assert routine.procedure_symbol.type.dtype.procedure is routine

    code = fgen(routine)
    assert 'PURE' in code
    assert 'ELEMENTAL' in code


@pytest.mark.parametrize('frontend', available_frontends())
def test_function_suffix(frontend, tmp_path):
    """
    Test that subroutine suffixes are supported and correctly reproduced
    """
    fcode = """
module subroutine_suffix_mod
    implicit none

    interface
        function check_value(value) bind(C, name='check_value')
            use, intrinsic :: iso_c_binding
            real(c_float), value :: value
            integer(c_int) :: check_value
        end function check_value
    end interface

    interface
        function fix_value(value) result(fixed) bind(C, name='fix_value')
            use, intrinsic :: iso_c_binding
            real(c_float), value :: value
            real(c_float) :: fixed
        end function fix_value
    end interface
contains
    function out_of_physical_bounds(field, istartcol, iendcol, do_fix) result(is_bad)
        real, intent(inout) :: field(:)
        integer, intent(in) :: istartcol, iendcol
        logical, intent(in) :: do_fix
        logical :: is_bad

        integer :: jcol
        logical :: bad_value

        is_bad = .false.
        do jcol=istartcol,iendcol
            bad_value = check_value(field(jcol)) > 0
            is_bad = is_bad .or. bad_value
            if (do_fix .and. bad_value) field(jcol) = fix_value(field(jcol))
        end do
    end function out_of_physical_bounds
end module subroutine_suffix_mod
    """.strip()
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    check_value = module.interface_map['check_value'].body[0]
    assert check_value.is_function
    assert check_value.result_name == 'check_value'
    assert check_value.return_type.dtype is BasicType.INTEGER
    assert check_value.return_type.kind == 'c_int'
    if frontend != OMNI:
        assert check_value.bind == 'check_value'
        assert "bind(c, name='check_value')" in fgen(check_value).lower()

    fix_value = module.interface_map['fix_value'].body[0]
    assert fix_value.is_function
    assert fix_value.result_name == 'fixed'
    assert fix_value.return_type.dtype is BasicType.REAL
    assert fix_value.return_type.kind == 'c_float'
    if frontend == OMNI:
        assert "result(fixed)" in fgen(fix_value).lower()
    else:
        assert fix_value.bind == 'fix_value'
        assert "result(fixed) bind(c, name='fix_value')" in fgen(fix_value).lower()

    routine = module['out_of_physical_bounds']
    assert routine.is_function
    assert routine.result_name == 'is_bad'
    assert routine.bind is None
    assert routine.return_type.dtype is BasicType.LOGICAL
    assert "result(is_bad)" in fgen(routine).lower()


@pytest.mark.parametrize('frontend', available_frontends())
def test_function_lazy_prefix(frontend):
    """
    Test that prefixes for functions are correctly captured when the object is made
    complete.

    This test represents a case where the REGEX frontend fails to capture these attributes correctly.

    The rationale for this test is that we don't currently need these attributes
    in the incomplete REGEX-parsed IR and we accept that this information is incomplete initially.
    tmp_path, we make sure this information is captured correctly after completing the full frontend
    parse.
    """
    fcode = """
pure elemental real function f_elem(a)
    real, intent(in) :: a
    f_elem = a
end function f_elem
    """.strip()

    routine = Function.from_source(fcode, frontend=REGEX)
    assert routine._incomplete
    assert routine.prefix == ('pure elemental real',)
    assert routine.arguments == ()
    assert routine.is_function is True
    assert routine.return_type is None

    routine.make_complete(frontend=frontend)
    assert not routine._incomplete
    assert 'PURE' in routine.prefix
    assert 'ELEMENTAL' in routine.prefix
    assert routine.arguments == ('a',)
    assert routine.is_function is True
    assert routine.return_type.dtype is BasicType.REAL
loki-ecmwf-0.3.6/loki/tests/test_nested_types/0000775000175000017500000000000015167130205021606 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/tests/test_nested_types/kernel.f900000664000175000017500000000063015167130205023405 0ustar  alastairalastairmodule kernel

    use types, only: parent_type

    implicit none

    public kernel_routine

    contains

    subroutine kernel_routine(size, pt)
        integer, intent(in) :: size
        type(parent_type), intent(inout) :: pt

        integer :: i

        do i=1,size
            pt%type_member%x(i) = pt%member*pt%type_member%x(i)
        end do

    end subroutine kernel_routine

end module kernel
loki-ecmwf-0.3.6/loki/tests/test_nested_types/test_nested_types.py0000664000175000017500000000407415167130205025732 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from pathlib import Path
import pytest

from loki import Sourcefile, fexprgen
from loki.frontend import available_frontends, OMNI


@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'Loki annotations break frontend parser')]))
def test_nested_types(frontend, tmp_path):
    """
    Regression test that ensures that nested types are correctly
    propagated through manual construction from source files.
    """
    here = Path(__file__).parent

    # First, get the sub_type and check that the dimension annotation is honoured
    subtypes = Sourcefile.from_file(here/'sub_types.f90', frontend=frontend, xmods=[tmp_path])['sub_types']
    child = subtypes.typedef_map['sub_type']
    assert fexprgen(child.variables[0].shape) == '(size,)'

    # Check that dimension in sub_type has propagated to parent_type
    types = Sourcefile.from_file(here/'types.f90', definitions=subtypes,
                                 frontend=frontend, xmods=[tmp_path])['types']
    parent = types.typedef_map['parent_type']
    x = parent.variables[1].variable_map['x']
    assert fexprgen(x.shape) == '(size,)'

    # Ensure that the driver has the correct shape info for pt%type_member%x
    driver = Sourcefile.from_file(here/'driver.f90', definitions=types, frontend=frontend, xmods=[tmp_path])['driver']
    pt_d = driver.routines[0].variables[0]
    x_d = pt_d.variable_map['type_member'].variable_map['x']
    assert fexprgen(x_d.shape) == '(size,)'

    kernel = Sourcefile.from_file(here/'kernel.f90', definitions=types, frontend=frontend, xmods=[tmp_path])['kernel']
    pt_k = kernel.routines[0].variables[1]
    x_k = pt_k.variable_map['type_member'].variable_map['x']
    assert fexprgen(x_k.shape) == '(size,)'
loki-ecmwf-0.3.6/loki/tests/test_nested_types/driver.f900000664000175000017500000000112015167130205023413 0ustar  alastairalastairmodule driver

    use kernel, only: kernel_routine
    use types, only: parent_type

    contains

    subroutine driver_routine()
        type(parent_type) :: pt

        integer :: summed, i, size

        size = 100

        pt%member = 12
        ALLOCATE(pt%type_member%x(size))
        do i=1,size
            pt%type_member%x(i) = 1
        end do

        call kernel_routine(size, pt)

        summed = 0
        do i=1,size
            summed = summed + pt%type_member%x(i)
        end do

        print*, "the sum is", summed

    end subroutine driver_routine

end module driver
loki-ecmwf-0.3.6/loki/tests/test_nested_types/types.f900000664000175000017500000000033015167130205023266 0ustar  alastairalastairmodule types
    use sub_types, only: sub_type

    implicit none

    public parent_type
    
    type parent_type
        integer :: member
        type(sub_type) :: type_member
    end type parent_type

end moduleloki-ecmwf-0.3.6/loki/tests/test_nested_types/sub_types.f900000664000175000017500000000025615167130205024146 0ustar  alastairalastairmodule sub_types
    implicit none

    public sub_type

    type sub_type
    !$loki dimension(size)
    integer, pointer :: x(:)
    end type sub_type

end module sub_typesloki-ecmwf-0.3.6/loki/tests/test_examples.py0000664000175000017500000000322315167130205021270 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Automatically test the provided examples
"""
from pathlib import Path

import pytest
import nbformat
from jupyter_client.kernelspec import find_kernel_specs
from nbconvert.preprocessors import ExecutePreprocessor

example_path = Path(__file__).parent.parent/'example'

def is_ipython_available():
    """
    Check if Jupyter Notebook kernel is available
    """
    is_available = False
    try:
        import IPython  # pylint: disable=import-outside-toplevel,unused-import
        is_available = True
    except ImportError:
        pass
    return is_available

# Skip tests in this module if Jupyter Kernel not available
pytestmark = pytest.mark.skipif(
    not is_ipython_available() or 'python3' not in find_kernel_specs(),
    reason='IPython or Jupyter kernel are not available'
)

@pytest.mark.parametrize("notebook", example_path.glob('*.ipynb'))
def test_notebooks(notebook, monkeypatch):
    """
    Convert all example Jupyter notebooks to scripts and run them, making sure
    they run through without any problems
    """
    monkeypatch.chdir(example_path)

    with notebook.open() as f:
        nb = nbformat.read(f, as_version=4)
        ep = ExecutePreprocessor(timeout=60, kernel_name='python3')
        assert ep.preprocess(nb) is not None, f"Got empty notebook for {notebook}"
loki-ecmwf-0.3.6/loki/tests/sources/0000775000175000017500000000000015167130205017524 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/tests/sources/sourcefile_pp_directives.F900000664000175000017500000000033715167130205025067 0ustar  alastairalastairsubroutine routine_pp_directives
  print *,"Compiled ",__FILENAME__," on ",__DATE__
#define __FILENAME__ __FILE__
  print *,"This is ",__FILE__,__VERSION__
  y = __LINE__ * 5 + __LINE__
end subroutine routine_pp_directives
loki-ecmwf-0.3.6/loki/tests/sources/trivial_fortran_files/0000775000175000017500000000000015167130205024113 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/tests/sources/trivial_fortran_files/case_statement_subroutine.f900000664000175000017500000000054115167130205031711 0ustar  alastairalastairsubroutine check_grade(score)
  integer, intent(in) :: score

  select case (score)
    case (90:100)
      print *, "A"
    case (80:89)
      print *, "B"
    case (70:79)
      print *, "C"
    case (60:69)
      print *, "D"
    case (0:59)
      print *, "F"
    case default
      print *, "Invalid score"
  end select

end subroutine check_grade
loki-ecmwf-0.3.6/loki/tests/sources/trivial_fortran_files/module_with_subroutines.f900000664000175000017500000000120515167130205031413 0ustar  alastairalastairmodule math_operations
  implicit none

  public :: add, subtract, multiply

contains

  ! Subroutine to add two numbers
  subroutine add(x, y, result)
    real, intent(in) :: x, y
    real, intent(out) :: result
    result = x + y
  end subroutine add

  ! Subroutine to subtract two numbers
  subroutine subtract(x, y, result)
    real, intent(in) :: x, y
    real, intent(out) :: result
    result = x - y
  end subroutine subtract

  ! Subroutine to multiply two numbers
  subroutine multiply(x, y, result)
    real, intent(in) :: x, y
    real, intent(out) :: result
    result = x * y
  end subroutine multiply

end module math_operations
loki-ecmwf-0.3.6/loki/tests/sources/trivial_fortran_files/nested_if_else_statements_subroutine.f900000664000175000017500000000062715167130205034136 0ustar  alastairalastairsubroutine nested_if_example(x, y)
  integer, intent(in) :: x, y

  if (x > 0) then
    if (y > 0) then
      print *, "Both x and y are positive."
    else
      print *, "x is positive, but y is not."
    end if
  else
    if (y > 0) then
      print *, "x is not positive, but y is positive."
    else
      print *, "Both x and y are not positive."
    end if
  end if

end subroutine nested_if_example
loki-ecmwf-0.3.6/loki/tests/sources/trivial_fortran_files/if_else_statement_subroutine.f900000664000175000017500000000030715167130205032404 0ustar  alastairalastairsubroutine check_number(x)
  real, intent(in) :: x

  if (x > 0.0) then
    print *, "The number is positive."
  else
    print *, "The number is non-positive."
  end if

end subroutine check_number
loki-ecmwf-0.3.6/loki/tests/sources/header.f900000664000175000017500000000042115167130205021271 0ustar  alastairalastairmodule header
  ! header module to provide external typedefs
  integer, parameter :: jprb = selected_real_kind(13,300)

  type derived_type
    real(kind=jprb) :: scalar, vector(3), matrix(3, 3)
    real(kind=jprb) :: red_herring
  end type derived_type

end module header
loki-ecmwf-0.3.6/loki/tests/sources/stmt.func.h0000664000175000017500000000127315167130205021621 0ustar  alastairalastair!*
! ---------------------------------------------------

!   Sample of statement functions externalized into a
!   header file similar to how they are included into
!   Fortran source code in the IFS.
!
!   This is an excerpt from fcttre.func.h
!
! ---------------------------------------------------
REAL(KIND=JPRB) :: FOEDELTA
REAL(KIND=JPRB) :: PTARE
FOEDELTA (PTARE) = MAX (0.0_JPRB,SIGN(1.0_JPRB,PTARE-RTT))

REAL(KIND=JPRB) :: FOEEWMO, FOEELIQ, FOEEICE 
FOEEWMO( PTARE ) = R2ES*EXP(R3LES*(PTARE-RTT)/(PTARE-R4LES))
FOEELIQ( PTARE ) = R2ES*EXP(R3LES*(PTARE-RTT)/(PTARE-R4LES))
FOEEICE( PTARE ) = R2ES*EXP(R3IES*(PTARE-RTT)/(PTARE-R4IES))

! ---------------------------------------------------
loki-ecmwf-0.3.6/loki/tests/sources/projScopes/0000775000175000017500000000000015167130205021653 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/tests/sources/projScopes/driver.F900000664000175000017500000000026715167130205023433 0ustar  alastairalastairsubroutine driver
    use kernel1_mod, only: kernel1 => kernel
    use kernel2_mod, only: kernel2 => kernel
    implicit none

    call kernel1
    call kernel2
end subroutine driver
loki-ecmwf-0.3.6/loki/tests/sources/projScopes/kernel1_impl.F900000664000175000017500000000021315167130205024511 0ustar  alastairalastairmodule kernel1_impl
contains
    subroutine kernel_impl
    implicit none
    ! ...
    end subroutine kernel_impl
end module kernel1_impl
loki-ecmwf-0.3.6/loki/tests/sources/projScopes/kernel2_impl.F900000664000175000017500000000021315167130205024512 0ustar  alastairalastairmodule kernel2_impl
contains
    subroutine kernel_impl
    implicit none
    ! ...
    end subroutine kernel_impl
end module kernel2_impl
loki-ecmwf-0.3.6/loki/tests/sources/projScopes/kernel2_mod.F900000664000175000017500000000025315167130205024334 0ustar  alastairalastairmodule kernel2_mod
contains
    subroutine kernel
        use kernel2_impl
        implicit none
        call kernel_impl
    end subroutine kernel
end module kernel2_mod
loki-ecmwf-0.3.6/loki/tests/sources/projScopes/kernel1_mod.F900000664000175000017500000000025315167130205024333 0ustar  alastairalastairmodule kernel1_mod
contains
    subroutine kernel
        use kernel1_impl
        implicit none
        call kernel_impl
    end subroutine kernel
end module kernel1_mod
loki-ecmwf-0.3.6/loki/tests/sources/sourcefile_pp_include.F900000664000175000017500000000030715167130205024346 0ustar  alastairalastairsubroutine routine_pp_include(a, b, c)
  implicit none
  real(kind=4), intent(in) :: a, b
  real(kind=4), intent(out) :: c

#include "some_header.h"
  c = add(a, b)
end subroutine routine_pp_include
loki-ecmwf-0.3.6/loki/tests/sources/sourcefile_item.f900000664000175000017500000000143115167130205023221 0ustar  alastairalastairsubroutine routine_a
  integer a, i
  a = 1
  i = a + 1

  call routine_b(a, i)
end subroutine routine_a

module some_module
contains
  subroutine module_routine
    integer m
    m = 2

    call routine_b(m, 6)
  end subroutine module_routine

  function module_function(n)
    integer n
    n = 3
  end function module_function
end module some_module

subroutine routine_b(i,j)
  integer, intent(in) :: i, j
  integer b
  b = 4

  call contained_c(i)

  call routine_a()
contains

  subroutine contained_c(i)
    integer, intent(in) :: i
    integer c
    c = 5
  end subroutine contained_c

  subroutine contained_d(i)
    integer, intent(in) :: i
    integer c
    c = 8
  end subroutine contained_d
end subroutine routine_b

function function_d(d)
integer d
d = 6
end function function_d
loki-ecmwf-0.3.6/loki/tests/sources/projInlineCalls/0000775000175000017500000000000015167130205022614 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/tests/sources/projInlineCalls/driver.F900000664000175000017500000000111115167130205024361 0ustar  alastairalastairsubroutine driver(n, work)
  use some_module, only: return_one, some_var, add_args, some_type
  use vars_module, only: varA, varB
  implicit none

  interface
     real function double_real(i)
       integer, intent(in) :: i
     end function double_real
  end interface

  integer, intent(in) :: n
  real, intent(out) :: work(n)
  type(some_type) :: var
  integer :: i

  do i=1,n
    work(i) = double_real(i) + return_one()
    work(i) = work(i) + dble(some_var)
    work(i) = work(i) + add_args(i,1) + add_args(i,2)
    call var%do_something(work(i))
  enddo

end subroutine driver
loki-ecmwf-0.3.6/loki/tests/sources/projInlineCalls/vars_module.F900000664000175000017500000000013215167130205025410 0ustar  alastairalastairmodule vars_module
implicit none

   real :: varA
   real :: varB

end module vars_module
loki-ecmwf-0.3.6/loki/tests/sources/projInlineCalls/some_module.F900000664000175000017500000000152115167130205025403 0ustar  alastairalastairmodule some_module
implicit none

  integer, parameter :: some_var=1

  interface add_args
    procedure add_two_args
    procedure add_three_args
  end interface add_args

  type some_type
    real :: c = 1.
    contains
      procedure :: do_something => add_const
  end type

contains

  function return_one() result(one)
     real :: one
     one = 1.
  end function return_one

  function add_two_args(i,j) result(res)
     integer, intent(in) :: i,j
     real :: res
     res = dble(i+j)
  end function add_two_args

  function add_three_args(i,j,k) result(res)
     integer, intent(in) :: i,j,k
     real :: res
     res = dble(i+j+k)
  end function add_three_args

  subroutine add_const(self, a)
     class(some_type), intent(in) :: self
     real, intent(inout) :: a

     a = a + self%c
  end subroutine add_const

end module some_module
loki-ecmwf-0.3.6/loki/tests/sources/projInlineCalls/double_real.F900000664000175000017500000000024115167130205025346 0ustar  alastairalastairreal function double_real(i)
  use vars_module, only: varA, varB
  implicit none
  integer, intent(in) :: i

  double_real =  dble(i*2)
end function double_real
loki-ecmwf-0.3.6/loki/tests/sources/Fortran-extract-interface-source.f900000664000175000017500000001147715167130205026375 0ustar  alastairalastair! Test file lifted from Fcm interface generation test suite
! https://github.com/metomi/fcm/blob/2-3/t/Fcm/Build/Fortran-extract-interface-source.f90

! A simple function
logical function func_simple()
func_simple = .true.
end function func_simple

! A simple function, but with less friendly end
logical function func_simple_1()
func_simple_1 = .true.
end function

! A simple function, but with even less friendly end
logical function func_simple_2()
func_simple_2 = .true.
end function

! A pure simple function
pure logical function func_simple_pure()
func_simple_pure = .true.
end function func_simple_pure

! A pure recursive function
recursive pure integer function func_simple_recursive_pure(i) result(result)
integer, intent(in) :: i
if (i <= 0) then
    result = i
else
    result = i + func_simple_recursive_pure(i - 1)
end if
end function func_simple_recursive_pure

! An elemental simple function
elemental logical function func_simple_elemental()
func_simple_elemental = .true.
end function func_simple_elemental

! A module with nonsense
module bar
type food
integer :: cooking_method
end type food
type organic
integer :: growing_method
end type organic
integer, parameter :: i_am_dim = 10
end module bar

! A module with more nonsense
module foo
use bar, only: FOOD
integer :: foo_int
contains
subroutine foo_sub(egg)
integer, parameter :: egg_dim = 10
type(Food), intent(in) :: egg
write(*, *) egg
end subroutine foo_sub
elemental function foo_func() result(f)
integer :: f
f = 0
end function
end module foo

! An function with arguments and module imports
integer(selected_int_kind(0)) function func_with_use_and_args(egg, ham)
use foo
! Deliberate trailing spaces in next line
use bar, only : organic,     i_am_dim
implicit none
integer, intent(in) :: egg(i_am_dim)
integer, intent(in) :: ham(i_am_dim, 2)
real bacon
! Deliberate trailing spaces in next line
type(   organic   ) :: tomato
func_with_use_and_args = egg(1) + ham(1, 1)
end function func_with_use_and_args

! A function with some parameters
character(20) function func_with_parameters(egg, ham)
implicit none
character*(*), parameter :: x_param = '01234567890'
character(*), parameter :: & ! throw in some comments
    y_param                &
    = '!&!&!&!&!&!'          ! how to make life interesting
integer, parameter :: z = 20
character(len(x_param)), intent(in) :: egg
character(len(y_param)), intent(in) :: ham
func_with_parameters = egg // ham
end function func_with_parameters

! A function with some parameters, with a result
function func_with_parameters_1(egg, ham) result(r)
implicit none
integer, parameter :: x_param = 10
integer z_param
parameter(z_param = 2)
real, intent(in), dimension(x_param) :: egg
integer, intent(in) :: ham
logical :: r(z_param)
r(1) = int(egg(1)) + ham > 0
r(2) = .false.
end function func_with_parameters_1

! A function with a contains
character(10) function func_with_contains(mushroom, tomoato)
character(5) mushroom
character(5) tomoato
func_with_contains = func_with_contains_1()
contains
character(10) function func_with_contains_1()
func_with_contains_1 = mushroom // tomoato
end function func_with_contains_1
end function func_with_contains

! A function with its result declared after a local in the same statement
Function func_mix_local_and_result(egg, ham, bacon) Result(Breakfast)
Integer, Intent(in) :: egg, ham
Real, Intent(in) :: bacon
Real :: tomato, breakfast
Breakfast = real(egg) + real(ham) + bacon
End Function func_mix_local_and_result

! A simple subroutine
subroutine sub_simple()
end subroutine sub_simple

! A simple subroutine, with not so friendly end
subroutine sub_simple_1()
end subroutine

! A simple subroutine, with even less friendly end
subroutine sub_simple_2()
end subroutine

! A simple subroutine, with funny continuation
subroutine sub_simple_3()
end sub&
&routine&
& sub_simple_3

! A subroutine with a few contains
subroutine sub_with_contains(foo) ! " &
! Deliberate trailing spaces in next line
use Bar, only: i_am_dim
character*(len('!"&''&"!')) & ! what a mess!
    foo
call sub_with_contains_first()
call sub_with_contains_second()
call sub_with_contains_third()
print*, foo
contains
subroutine sub_with_contains_first()
interface
integer function x()
end function x
end interface
end subroutine sub_with_contains_first
subroutine sub_with_contains_second()
end subroutine
subroutine sub_with_contains_third()
end subroutine
end subroutine sub_with_contains

! A subroutine with a renamed module import
subroutine sub_with_renamed_import(i_am_dim)
use bar, only: i_am_not_dim => i_am_dim
integer, parameter :: d = 2
complex :: i_am_dim(d)
print*, i_am_dim
end subroutine sub_with_renamed_import

! A subroutine with an external argument
subroutine sub_with_external(proc)
external proc
call proc()
end subroutine sub_with_external

! A subroutine with a variable named "end"
subroutine sub_with_end()
integer :: end
end = 0
end subroutine sub_with_end
loki-ecmwf-0.3.6/loki/tests/sources/sourcefile.f900000664000175000017500000000071515167130205022207 0ustar  alastairalastairsubroutine routine_a
integer a
a = 1
end subroutine routine_a

module some_module
contains
subroutine module_routine
integer m
m = 2
end subroutine module_routine
function module_function(n)
integer n
n = 3
end function module_function
end module some_module

subroutine routine_b
integer b
b = 4
contains
subroutine contained_c
integer c
c = 5
end subroutine contained_c
end subroutine routine_b

function function_d(d)
integer d
d = 6
end function function_d
loki-ecmwf-0.3.6/loki/tests/sources/projB/0000775000175000017500000000000015167130205020600 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/tests/sources/projB/external/0000775000175000017500000000000015167130205022422 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/tests/sources/projB/external/ext_driver_mod.f900000664000175000017500000000064315167130205025757 0ustar  alastairalastair#ifdef HAVE_EXT_DRIVER_MODULE
module ext_driver_mod
  implicit none

contains
#endif

  subroutine ext_driver(vector, matrix)
    use ext_kernel_mod, only: jprb, ext_kernel
    implicit none
    real(kind=jprb), intent(inout) :: vector(:)
    real(kind=jprb), intent(inout) :: matrix(:, :)

    call ext_kernel(vector, matrix)
  end subroutine ext_driver

#ifdef HAVE_EXT_DRIVER_MODULE
end module ext_driver_mod
#endif
loki-ecmwf-0.3.6/loki/tests/sources/projB/module/0000775000175000017500000000000015167130205022065 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/tests/sources/projB/module/ext_kernel.f900000664000175000017500000000055215167130205024547 0ustar  alastairalastairmodule ext_kernel_mod
  integer, parameter :: jprb = 4

contains

  subroutine ext_kernel(vector, matrix)
    real(kind=jprb), intent(inout) :: vector(:)
    real(kind=jprb), intent(inout) :: matrix(:, :)
    integer :: i

    do i = 1, size(vector)
      matrix(:, i) = matrix(:, i) + vector(i)
    end do
  end subroutine ext_kernel

end module ext_kernel_mod
loki-ecmwf-0.3.6/loki/tests/sources/projBatch/0000775000175000017500000000000015167130205021440 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/tests/sources/projBatch/source/0000775000175000017500000000000015167130205022740 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/tests/sources/projBatch/source/comp1.F900000664000175000017500000000065115167130205024241 0ustar  alastairalastairsubroutine comp1 (arg, val)
    use t_mod, only: t, nt1
    use header_mod
    implicit none
    type(t), intent(inout) :: arg
    real(kind=k), intent(inout) :: val(:)
    integer :: jnt1
#include "comp2.intfb.h"
    call arg%proc()
    call comp2(arg, val)
    call comp2(arg, val)  ! Twice to check we're not duplicating dependencies
    do jnt1=1,nt1
        call arg%no(jnt1)%way(.true.)
    end do
end subroutine comp1
loki-ecmwf-0.3.6/loki/tests/sources/projBatch/source/comp2.f900000664000175000017500000000046715167130205024307 0ustar  alastairalastairsubroutine comp2 (arg, val)
    use t_mod, only: t
    use header_mod, only: k
    use a_mod, only: a
    use b_mod, only: b
    implicit none
    type(t), intent(inout) :: arg
    real(kind=k), intent(inout) :: val(:)

    call a(t%yay%indirection)
    call b(val)
    call arg%yay%proc()
end subroutine comp2
loki-ecmwf-0.3.6/loki/tests/sources/projBatch/headers/0000775000175000017500000000000015167130205023053 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/tests/sources/projBatch/headers/header_mod.F900000664000175000017500000000013215167130205025416 0ustar  alastairalastairmodule header_mod
    implicit none
    integer, parameter :: k = 8
end module header_mod
loki-ecmwf-0.3.6/loki/tests/sources/projBatch/include/0000775000175000017500000000000015167130205023063 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/tests/sources/projBatch/include/comp2.intfb.h0000664000175000017500000000033215167130205025353 0ustar  alastairalastairinterface
subroutine comp2 (arg, val)
    use t_mod, only: t
    use header_mod, only: k
    implicit none
    type(t), intent(inout) :: arg
    real(kind=k), intent(inout) :: val(:)
end subroutine comp2
end interface
loki-ecmwf-0.3.6/loki/tests/sources/projBatch/module/0000775000175000017500000000000015167130205022725 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/tests/sources/projBatch/module/tt_mod.F900000664000175000017500000000066615167130205024503 0ustar  alastairalastairmodule tt_mod
    use header_mod, only: k
    implicit none

    integer, parameter :: nclv = 5

    type tt
        real(kind=k), allocatable :: indirection(:)
        real(kind=k) :: other(nclv)
    contains
        procedure :: proc
    end type tt

    interface intf
        module procedure proc
    end interface
contains
    subroutine proc(this)
        class(tt), intent(inout) :: this
    end subroutine proc
end module tt_mod
loki-ecmwf-0.3.6/loki/tests/sources/projBatch/module/other_mod.F900000664000175000017500000000041115167130205025161 0ustar  alastairalastairmodule other_mod
    use tt_mod, only: tt
    use b_mod, only: b
    implicit none
contains
    subroutine mod_proc(arg)
        type(tt), intent(inout) :: arg
        call arg%proc()
        call b(arg%indirection)
    end subroutine mod_proc
end module other_mod
loki-ecmwf-0.3.6/loki/tests/sources/projBatch/module/a_mod.F900000664000175000017500000000026215167130205024264 0ustar  alastairalastairmodule a_mod
    implicit none
contains
    subroutine a(arg)
        use header_mod, only: k
        real(kind=k), intent(inout) :: arg(:)
    end subroutine a
end module a_mod
loki-ecmwf-0.3.6/loki/tests/sources/projBatch/module/b_mod.F900000664000175000017500000000025615167130205024270 0ustar  alastairalastairmodule b_mod
    use header_mod, only: k
    implicit none
contains
    subroutine b(arg)
        real(kind=k), intent(inout) :: arg(:)
    end subroutine b
end module b_mod
loki-ecmwf-0.3.6/loki/tests/sources/projBatch/module/t_mod.F900000664000175000017500000000130615167130205024307 0ustar  alastairalastairmodule t_mod
    use tt_mod, only: tt, intf, proc
    use a_mod, only: a
    implicit none

    integer, parameter :: nt1 = 10

    type t1
    contains
        procedure :: way => my_way
    end type t1

    type t
        type(tt) :: yay
        type(t1) :: no(nt1)
    contains
        procedure :: proc => t_proc
    end type t
contains
    subroutine t_proc(this)
        class(t), intent(inout) :: this
        call a(this%yay%other)
        call this%yay%proc()
    end subroutine t_proc

    recursive subroutine my_way(this, recurse)
        class(t1), intent(inout) :: this
        logical, intent(in) :: recurse
        if (recurse) call this%way(.false.)
    end subroutine my_way
end module t_mod
loki-ecmwf-0.3.6/loki/tests/sources/sourcefile_cpp_stmt_func.F900000664000175000017500000000264415167130205025076 0ustar  alastairalastairmodule cpp_stmt_func_mod

    IMPLICIT NONE

    ! originally declared in parkind1.F90
    INTEGER, PARAMETER :: JPIM = SELECTED_INT_KIND(9)
    INTEGER, PARAMETer :: JPRB = SELECTED_REAL_KIND(13,300)

contains

subroutine cpp_stmt_func(KIDIA, KFDIA, KLON, KLEV, ZFOEEW)
    INTEGER(KIND=JPIM),INTENT(IN)    :: KLON, KLEV
    INTEGER(KIND=JPIM),INTENT(IN)    :: KIDIA
    INTEGER(KIND=JPIM),INTENT(IN)    :: KFDIA
    REAL(KIND=JPRB)   ,INTENT(OUT)   :: ZFOEEW(KLON,KLEV)

    INTEGER(KIND=JPIM) :: JK, JL

    REAL(KIND=JPRB) :: ZTP1(KLON,KLEV)
    REAL(KIND=JPRB) :: PAP(KLON,KLEV)
    REAL(KIND=JPRB) :: ZALFA

    ! originally declared in yomcst.F90
    REAL(KIND=JPRB) :: RTT = 1._JPRB

    ! originally declared in yoethf.F90
    REAL(KIND=JPRB) :: R2ES = 2._JPRB
    REAL(KIND=JPRB) :: R3LES = 3._JPRB
    REAL(KIND=JPRB) :: R3IES = 3._JPRB
    REAL(KIND=JPRB) :: R4LES = 4._JPRB
    REAL(KIND=JPRB) :: R4IES = 4._JPRB

#include "stmt.func.h"

    ! initialize with some stupid values
    PAP(:,:) = 8._JPRB
    ZTP1(:,:) = 1._JPRB

    DO JK=1,KLEV
        DO JL=KIDIA,KFDIA
            ZALFA=FOEDELTA(ZTP1(JL,JK))

            ! this should essentially become: min((zalfa * 2 + (1-zalfa) * 2))/8,0.5) === 0.25
            ZFOEEW(JL,JK)=MIN((ZALFA*FOEELIQ(ZTP1(JL,JK))+ &
                &  (1.0_JPRB-ZALFA)*FOEEICE(ZTP1(JL,JK)))/PAP(JL,JK),0.5_JPRB)
        END DO
    END DO
end subroutine cpp_stmt_func

end module cpp_stmt_func_mod
loki-ecmwf-0.3.6/loki/tests/sources/sourcefile_cpp_preprocessing.F900000664000175000017500000000044515167130205025754 0ustar  alastairalastairsubroutine sourcefile_external_preprocessing(a, b)
  real(kind=8), intent(inout) :: a, b
#include "some_header.h"
#ifdef FLAG_SMALL
#define CONSTANT 6
#else
#define CONSTANT 123
#endif

#define ADD_ONE(x) x + 1

  a = ADD_ONE(5)
  b = CONSTANT
end subroutine sourcefile_external_preprocessing
loki-ecmwf-0.3.6/loki/tests/sources/sourcefile_pp_macros.F900000664000175000017500000000036515167130205024213 0ustar  alastairalastairsubroutine routine_pp_macros()
#define CONSTANT 123
#define FLAG
  implicit none
  integer :: y, z
#define SOME_MACRO(x) x + 1
  y = 1
#define SOME_OTHER_MACRO (x - 1)
#
#warning 'ABC'
#ifdef FLAG
  z = 3
#endif
end subroutine routine_pp_macros
loki-ecmwf-0.3.6/loki/tests/sources/call_me_trafo.py0000664000175000017500000000150515167130205022666 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki import Transformation, FindNodes, CallStatement


class CallMeMaybeTrafo(Transformation):
    """ Test transformation for dynamically loading remote transformations. """

    def __init__(self, name='Dave', horizontal=None):
        self.name = name
        self.horizontal = horizontal

    def transform_subroutine(self, routine, **kwargs):
        for call in FindNodes(CallStatement).visit(routine.body):
            call._update(name=self.name)
loki-ecmwf-0.3.6/loki/tests/sources/projHoist/0000775000175000017500000000000015167130205021505 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/tests/sources/projHoist/module/0000775000175000017500000000000015167130205022772 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/tests/sources/projHoist/module/subroutines_inline_mod.f900000664000175000017500000000312215167130205030067 0ustar  alastairalastairmodule subroutines_inline_mod
    implicit none
    integer, parameter :: len = 3
contains

    function func1(a)
        integer, intent(in) :: a
        integer :: func1(a)

        func1 = 1
    end function func1

    subroutine kernel1(a, b, c)
        integer, intent(in) :: a
        integer, intent(inout) :: b(a)
        integer, intent(inout) :: c(a, len)
        real :: x(a)
        integer :: y(a, len)
        real :: k1_tmp(a, len)
        y = 11
        c = y
        x = func1(a)
    end subroutine kernel1

    subroutine kernel2(a1, b)
        integer, intent(in) :: a1
        integer, intent(inout) :: b(a1)
        real :: x(a1)
        real :: y, z
        real :: k2_tmp(a1, a1)
        call device1(a1, b, x, k2_tmp)
        call device2(a1, b, x)
    end subroutine kernel2

    subroutine device1(a1, b, x, y)
        integer, intent(in) :: a1
        integer, intent(inout) :: b(a1)
        real, intent(inout) :: x(a1)
        real, intent(inout) :: y(a1, a1)
        real :: z
        integer :: d1_tmp
        call device2(a1, b, x)
        call device2(a1, b, x)
    end subroutine device1

    function init_int(a2) result(TMP)
        integer, intent(in) :: a2
        integer :: tmp(a2)
        integer :: tmp0(a2)

        tmp0 = 42
        tmp = tmp0
    end function init_int

    subroutine device2(a2, b, x)
        integer, intent(in) :: a2
        integer, intent(inout) :: b(a2)
        real, intent(inout) :: x(a2)
        integer z(a2)
        integer :: d2_tmp(len,len)
        z = init_int(a2)
        b = z
    end subroutine device2

end module subroutines_inline_mod

loki-ecmwf-0.3.6/loki/tests/sources/projHoist/module/driver_mod.f900000664000175000017500000000156115167130205025447 0ustar  alastairalastairmodule transformation_module_hoist
    USE subroutines_mod, only: kernel1, kernel2, device1, device2, kernel3
    implicit none
    integer, parameter :: len = 3
contains

    subroutine driver(a, b, c)
        integer, intent(in) :: a
        integer, intent(inout) :: b(a)
        integer, intent(inout) :: c(a, len)
        real :: x, y
        call kernel1(a, b, c)
        call kernel2(a, b)
        call kernel1(a, b, c)
    end subroutine driver

    subroutine another_driver(a, b, c)
        integer, intent(in) :: a
        integer, intent(inout) :: b(a)
        integer, intent(inout) :: c(a, len)
        real :: x, y
        call kernel1(a, b, c)
    end subroutine another_driver

    subroutine yet_another_driver(a, a1)
        integer, intent(in) :: a, a1

        call kernel3(a, a1)
    end subroutine yet_another_driver

end module transformation_module_hoist

loki-ecmwf-0.3.6/loki/tests/sources/projHoist/module/driver_inline_mod.f900000664000175000017500000000076615167130205027013 0ustar  alastairalastairmodule transformation_module_hoist_inline
    USE subroutines_inline_mod, only: kernel1, kernel2
    implicit none
    integer, parameter :: len = 3
contains

    subroutine inline_driver(a, b, c)
        integer, intent(in) :: a
        integer, intent(inout) :: b(a)
        integer, intent(inout) :: c(a, len)
        real :: x, y
        call kernel1(a, b, c)
        call kernel2(a, b)
        call kernel1(a, b, c)
    end subroutine inline_driver

end module transformation_module_hoist_inline

loki-ecmwf-0.3.6/loki/tests/sources/projHoist/module/subroutines_mod.f900000664000175000017500000000276215167130205026542 0ustar  alastairalastairmodule subroutines_mod
    implicit none
    integer, parameter :: len = 3
contains

    subroutine kernel1(a, b, c)
        integer, intent(in) :: a
        integer, intent(inout) :: b(a)
        integer, intent(inout) :: c(a, len)
        real :: x(a)
        integer :: y(a, len)
        real :: k1_tmp(a, len)
        y = 11
        c = y
    end subroutine kernel1

    subroutine kernel2(a1, b)
        integer, intent(in) :: a1
        integer, intent(inout) :: b(a1)
        real :: x(a1)
        real :: y, z
        real :: k2_tmp(a1, a1)
        call device1(a1, b, x, k2_tmp)
        call device2(a1, b, x)
    end subroutine kernel2

    subroutine device1(a1, b, x, y)
        integer, intent(in) :: a1
        integer, intent(inout) :: b(a1)
        real, intent(inout) :: x(a1)
        real, intent(inout) :: y(a1, a1)
        real :: z
        integer :: d1_tmp
        call device2(a1, b, x)
        call device2(a1, b, x)
    end subroutine device1

    subroutine device2(a2, b, x)
        integer, intent(in) :: a2
        integer, intent(inout) :: b(a2)
        real, intent(inout) :: x(a2)
        integer z(a2)
        integer :: d2_tmp(len)
        z = 42
        b = z
    end subroutine device2

    subroutine kernel3(a, a1)
        integer, intent(in) :: a, a1

        call device3(a)
        call device3(a1)
    end subroutine kernel3

    subroutine device3(n)
        integer, intent(in) :: n
        integer :: x(n)

        x = 1
    end subroutine device3

end module subroutines_mod

loki-ecmwf-0.3.6/loki/tests/sources/projTypeBound/0000775000175000017500000000000015167130205022330 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/tests/sources/projTypeBound/typebound_header.F900000664000175000017500000000177715167130205026145 0ustar  alastairalastairmodule typebound_header
    implicit none

    type header_type
    contains
        procedure :: member_routine => header_member_routine
        procedure :: routine_real => header_routine_real
        procedure :: routine_integer
        generic :: routine => routine_real, routine_integer
    end type header_type

contains

    subroutine header_member_routine(self, val)
        class(header_type) :: self
        integer, intent(in) :: val
        integer :: j
        j = val
    end subroutine header_member_routine

    subroutine header_routine_real(self, val)
        class(header_type) :: self
        real, intent(out) :: val
        val = 1.0
    end subroutine header_routine_real

    subroutine routine_integer(self, val)
        class(header_type) :: self
        integer, intent(out) :: val
        val = 1
    end subroutine routine_integer

    SUBROUTINE ABOR1(CDTEXT)
        CHARACTER(LEN=*) CDTEXT
        WRITE(0,*) CDTEXT
        call abort()
    END SUBROUTINE ABOR1
end module typebound_header
loki-ecmwf-0.3.6/loki/tests/sources/projTypeBound/typebound_other.F900000664000175000017500000000161515167130205026025 0ustar  alastairalastairmodule typebound_other
    use typebound_header, only: header => header_type

    implicit none

    type other_type
      type(header) :: var(2)
    contains
      procedure :: member => other_member
    end type other_type

    type outer_type
      type(other_type) :: other
    contains
      procedure :: nested_call
    end type outer_type

contains

    module subroutine other_member(self, i, m)
        use typebound_header, only: member_routine => header_member_routine
        class(other_type) :: self
        integer, intent(in) :: i, m
        call member_routine(m)
        call self%var(i)%member_routine(m)
    end subroutine other_member

    subroutine nested_call(self, m)
      class(outer_type) :: self
      integer, intent(in) :: m
      call self%other%var(1)%member_routine(m)
      call self%other%var(2)%member_routine(m)
    end subroutine nested_call

end module typebound_other
loki-ecmwf-0.3.6/loki/tests/sources/projTypeBound/typebound_item.F900000664000175000017500000000347015167130205025643 0ustar  alastairalastairmodule typebound_item
    use typebound_header
    implicit none

    type some_type
    contains
        procedure, nopass :: routine => module_routine
        procedure :: some_routine
        procedure, pass :: other_routine
        procedure :: routine1, &
            & routine2 => routine
        ! procedure :: routine1
        ! procedure :: routine2 => routine
    end type some_type
contains
    subroutine module_routine
        integer m
        m = 2
    end subroutine module_routine

    subroutine some_routine(self)
        class(some_type) :: self

        call self%routine
    end subroutine some_routine

    subroutine other_routine(self, m)
        class(some_type), intent(inout) :: self
        integer, intent(in) :: m
        integer :: j

        if (m < 0) call abor1('Error with unbalanced parenthesis)')

        j = m
        call self%routine1
        call self%routine2
    end subroutine other_routine

    subroutine routine(self)
        class(some_type) :: self
        call self%some_routine
    end subroutine routine

    subroutine routine1(self)
        class(some_type) :: self
        call module_routine
    end subroutine routine1
end module typebound_item

subroutine driver
    use typebound_item
    use typebound_header
    use typebound_other, only: other => other_type
    implicit none

    integer, parameter :: constant = 2
    type(some_type), allocatable :: obj(:), obj2(:,:)
    type(header_type) :: header
    type(other) :: other_obj, derived(constant)
    real :: x
    integer :: i

    allocate(obj(1))
    allocate(obj2(1,1))
    call obj(1)%other_routine(5)
    call obj2(1,1)%some_routine
    call header%member_routine(1)
    call header%routine(x)
    call header%routine(i)
    call other_obj%member(2, 2)
    call derived(1)% var( 2 ) % member_routine(2)
end subroutine driver
loki-ecmwf-0.3.6/loki/tests/sources/projC/0000775000175000017500000000000015167130205020601 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/tests/sources/projC/util/0000775000175000017500000000000015167130205021556 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/tests/sources/projC/util/proj_c_util_mod.f900000664000175000017500000000070215167130205025245 0ustar  alastairalastairmodule proj_c_util_mod
  integer, parameter :: jprb = 4

contains

  subroutine routine_one(matrix)
    real(kind=jprb), intent(inout) :: matrix(:,:)
    integer :: i

    do i = 1, size(matrix, dim=1)
      call routine_two(matrix(i,:))
    end do
  end subroutine routine_one

  subroutine routine_two(vector)
    real(kind=jprb), intent(inout) :: vector(:)

    vector(:) = vector(:) + 2.0
  end subroutine routine_two

end module proj_c_util_mod
loki-ecmwf-0.3.6/loki/tests/sources/projA/0000775000175000017500000000000015167130205020577 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/tests/sources/projA/scheduler_partial.config0000664000175000017500000000065315167130205025464 0ustar  alastairalastair[default]
# Specifies the behaviour of auto-expanded routines
mode = 'test'
role = 'kernel'
expand = true

# Forces exceptions to be thrown during processing
strict = true

# The list files to exclude from the tree. During development this is
# usually used to exclude files we cannot yet process.
block = ['compute_l2']

[routines.compute_l1]
role = 'driver'
expand = true

[routines.another_l1]
role = 'driver'
expand = true
loki-ecmwf-0.3.6/loki/tests/sources/projA/source/0000775000175000017500000000000015167130205022077 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/tests/sources/projA/source/another_l1.F900000664000175000017500000000032215167130205024410 0ustar  alastairalastairsubroutine another_l1(matrix)
  use header_mod, only: jprb

  implicit none

  real(kind=jprb), intent(inout) :: matrix(:,:)

#include "another_l2.intfb.h"

  call another_l2(matrix)

end subroutine another_l1
loki-ecmwf-0.3.6/loki/tests/sources/projA/source/another_l2.F900000664000175000017500000000025615167130205024417 0ustar  alastairalastairsubroutine another_l2(matrix)
  use header_mod, only: jprb

  implicit none

  real(kind=jprb), intent(inout) :: matrix(:,:)

  matrix(:,:) = 77.0

end subroutine another_l2
loki-ecmwf-0.3.6/loki/tests/sources/projA/include/0000775000175000017500000000000015167130205022222 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/tests/sources/projA/include/another_l1.intfb.h0000664000175000017500000000026015167130205025526 0ustar  alastairalastairinterface

subroutine another_l1(matrix)
  use header_mod, only: jprb
  implicit none

  real(kind=jprb), intent(inout) :: matrix(:,:)
end subroutine another_l1

end interface
loki-ecmwf-0.3.6/loki/tests/sources/projA/include/another_l2.intfb.h0000664000175000017500000000026015167130205025527 0ustar  alastairalastairinterface

subroutine another_l2(matrix)
  use header_mod, only: jprb
  implicit none

  real(kind=jprb), intent(inout) :: matrix(:,:)
end subroutine another_l2

end interface
loki-ecmwf-0.3.6/loki/tests/sources/projA/module/0000775000175000017500000000000015167130205022064 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/tests/sources/projA/module/driverB_mod.f900000664000175000017500000000043115167130205024636 0ustar  alastairalastairmodule driverB_mod
  use header_mod, only: jprb, header_type
  use kernelB_mod, only: kernelB

  implicit none

contains

  subroutine driverB()
    type(header_type) :: mystruct

    call kernelB(mystruct%vector, mystruct%matrix)

  end subroutine driverB

end module driverB_mod
loki-ecmwf-0.3.6/loki/tests/sources/projA/module/kernelE_mod.f900000664000175000017500000000136615167130205024636 0ustar  alastairalastairmodule kernelE_mod
  use header_mod, only: jprb
  use compute_l1_mod, only: compute_l1

  implicit none

contains

  ! Two kernel routines, but we process only one!

  subroutine kernelE(vector, matrix)
    real(kind=jprb), intent(inout) :: vector(:)
    real(kind=jprb), intent(inout) :: matrix(:)

    call compute_l1(vector)

    call ghost_busters(vector)

  contains

    subroutine ghost_busters(vector)
      real(kind=jprb), intent(inout) :: vector(:)

      vector(:) = 42.0
    end subroutine ghost_busters

  end subroutine kernelE

  subroutine kernelET(vector, matrix)
    real(kind=jprb), intent(inout) :: vector(:)
    real(kind=jprb), intent(inout) :: matrix(:)

    call compute_l2(vector)

  end subroutine kernelET

end module kernelE_mod
loki-ecmwf-0.3.6/loki/tests/sources/projA/module/kernelC_mod.f900000664000175000017500000000061515167130205024630 0ustar  alastairalastairmodule kernelC_mod
  use header_mod, only: jprb
  use compute_l1_mod, only: compute_l1
  use proj_c_util_mod, only: routine_one

  implicit none

contains

  subroutine kernelC(vector, matrix)
    real(kind=jprb), intent(inout) :: vector(:)
    real(kind=jprb), intent(inout) :: matrix(:)

    call compute_l1(vector)

    call routine_one(matrix)
  end subroutine kernelC

end module kernelC_mod
loki-ecmwf-0.3.6/loki/tests/sources/projA/module/kernelD_mod.f900000664000175000017500000000057215167130205024633 0ustar  alastairalastairmodule kernelD_mod
  use header_mod, only: jprb
  use compute_l1_mod, only: compute_l1
  use proj_c_util_mod

  implicit none

contains

  subroutine kernelD(vector, matrix)
    real(kind=jprb), intent(inout) :: vector(:)
    real(kind=jprb), intent(inout) :: matrix(:)

    call compute_l1(vector)

    call routine_one(matrix)
  end subroutine kernelD

end module kernelD_mod
loki-ecmwf-0.3.6/loki/tests/sources/projA/module/kernelB_mod.F900000664000175000017500000000076415167130205024574 0ustar  alastairalastairmodule kernelB_mod
  use header_mod, only: jprb
  use compute_l1_mod, only: compute_l1
#ifdef HAVE_EXT_DRIVER_MODULE
  use ext_driver_mod, only: ext_driver
#endif

  implicit none

contains

  subroutine kernelB(vector, matrix)
    real(kind=jprb), intent(inout) :: vector(:)
    real(kind=jprb), intent(inout) :: matrix(:)

#ifndef HAVE_EXT_DRIVER_MODULE
#include "ext_driver.intfb.h"
#endif

    call compute_l1(vector)

    call ext_driver(matrix)
  end subroutine kernelB

end module kernelB_mod
loki-ecmwf-0.3.6/loki/tests/sources/projA/module/driverC_mod.f900000664000175000017500000000043115167130205024637 0ustar  alastairalastairmodule driverC_mod
  use header_mod, only: jprb, header_type
  use kernelC_mod, only: kernelC

  implicit none

contains

  subroutine driverC()
    type(header_type) :: mystruct

    call kernelC(mystruct%vector, mystruct%matrix)

  end subroutine driverC

end module driverC_mod
loki-ecmwf-0.3.6/loki/tests/sources/projA/module/compute_l2_mod.f900000664000175000017500000000033415167130205025314 0ustar  alastairalastairmodule compute_l2_mod
  use header_mod, only: jprb

contains

  subroutine compute_l2(vector)
    real(kind=jprb), intent(inout) :: vector(:)

    vector(:) = 66.0

  end subroutine compute_l2

end module compute_l2_mod
loki-ecmwf-0.3.6/loki/tests/sources/projA/module/driverE_mod.f900000664000175000017500000000106315167130205024643 0ustar  alastairalastairmodule driverE_mod
  use header_mod, only: jprb, header_type
  use kernelE_mod, only: kernelE, kernelET

  implicit none

contains

  ! Two driver routines, but we process only one!

  subroutine driverE_single()
    type(header_type) :: mystruct

    call kernelE(mystruct%vector, mystruct%matrix)

  end subroutine driverE_single


  subroutine driverE_multiple()
    type(header_type) :: mystruct

    call kernelE(mystruct%vector, mystruct%matrix)

    call kernelET(mystruct%vector, mystruct%matrix)
  end subroutine driverE_multiple

end module driverE_mod
loki-ecmwf-0.3.6/loki/tests/sources/projA/module/kernelA_mod.F900000664000175000017500000000060315167130205024563 0ustar  alastairalastairmodule KERNELA_MOD
  use header_mod, only: jprb
  use compute_l1_mod, only: compute_l1

  implicit none

contains

  subroutine KERNELA(vector, matrix)
    real(kind=jprb), intent(inout) :: vector(:)
    real(kind=jprb), intent(inout) :: matrix(:)

#include "another_l1.intfb.h"

    call COMPUTE_L1(vector)

    call ANOTHER_L1(matrix)

  end subroutine KERNELA

end module KERNELA_MOD
loki-ecmwf-0.3.6/loki/tests/sources/projA/module/header_mod.f900000664000175000017500000000025615167130205024476 0ustar  alastairalastairmodule header_mod

  integer, parameter :: jprb = 4

  type header_type
    real(kind=jprb) :: scalar, vector(:), matrix(3, 3)
  end type header_type


end module header_mod
loki-ecmwf-0.3.6/loki/tests/sources/projA/module/driverA_mod.f900000664000175000017500000000043115167130205024635 0ustar  alastairalastairmodule driverA_mod
  use header_mod, only: jprb, header_type
  use kernelA_mod, only: kernelA

  implicit none

contains

  subroutine driverA()
    type(header_type) :: mystruct

    call kernelA(mystruct%vector, mystruct%matrix)

  end subroutine driverA

end module driverA_mod
loki-ecmwf-0.3.6/loki/tests/sources/projA/module/compute_l1_mod.f900000664000175000017500000000041215167130205025310 0ustar  alastairalastairmodule compute_l1_mod
  use header_mod, only: jprb
  use compute_l2_mod, only: compute_l2

contains

  subroutine compute_l1(vector)
    real(kind=jprb), intent(inout) :: vector(:)

    call compute_l2(vector)

  end subroutine compute_l1

end module compute_l1_mod
loki-ecmwf-0.3.6/loki/tests/sources/projA/module/driverD_mod.f900000664000175000017500000000043015167130205024637 0ustar  alastairalastairmodule driverD_mod
  use header_mod, only: jprb, header_type
  use kernelD_mod, only: kernelD

  implicit none

contains

  subroutine driverD()
    type(header_type) :: mystruct

    call kernelD(mystruct%vector, mystruct%matrix)
  end subroutine driverD

end module driverD_mod
loki-ecmwf-0.3.6/loki/tests/sources/data_dependency_detection/0000775000175000017500000000000015167130205024671 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/tests/sources/data_dependency_detection/loop_carried_dependencies.f900000664000175000017500000000212515167130205032361 0ustar  alastairalastairsubroutine SimpleDependency(data, n)
  implicit none
  integer, intent(in) :: n
  real(8), dimension(n) :: data
  integer :: i

  ! Loop with a simple loop-carried dependency
  do i = 1, n
    data(i) = data(i) + data(i-1)
  end do

end subroutine SimpleDependency


subroutine NestedDependency(data, n)
  implicit none
  integer, intent(in) :: n
  real(8), dimension(n) :: data
  integer :: i, j

  ! Nested loop with a loop-carried dependency
  do i = 2, n
    do j = 1, i-1
      data(i) = data(i) + data(j)
    end do
  end do

end subroutine NestedDependency


subroutine ConditionalDependency(data, n)
  implicit none
  integer, intent(in) :: n
  real(8), dimension(n) :: data
  integer :: i

  ! Loop with a conditional loop-carried dependency
  do i = 2, n
    if (data(i-1) > 0.0) then
      data(i) = data(i) + data(i-1)
    endif
  end do

end subroutine ConditionalDependency

subroutine NoDependency(data)
  implicit none
  real(8), dimension(20) :: data
  integer :: i

  do i = 1, 10, 1
    data(2*i) = 10;
  end do

  do i = 1, 5, 1
    data(2*i + 1) = 20;
  end do
end subroutine NoDependencyloki-ecmwf-0.3.6/loki/tests/sources/data_dependency_detection/various_loops.f900000664000175000017500000000321715167130205030120 0ustar  alastairalastairsubroutine single_loop(arr, n)
  implicit none
  integer, intent(inout) :: arr(:)
  integer, intent(in) :: n
  integer :: i

  do i = 1, n
    arr(i) = arr(i) * 2
  end do
end subroutine single_loop

subroutine single_loop_split_access(arr, n)
  implicit none
  integer, intent(inout) :: arr(:)
  integer, intent(in) :: n
  integer :: i, nhalf
  nhalf = n/2
  do i = 1, nhalf
    arr(2*i) = arr(2*i) * 2
    arr(2*i + 1) = arr(2*i + 1) * 2
  end do
end subroutine single_loop_split_access

subroutine single_loop_arithmetic_operations_for_access(arr, n)
  implicit none
  integer, intent(inout) :: arr(:)
  integer, intent(in) :: n
  integer :: i

  do i = 1, n
    arr(i*i) = arr(i + i) * 2
  end do
end subroutine single_loop_arithmetic_operations_for_access

subroutine nested_loop_single_dimensions_access(arr, n)
  implicit none
  integer, intent(inout) :: arr(:)
  integer, intent(in) :: n
  integer :: i, j, nhalf

  nhalf = n/2
  do i = 1, nhalf
    do j = 1, nhalf
          arr(i + j) = arr(i + j) * 2
    end do
  end do
end subroutine nested_loop_single_dimensions_access



subroutine nested_loop_partially_used(arr, n)
  implicit none
  integer, intent(inout) :: arr(:)
  integer, intent(in) :: n
  integer :: i, j, nfourth

  nfourth = n / 4
  do i = 1, nfourth
    do j = 1, nfourth
          arr(i + j) = arr(i + j) * 2
    end do
  end do
end subroutine nested_loop_partially_used


subroutine partially_used_array(arr, n)
  implicit none
  integer, intent(out) :: arr(:)
  integer, intent(in) :: n
  integer :: i = 1 , j = 3, nhalf

  nhalf = n/2
  arr(1) = 1
  do i = 2, nhalf
    arr(i) = arr(i - 1)
  end do

  j = arr(j)
end subroutine partially_used_arrayloki-ecmwf-0.3.6/loki/tests/sources/projParametrise/0000775000175000017500000000000015167130205022673 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/tests/sources/projParametrise/parametrise.f900000664000175000017500000000363715167130205025540 0ustar  alastairalastairmodule parametrise
    implicit none
contains

    subroutine stop_execution(msg)
        character(200), INTENT(IN) :: msg
        PRINT *, msg
        stop 1
    end subroutine stop_execution

    subroutine driver(a, b, c, d)
        integer, intent(inout) :: a, b, d(b)
        integer, intent(inout) :: c(a, b)
        real :: x, y
        call kernel1(a, b, c)
        call kernel2(a, b, d)
        call kernel1(a, b, c)
    end subroutine driver

    subroutine another_driver(a, b, c)
        integer, intent(in) :: a
        integer, intent(in) :: b
        integer, intent(inout) :: c(a, b)
        real :: x(a)
        call kernel1(a, b, c)
    end subroutine another_driver

    subroutine kernel1(a, b, c)
        integer, intent(in) :: a
        integer, intent(in) :: b
        integer, intent(inout) :: c(a, b)
        integer :: local_a
        real :: x(a)
        integer :: y(a, b)
        real :: k1_tmp(a, b)
        y = 11
        c = y
    end subroutine kernel1

    subroutine kernel2(a_new, b, d)
        integer, intent(in) :: a_new
        integer, intent(in) :: b
        integer, intent(inout) :: d(b)
        real :: x(a_new)
        real :: y, z
        real :: k2_tmp(a_new, a_new)
        call device1(a_new, b, d, x, k2_tmp)
    end subroutine kernel2

    subroutine device1(a, b, d, x, y)
        integer, intent(in) :: a
        integer, intent(in) :: b
        integer, intent(inout) :: d(b)
        real, intent(inout) :: x(a)
        real, intent(inout) :: y(a, a)
        real :: z
        integer :: d1_tmp
        call device2(a, b, d, x)
        call device2(a, b, d, x)
    end subroutine device1

    subroutine device2(a, b, d, x)
        integer, intent(in) :: a
        integer, intent(in) :: b
        integer, intent(inout) :: d(b)
        real, intent(inout) :: x(a)
        integer z(b)
        integer :: d2_tmp(b)
        z = 42
        d = z
    end subroutine device2

end module parametrise
loki-ecmwf-0.3.6/loki/tests/test_source.py0000664000175000017500000004231515167130205020757 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from pathlib import Path
import re

import pytest

from loki import read_file, Source, source_to_lines, join_source_list, FortranReader


@pytest.fixture(scope='module', name='here')
def fixture_here():
    return Path(__file__).parent


def test_source(here):
    """Test :any:`Source` constructor"""
    filepath = here/'sources/sourcefile.f90'
    fcode = read_file(filepath)
    lines = (1, fcode.count('\n') + 1)

    source = Source([1, 1], None)
    assert source.string is None
    assert source.lines == [1, 1]
    assert source.file is None

    source = Source([3, None], fcode)
    assert source.string == fcode
    assert source.lines == [3, None]
    assert source.file is None

    source = Source(lines, fcode, filepath)
    assert source.string == fcode
    assert source.lines == lines
    assert source.file == filepath


def test_source_find(here):
    """Test the `find` utility of :any:`Source`"""
    filepath = here/'sources/sourcefile.f90'
    fcode = read_file(filepath)
    lines = (1, fcode.count('\n') + 1)

    routine_b_match = re.search(r'(subroutine routine_b.*?end subroutine routine_b)', fcode, re.DOTALL)
    assert routine_b_match
    routine_b_fcode = routine_b_match.group(0)

    cstart, cend = Source((1, 1), None).find(routine_b_fcode)
    assert cstart is None
    assert cend is None

    cstart, cend = Source(lines, fcode).find(routine_b_fcode)
    assert (cstart, cend) == routine_b_match.span()

    cstart, cend = Source(lines, fcode).find(routine_b_fcode.upper())
    assert (cstart, cend) == routine_b_match.span()

    cstart, cend = Source(lines, fcode).find(routine_b_fcode.upper(), ignore_case=False)
    assert (cstart, cend) == (None, None)

    bstart = routine_b_match.span()[0] + len('subroutine ')
    bend = bstart + len('routine_b')
    cstart, cend = Source(lines, fcode).find('   routine_b')
    assert (cstart, cend) == (bstart, bend)

    cstart, cend = Source(lines, fcode).find('   routine_b', ignore_space=False)
    assert (cstart, cend) == (None, None)

    cstart, cend = Source(lines, fcode).find(' routine_b', ignore_space=False)
    assert (cstart, cend) == (bstart - 1, bend)  # start offset by 1 because leading whitespace is taken into account


def test_source_clone_with_string(here):
    """Test the `clone_with_string` utility of :any:`Source`"""
    filepath = here/'sources/sourcefile.f90'
    fcode = read_file(filepath)
    lines = (1, fcode.count('\n') + 1)

    routine_b_match = re.search(r'(subroutine routine_b.*?end subroutine routine_b)', fcode, re.DOTALL)
    assert routine_b_match
    routine_b_fcode = routine_b_match.group(0)

    routine_b_start = fcode[:routine_b_match.span()[0]].count('\n') + 1
    routine_b_end = routine_b_start + routine_b_fcode.count('\n')
    routine_b_lines = (routine_b_start, routine_b_end)

    source = Source([3, None], None).clone_with_string(routine_b_fcode)
    assert source.string == routine_b_fcode
    assert source.lines == [3, None]
    assert source.file is None

    source = Source([3, None], fcode).clone_with_string(routine_b_fcode)
    assert source.string == routine_b_fcode
    assert source.lines == (20, 28)
    assert source.file is None

    source = Source(lines, fcode, filepath).clone_with_string(routine_b_fcode)
    assert source.string == routine_b_fcode
    assert source.lines == routine_b_lines
    assert source.file == filepath

    source = Source((1, routine_b_fcode.count('\n')+1), routine_b_fcode).clone_with_string(routine_b_fcode)
    assert source.string == routine_b_fcode
    assert source.lines == (1, routine_b_fcode.count('\n')+1)
    assert source.file is None

    source = Source(lines, fcode, filepath).clone_with_string(routine_b_fcode.upper(), ignore_case=True)
    assert source.string == routine_b_fcode
    assert source.lines == routine_b_lines
    assert source.file == filepath


def test_source_clone_with_span(here):
    """Test the `clone_with_span` utility of :any:`Source`"""
    filepath = here/'sources/sourcefile.f90'
    fcode = read_file(filepath)
    lines = (1, fcode.count('\n') + 1)

    routine_b_match = re.search(r'(subroutine routine_b.*?end subroutine routine_b)', fcode, re.DOTALL)
    assert routine_b_match
    routine_b_fcode = routine_b_match.group(0)

    routine_b_start = fcode[:routine_b_match.span()[0]].count('\n') + 1
    routine_b_end = routine_b_start + routine_b_fcode.count('\n')
    routine_b_lines = (routine_b_start, routine_b_end)

    source = Source(lines, fcode, filepath).clone_with_span(routine_b_match.span())
    assert source.string == routine_b_fcode
    assert source.lines == routine_b_lines
    assert source.file == filepath

    source = Source(lines, fcode.upper(), filepath).clone_with_span(routine_b_match.span())
    assert source.string == routine_b_fcode.upper()
    assert source.lines == routine_b_lines
    assert source.file == filepath


def test_source_clone_lines(here):
    """Test the `clone_lines` utility of :any:`Source`"""
    filepath = here/'sources/sourcefile.f90'
    fcode = read_file(filepath)
    lines = (1, fcode.count('\n') + 1)
    source = Source(lines, fcode, filepath)

    source_lines = source.clone_lines()
    str_lines = fcode.splitlines()
    assert len(source_lines) == len(str_lines)

    for idx, (source_line, str_line) in enumerate(zip(source_lines, str_lines)):
        assert source_line.string == str_line
        assert source_line.lines[0] == idx+1
        assert source_line.lines[1] == idx+1
        assert source_line.file == filepath

    routine_b_match = re.search(r'(subroutine routine_b.*?end subroutine routine_b)', fcode, re.DOTALL)
    routine_b_source = source.clone_with_span(routine_b_match.span())

    source_lines = source.clone_lines(routine_b_match.span())
    routine_b_str_lines = str_lines[routine_b_source.lines[0]-1:routine_b_source.lines[1]]
    assert len(source_lines) == len(routine_b_str_lines)

    for idx, (source_line, str_line) in enumerate(zip(source_lines, routine_b_str_lines)):
        assert source_line.string == str_line
        assert source_line.lines[0] == idx+routine_b_source.lines[0]
        assert source_line.lines[1] == idx+routine_b_source.lines[0]
        assert source_line.file == filepath


def test_source_to_lines():
    """Test the `source_to_lines` utility"""
    fcode = """
module some_module
    implicit none

    integer :: var1, &
        & var2, &
        & var3, &
 var4, &
            &var5,&
            &var6, &
            & var7

    ! This is a &
    ! & comment
contains
    subroutine my_routine
      integer j
      j = var1 &
        &+1
    end subroutine my_routine
end module some_module
    """.strip()

    lines = (1, fcode.count('\n') + 1)
    source = Source(lines, fcode)

    source_lines = source_to_lines(source)

    # All line numbers present?
    assert set(range(lines[0], lines[1] + 1)) == {
        n for s in source_lines for n in range(s.lines[0], s.lines[1] + 1)
    }
    # Line numbers don't overlap?
    assert all(
        l1.lines[0] <= l1.lines[1] and l1.lines[1] + 1 == l2.lines[0] and l2.lines[0] <= l2.lines[1]
        for l1, l2 in zip(source_lines[:-1], source_lines[1:])
    )

    # The known line continuations:
    assert source_lines[3].lines == (4, 10)
    assert source_lines[10].lines == (17, 18)
    assert '    integer ::' in source_lines[3].string
    assert ',  var7' in source_lines[3].string
    assert '      j = var1 +1' in source_lines[10].string


@pytest.mark.parametrize('source_list,expected',(
    (
        [], None
    ),  (
        [Source((1, 2), 'subroutine my_routine\nimplicit none'), Source((3, None), 'end subroutine my_routine')],
        Source((1, 3), 'subroutine my_routine\nimplicit none\nend subroutine my_routine')
    ), (
        [
            Source([1, None], 'subroutine my_routine'),
            Source([2, None], '  use iso_fortran_env, only: real64'),
            Source([3, None], '  implicit none'),
            Source([4, None], '  real(kind=real64) ::'),
            Source([4, 5], ' var_1, &\n   & var_2'),
            Source([6, 7], '  var_1 = 1._real64\n  var_2 = 2._real64'),
            Source([8, None], 'end subroutine my_routine'),
        ],
        Source((1, 8), '''
subroutine my_routine
  use iso_fortran_env, only: real64
  implicit none
  real(kind=real64) :: var_1, &
   & var_2
  var_1 = 1._real64
  var_2 = 2._real64
end subroutine my_routine
        '''.strip())
    ), (
        [
            Source((5, 5), 'integer ::'),
            Source((5, None), ' var1,'),
            Source((5, 5), ' var2')
        ],
        Source((5, 5), 'integer :: var1, var2')
    ), (
        [Source((1, 1), 'print *,* "hello world!"')], Source((1, 1), 'print *,* "hello world!"')
    ), (
        [Source((13, 19), '! line with less line breaks than reported'), Source((20, None), '! here')],
        Source((13, 20), '! line with less line breaks than reported\n! here')
    ), (
        [Source((7, None), '! Some line'), Source([12, None], '! Some other line')],
        Source((7, 12), '! Some line\n\n\n\n\n! Some other line')
    ), (
        [Source((3, 4), '! Some line\n! With line break'), Source([6, None], '! Other line\n! And new line')],
        Source((3, 7), '! Some line\n! With line break\n\n! Other line\n! And new line')
    )
))
def test_join_source_list(source_list, expected):
    """
    Test the `join_source_list` utility
    """
    result = join_source_list(source_list)
    if expected is None:
        assert result is None
    else:
        assert isinstance(result, Source)
        assert result.lines == expected.lines
        assert result.string == expected.string
        assert result.file == expected.file


def test_fortran_reader(here):
    """Test :any:`FortranReader` constructor"""
    filepath = here/'sources/Fortran-extract-interface-source.f90'
    fcode = read_file(filepath)
    lines = (1, fcode.count('\n') + 1)

    reader = FortranReader(fcode)

    # Check for line continuation in sanitized string
    _re_line_cont = re.compile(r'&([ \t]*)\n([ \t]*)(?:&|(?!\!)(?=\S))', re.MULTILINE)
    assert _re_line_cont.search(fcode) is not None
    assert _re_line_cont.search(reader.sanitized_string) is None
    assert 'end subroutine sub_simple_3' in reader.sanitized_string

    # Sanity check for line numbers
    assert reader.sanitized_lines[0].span[0] >= lines[0]
    assert reader.sanitized_lines[1].span[1] <= lines[1]
    assert len(reader.source_lines) == lines[1] - lines[0]

    # Check for comments at the top that are removed
    source = reader.source_from_head()
    assert source.lines == (1, 4)
    assert all(line.strip().startswith('!') or not line.strip() for line in source.string.splitlines())

    assert reader.source_from_tail() is None

    # Test extracting substrings
    start = reader.sanitized_string.find('module foo')
    end = reader.sanitized_string.find('end module foo') + len('end module foo')
    assert 0 < start < end

    # without padding
    new_reader = reader.reader_from_sanitized_span((start, end))
    assert new_reader.sanitized_lines[0].span[0] == 51
    assert new_reader.sanitized_lines[-1].span[1] == 64

    source = new_reader.to_source()
    assert source.lines == (51, 64)
    assert source.string.startswith('module foo')
    assert source.string.endswith('end module foo')

    assert new_reader.source_from_tail() is None

    # with padding
    new_reader = reader.reader_from_sanitized_span((start, end), include_padding=True)
    assert new_reader.sanitized_lines[0].span[0] == 51
    assert new_reader.sanitized_lines[-1].span[1] == 64

    source = new_reader.to_source()
    assert source.lines == (51, 64)
    assert source.string.startswith('module foo')
    assert source.string.endswith('end module foo')

    source = new_reader.to_source(include_padding=True)
    assert source.lines == (49, 66)
    assert source.string.startswith('\n!')
    assert source.string.splitlines()[-1].startswith('!')

    source = new_reader.source_from_tail()
    assert source.lines == (65, 66)
    assert all(line.strip().startswith('!') or not line.strip() for line in source.string.splitlines())

    source = reader.source_from_sanitized_span((start, end))
    assert source.lines == (51, 64)
    assert source.string.startswith('module foo')
    assert source.string.endswith('end module foo')

    source = reader.source_from_sanitized_span((start, end), include_padding=True)
    assert source.lines == (49, 66)
    assert source.string.startswith('\n!')
    assert source.string.splitlines()[-1].startswith('!')

    # Test nested reader
    start = new_reader.sanitized_string.find('subroutine foo_sub')
    end = new_reader.sanitized_string.find('end subroutine foo_sub') + len('end subroutine foo_sub')
    assert 0 < start < end < len(new_reader.sanitized_string)

    nested_reader = new_reader.reader_from_sanitized_span((start, end))
    assert nested_reader.sanitized_lines[0].span[0] == 55
    assert nested_reader.sanitized_lines[-1].span[1] == 59

    source = nested_reader.to_source()
    assert source.lines == (55, 59)
    assert source.string.startswith('subroutine foo_sub')

    source = new_reader.source_from_sanitized_span((start, end))
    assert source.lines == (55, 59)
    assert source.string.startswith('subroutine foo_sub')
    assert source.string.splitlines()[-1].startswith('end subroutine foo_sub')

    # Test extracting substring at the start
    start = reader.sanitized_string.find('logical function func_simple')
    end = reader.sanitized_string.find('end function func_simple') + len('end function func_simple')
    assert start == 0 and end > start

    new_reader = reader.reader_from_sanitized_span((start, end))
    assert new_reader.sanitized_lines[0].span[0] == 5
    assert new_reader.sanitized_lines[-1].span[1] == 7

    source = reader.source_from_sanitized_span((start, end), include_padding=True)
    assert source.lines == (1, 9)
    assert source.string.startswith('!')
    assert source.string.splitlines()[-1].startswith('!')

    source = reader.source_from_sanitized_span((start, end))
    assert source.lines == (5, 7)
    assert source.string.startswith('logical function func_simple')
    assert source.string.endswith('end function func_simple')

    # Test extracting substring at the end
    start = reader.sanitized_string.find('subroutine sub_with_end')
    end = reader.sanitized_string.find('end subroutine sub_with_end') + len('end subroutine sub_with_end')
    assert 0 < start < end == len(reader.sanitized_string)

    new_reader = reader.reader_from_sanitized_span((start, end))
    assert new_reader.sanitized_lines[0].span[0] == 181
    assert new_reader.sanitized_lines[-1].span[1] == 184

    source = reader.source_from_sanitized_span((start, end))
    assert source.lines == (181, 184)
    assert source.string.startswith('subroutine sub_with_end')
    assert source.string.splitlines()[-1].startswith('end subroutine')

    # Test extracting open-ended substring
    end = None

    new_reader = reader.reader_from_sanitized_span((start, end))
    assert new_reader.sanitized_lines[0].span[0] == 181
    assert new_reader.sanitized_lines[-1].span[1] == 184

    source = reader.source_from_sanitized_span((start, end))
    assert source.lines == (181, 184)
    assert source.string.startswith('subroutine sub_with_end')
    assert source.string.splitlines()[-1].startswith('end subroutine')


def test_fortran_reader_iterate(here):
    """Test :any:`FortranReader` iteration"""
    filepath = here/'sources/Fortran-extract-interface-source.f90'
    fcode = read_file(filepath)

    reader = FortranReader(fcode)
    sanitized_code = reader.sanitized_string

    assert reader.current_line is None

    # Test that iterating reproduces the sanitized code
    assert sanitized_code == '\n'.join(item.line for item in reader)

    # Test that we can request the current line string within the iteration range
    iterated_code = ''
    for _ in reader:
        iterated_code += reader.current_line.line + '\n'
    iterated_code = iterated_code[:-1]
    assert sanitized_code == iterated_code

    assert reader.current_line is None

    # Test that we can generate source objects while iterating, that contain
    # the original formatting (this excludes lines missing due to sanitzation)

    def sanitize_empty_lines_and_comments(string):
        sanitized_string = ''
        for line in string.splitlines():
            if not line.lstrip() or line.lstrip().startswith('!'):
                continue
            sanitized_string += line + '\n'
        return sanitized_string

    iterated_code = ''
    for _ in reader:
        iterated_code += reader.source_from_current_line().string + '\n'
    assert sanitize_empty_lines_and_comments(fcode) == iterated_code


@pytest.mark.parametrize('fcode', ['', '\n'])
def test_fortran_reader_empty(fcode):
    """Test :any:`FortranReader` for empty strings"""
    reader = FortranReader(fcode)
    assert isinstance(reader, FortranReader)
    assert not reader.source_lines
    assert not reader.sanitized_lines
    source = reader.to_source()
    assert isinstance(source, Source)
    assert source.lines == (1, 1)
    assert source.string == ''
loki-ecmwf-0.3.6/loki/tests/include/0000775000175000017500000000000015167130205017464 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/tests/include/some_header.h0000664000175000017500000000011115167130205022101 0ustar  alastairalastair
! Oh yes, a statement function!
real(kind=4) :: add
add( a, b ) = a + b
loki-ecmwf-0.3.6/loki/tests/test_dimension.py0000664000175000017500000001323515167130205021443 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki import Subroutine, Dimension, FindNodes, Loop
from loki.batch import SchedulerConfig
from loki.expression import symbols as sym
from loki.frontend import available_frontends
from loki.types import BasicType, Scope, SymbolAttributes


def test_dimension_properties():
    """
    Test that :any:`Dimension` objects store the correct strings.
    """
    scope = Scope()
    type_int = SymbolAttributes(dtype=BasicType.INTEGER)
    i = sym.Variable(name='i', type=type_int, scope=scope)
    n = sym.Variable(name='n', type=type_int, scope=scope)
    z = sym.Variable(name='z', type=type_int, scope=scope)
    one = sym.IntLiteral(1)
    two = sym.IntLiteral(2)

    simple = Dimension('simple', index='i', upper='n', size='z')
    assert simple.index == i
    assert simple.upper == n
    assert simple.size == z

    detail = Dimension(index='i', lower='1', upper='n', step='2', size='z')
    assert detail.index == i
    assert detail.lower == one
    assert detail.upper == n
    assert detail.step == two
    assert detail.size == z
    # Check derived properties
    assert detail.bounds == (one, n)
    assert detail.range == sym.LoopRange((1, n))

    multi = Dimension(
        index=('i', 'idx'), lower=('1', 'start'), upper=('n', 'end'), size='z'
    )
    assert multi.index == i
    assert multi.indices == (i, sym.Variable(name='idx', type=type_int, scope=scope))
    assert multi.lower == (one, sym.Variable(name='start', type=type_int, scope=scope))
    assert multi.upper == (n, sym.Variable(name='end', type=type_int, scope=scope))
    assert multi.size == z
    # Check derived properties
    assert multi.bounds ==  (one, n)
    assert multi.range == sym.LoopRange((1, n))


@pytest.mark.parametrize('frontend', available_frontends())
def test_dimension_size(frontend):
    """
    Test that :any:`Dimension` objects match size expressions.
    """
    fcode = """
subroutine test_dimension_size(nlon, start, end, arr)
  integer, intent(in) :: NLON, START, END
  real, intent(inout) :: arr(nlon)
  real :: local_arr(1:nlon)
  real :: range_arr(end-start+1)

  arr(start:end) = 1.
end subroutine test_dimension_size
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    # Create the dimension object and make sure we match all array sizes
    dim = Dimension(name='test_dim', size='nlon', bounds=('start', 'end'))
    assert routine.variable_map['nlon'] == dim.size
    assert routine.variable_map['arr'].dimensions[0] == dim.size

    # Ensure that aliased size expressions laos trigger right
    assert routine.variable_map['nlon'] in dim.size_expressions
    assert routine.variable_map['local_arr'].dimensions[0] in dim.size_expressions
    assert routine.variable_map['range_arr'].dimensions[0] in dim.size_expressions


@pytest.mark.parametrize('frontend', available_frontends())
def test_dimension_index_range(frontend):
    """
    Test that :any:`Dimension` objects match index and range expressions.
    """
    fcode = """
subroutine test_dimension_index(nlon, start, end, arr)
  integer, intent(in) :: NLON, START, END
  real, intent(inout) :: arr(nlon)
  integer :: I

  do i=start, end
    arr(I) = 1.
  end do
end subroutine test_dimension_index
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    # Create the dimension object and make sure we match all array sizes
    dim = Dimension(name='test_dim', index='i', bounds=('start', 'end'))
    assert routine.variable_map['i'] == dim.index

    assert FindNodes(Loop).visit(routine.body)[0].bounds == dim.range
    assert FindNodes(Loop).visit(routine.body)[0].bounds.lower == dim.bounds[0]
    assert FindNodes(Loop).visit(routine.body)[0].bounds.upper == dim.bounds[1]

    # Test the correct creation of horizontal dim with aliased bounds vars
    _ = Dimension('test_dim_alias', bounds_aliases=('bnds%start', 'bnds%end'))


def test_dimension_config(tmp_path):
    """
    Test that :any:`Dimension` objects get created from
    :any:`SchedulerConfig` correctly.
    """
    scope = Scope()
    type_int = SymbolAttributes(dtype=BasicType.INTEGER)
    type_deferred = SymbolAttributes(dtype=BasicType.DEFERRED)
    ibl = sym.Variable(name='ibl', type=type_int, scope=scope)
    nblocks = sym.Variable(name='nblocks', type=type_int, scope=scope)
    start = sym.Variable(name='start', type=type_int, scope=scope)
    end = sym.Variable(name='end', type=type_int, scope=scope)
    dim = sym.Variable(name='dim', type=type_deferred, scope=scope)
    one = sym.IntLiteral(1)

    config_str = """
[dimensions.dim_a]
  size = 'NBLOCKS'
  index = 'IBL'
  bounds = ['START', 'END']
  aliases = ['DIM%START', 'DIM%END']

[dimensions.dim_b]
  size = 'nblocks'
  index = 'ibl'
  lower = ['1', 'start', 'dim%start']
  upper = ['nblocks', 'end', 'dim%end']
"""
    cfg_path = tmp_path/'test_config.yml'
    cfg_path.write_text(config_str)

    config = SchedulerConfig.from_file(cfg_path)
    dim_a = config.dimensions['dim_a']
    assert dim_a.size == nblocks
    assert dim_a.index == ibl
    assert dim_a.bounds == (start, end)

    dim_b = config.dimensions['dim_b']
    assert dim_b.size == nblocks
    assert dim_b.index == ibl
    assert dim_b.bounds == (sym.IntLiteral(1), nblocks)
    assert dim_b.lower == (one, start, start.clone(parent=dim))  # pylint: disable=no-member
    assert dim_b.upper == (nblocks, end, end.clone(parent=dim))  # pylint: disable=no-member
loki-ecmwf-0.3.6/loki/backend/0000775000175000017500000000000015167130205016266 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/backend/__init__.py0000664000175000017500000000154415167130205020403 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
"""
Backend classes that convert Loki IR into output code in various languages.
"""

from loki.backend.cgen import * # noqa
from loki.backend.cppgen import * # noqa
from loki.backend.cudagen import * # noqa
from loki.backend.cufgen import * # noqa
from loki.backend.dacegen import * # noqa
from loki.backend.fgen import * # noqa
from loki.backend.fgencon import * # noqa
from loki.backend.pygen import * # noqa
from loki.backend.pprint import * # noqa
from loki.backend.style import * # noqa
loki-ecmwf-0.3.6/loki/backend/tests/0000775000175000017500000000000015167130205017430 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/backend/tests/__init__.py0000664000175000017500000000057015167130205021543 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
loki-ecmwf-0.3.6/loki/backend/tests/test_pygen.py0000664000175000017500000004254215167130205022172 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import sys
from pathlib import Path
from importlib import import_module, reload, invalidate_caches
from collections import namedtuple
import pytest
import numpy as np

from loki import Subroutine
from loki.backend import pygen
from loki.jit_build import jit_compile, clean_test
from loki.frontend import available_frontends, OMNI
from loki.transformations.transpile import FortranPythonTransformation


def load_module(tmp_path, module):
    """
    A helper routine that loads the given module from the current path.
    """
    modpath = str(Path(tmp_path).absolute())
    if modpath not in sys.path:
        sys.path.insert(0, modpath)
    if module in sys.modules:
        reload(sys.modules[module])
        return sys.modules[module]

    # Trigger the actual module import
    try:
        return import_module(module)
    except ModuleNotFoundError:
        # If module caching interferes, try again with clean caches
        invalidate_caches()
        return import_module(module)


@pytest.mark.parametrize('frontend', available_frontends())
def test_pygen_simple_loops(tmp_path, frontend):
    """
    A simple test routine to test Python transpilation of loops
    """

    fcode = """
subroutine pygen_simple_loops(n, m, scalar, vector, tensor)
  use iso_fortran_env, only: real64
  implicit none
  integer, intent(in) :: n, m
  real(kind=real64), intent(inout) :: scalar
  real(kind=real64), intent(inout) :: vector(n), tensor(n, m)

  integer :: i, j

  ! For testing, the operation is:
  do i=1, n
     vector(i) = vector(i) + tensor(i, 1) + 1.0
  end do

  do j=1, m
     do i=1, n
        tensor(i, j) = 10.* j + i
     end do
  end do
end subroutine pygen_simple_loops
"""

    # Generate reference code, compile run and verify
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/(f'pygen_simple_loops_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname='pygen_simple_loops')

    n, m = 3, 4
    scalar = 2.0
    vector = np.zeros(shape=(n,), order='F') + 3.
    tensor = np.zeros(shape=(n, m), order='F') + 4.
    function(n, m, scalar, vector, tensor)

    assert np.all(vector == 8.)
    assert np.all(tensor == [[11., 21., 31., 41.],
                             [12., 22., 32., 42.],
                             [13., 23., 33., 43.]])

    # Rename routine to avoid problems with module import caching
    routine.name = f'{routine.name}_{str(frontend)}'

    # Generate and test the transpiled Python kernel
    f2p = FortranPythonTransformation(suffix='_py')
    f2p.apply(source=routine, path=tmp_path)
    mod = load_module(tmp_path, f2p.mod_name)
    func = getattr(mod, f2p.mod_name)

    n, m = 3, 4
    scalar = 2.0
    vector = np.zeros(shape=(n,), order='F') + 3.
    tensor = np.zeros(shape=(n, m), order='F') + 4.
    func(n, m, scalar, vector, tensor)

    assert np.all(vector == 8.)
    assert np.all(tensor == [[11., 21., 31., 41.],
                             [12., 22., 32., 42.],
                             [13., 23., 33., 43.]])

    clean_test(filepath)
    f2p.py_path.unlink()


@pytest.mark.parametrize('frontend', available_frontends())
def test_pygen_arguments(tmp_path, frontend):
    """
    Test the correct exchange of arguments with varying intents
    """

    fcode = """
subroutine pygen_arguments(n, array, array_io, a, b, c, a_io, b_io, c_io)
  use iso_fortran_env, only: real32, real64
  implicit none

  integer, intent(in) :: n
  real(kind=real64), intent(inout) :: array(n)
  real(kind=real64), intent(out) :: array_io(n)

  integer, intent(out) :: a
  real(kind=real32), intent(out) :: b
  real(kind=real64), intent(out) :: c
  integer, intent(inout) :: a_io
  real(kind=real32), intent(inout) :: b_io
  real(kind=real64), intent(inout) :: c_io

  integer :: i

  do i=1, n
     array(i) = 3.
     array_io(i) = array_io(i) + 3.
  end do

  a = 2**3
  b = 3.2_real32
  c = 4.1_real64

  a_io = a_io + 2
  b_io = b_io + real(3.2, kind=real32)
  c_io = c_io + 4.1
end subroutine pygen_arguments
"""

    # Test the reference solution
    n = 3
    array = np.zeros(shape=(n,), order='F')
    array_io = np.zeros(shape=(n,), order='F') + 3.
    # To do scalar inout we allocate data in single-element arrays
    a_io = np.array(1)
    b_io = np.array(2.)
    c_io = np.array(3.)

    # Generate reference code, compile run and verify
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/(f'pygen_arguments_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname='pygen_arguments')
    a, b, c = function(n, array, array_io, a_io, b_io, c_io)

    assert np.all(array == 3.) and array.size == n
    assert np.all(array_io == 6.)
    assert a_io == 3 and np.isclose(b_io, 5.2) and np.isclose(c_io, 7.1)
    assert a == 8 and np.isclose(b, 3.2) and np.isclose(c, 4.1)

    # Rename routine to avoid problems with module import caching
    routine.name = f'{routine.name}_{str(frontend)}'

    # Generate and test the transpiled Python kernel
    f2p = FortranPythonTransformation(suffix='_py')
    f2p.apply(source=routine, path=tmp_path)
    mod = load_module(tmp_path, f2p.mod_name)
    func = getattr(mod, f2p.mod_name)

    array = np.zeros(shape=(n,), order='F')
    array_io = np.zeros(shape=(n,), order='F') + 3.
    a_io = np.array(1)
    b_io = np.array(2.)
    c_io = np.array(3.)
    a, b, c, a_io, b_io, c_io = func(n, array, array_io, a_io, b_io, c_io)

    assert np.all(array == 3.) and array.size == n
    assert np.all(array_io == 6.)
    assert a_io == 3. and np.isclose(b_io, 5.2) and np.isclose(c_io, 7.1)
    assert a == 8 and np.isclose(b, 3.2) and np.isclose(c, 4.1)

    clean_test(filepath)
    f2p.py_path.unlink()


# TODO: implement and test transpilation of derived types

# TODO: implement and test transpilation of associates

# TODO: implement and test transpilation of modules


@pytest.mark.parametrize('frontend', available_frontends())
def test_pygen_vectorization(tmp_path, frontend):
    """
    Tests vector-notation conversion and local multi-dimensional arrays.
    """

    fcode = """
subroutine pygen_vectorization(n, m, scalar, v1, v2)
  use iso_fortran_env, only: real64
  implicit none
  integer, intent(in) :: n, m
  real(kind=real64), intent(inout) :: scalar
  real(kind=real64), intent(inout) :: v1(n), v2(n)

  real(kind=real64) :: matrix(n, m)

  integer :: i

  v1(:) = scalar + 1.0
  matrix(:, 1:m) = scalar + 2.
  v2(:n) = matrix(:, 2)
  v2(1) = 1.
end subroutine pygen_vectorization
"""

    # Generate reference code, compile run and verify
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/(f'pygen_vectorization_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname='pygen_vectorization')

    n, m = 3, 4
    scalar = 2.0
    v1 = np.zeros(shape=(n,), order='F')
    v2 = np.zeros(shape=(n,), order='F')
    function(n, m, scalar, v1, v2)

    assert np.all(v1 == 3.)
    assert v2[0] == 1. and np.all(v2[1:] == 4.)

    # Rename routine to avoid problems with module import caching
    routine.name = f'{routine.name}_{str(frontend)}'

    # Generate and test the transpiled Python kernel
    f2p = FortranPythonTransformation(suffix='_py')
    f2p.apply(source=routine, path=tmp_path)
    mod = load_module(tmp_path, f2p.mod_name)
    func = getattr(mod, f2p.mod_name)

    # Test the transpiled Python kernel
    n, m = 3, 4
    scalar = 2.0
    v1 = np.zeros(shape=(n,), order='F')
    v2 = np.zeros(shape=(n,), order='F')
    scalar = func(n, m, scalar, v1, v2)

    assert np.all(v1 == 3.)
    assert v2[0] == 1. and np.all(v2[1:] == 4.)

    clean_test(filepath)
    f2p.py_path.unlink()


@pytest.mark.parametrize('frontend', available_frontends())
def test_pygen_intrinsics(tmp_path, frontend):
    """
    A simple test routine to test supported intrinsic functions
    """

    fcode = """
subroutine pygen_intrinsics(v1, v2, v3, v4, vmin, vmax, vabs, vmin_nested, vmax_nested, vexp, vsqrt, vsign)
  ! Test supported intrinsic functions
  use iso_fortran_env, only: real64
  real(kind=real64), intent(in) :: v1, v2, v3, v4
  real(kind=real64), intent(out) :: vmin, vmax, vabs, vmin_nested, vmax_nested, vexp, vsqrt, vsign

  vmin = MIN(v1, v2)
  vmax = MAX(v1, v2)
  vabs = ABS(v1 - v2)
  vmin_nested = MIN(MIN(MAX(v1, -1._real64), v2), MIN(v3, v4))
  vmax_nested = MAX(MAX(v1, v2), MAX(v3, v4))
  vexp = EXP(v2)
  vsqrt = SQRT(v2)
  vsign = SIGN(v4, v1-v2)
end subroutine pygen_intrinsics
"""

    # Generate reference code, compile run and verify
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/(f'pygen_intrinsics_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname='pygen_intrinsics')

    # Test the reference solution
    v1, v2, v3, v4 = 2., 4., 1., 5.
    vmin, vmax, vabs, vmin_nested, vmax_nested, vexp, vsqrt, vsign = function(v1, v2, v3, v4)
    assert vmin == 2. and vmax == 4. and vabs == 2.
    assert vmin_nested == 1. and vmax_nested == 5.
    assert vexp == np.exp(4.) and vsqrt == 2.
    assert vsign == -5.

    # Rename routine to avoid problems with module import caching
    routine.name = f'{routine.name}_{str(frontend)}'

    # Generate and test the transpiled Python kernel
    f2p = FortranPythonTransformation(suffix='_py')
    f2p.apply(source=routine, path=tmp_path)
    mod = load_module(tmp_path, f2p.mod_name)
    func = getattr(mod, f2p.mod_name)

    vmin, vmax, vabs, vmin_nested, vmax_nested, vexp, vsqrt, vsign = func(v1, v2, v3, v4)
    assert vmin == 2. and vmax == 4. and vabs == 2.
    assert vmin_nested == 1. and vmax_nested == 5.
    assert vexp == np.exp(4.) and vsqrt == 2.
    assert vsign == -5.

    clean_test(filepath)
    f2p.py_path.unlink()


@pytest.mark.parametrize('frontend', available_frontends())
def test_pygen_loop_indices(tmp_path, frontend):
    """
    Test to ensure loop indexing translates correctly
    """

    fcode = """
subroutine pygen_loop_indices(n, idx, mask1, mask2, mask3)
  ! Test to ensure loop indexing translates correctly
  use iso_fortran_env, only: real64
  integer, intent(in) :: n, idx
  integer, intent(inout) :: mask1(n), mask2(n)
  real(kind=real64), intent(inout) :: mask3(n)

  integer :: i

  do i=1, n
     if (i < idx) then
        mask1(i) = 1
     elseif (i == idx) then
        mask1(i) = 2
     else
        mask1(i) = 3
     end if

     mask2(i) = i
  end do

  i = 1
  do while (i <= idx)
    mask3(i) = 2.0
    i = i + 1
  end do
  mask3(n) = 3.0
end subroutine pygen_loop_indices
"""

    # Generate reference code, compile run and verify
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/(f'pygen_loop_indices_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname='pygen_loop_indices')

    # Test the reference solution
    n = 6
    cidx, fidx = 3, 4
    mask1 = np.zeros(shape=(n,), order='F', dtype=np.int32)
    mask2 = np.zeros(shape=(n,), order='F', dtype=np.int32)
    mask3 = np.zeros(shape=(n,), order='F', dtype=np.float64)

    function(n=n, idx=fidx, mask1=mask1, mask2=mask2, mask3=mask3)
    assert np.all(mask1[:cidx-1] == 1)
    assert mask1[cidx] == 2
    assert np.all(mask1[cidx+1:] == 3)
    assert np.all(mask2 == np.arange(n, dtype=np.int32) + 1)
    assert np.all(mask3[:fidx] == 2.)
    assert mask3[-1] == 3.

    # Rename routine to avoid problems with module import caching
    routine.name = f'{routine.name}_{str(frontend)}'

    # Generate and test the transpiled Python kernel
    f2p = FortranPythonTransformation(suffix='_py')
    f2p.apply(source=routine, path=tmp_path)
    mod = load_module(tmp_path, f2p.mod_name)
    func = getattr(mod, f2p.mod_name)

    mask1 = np.zeros(shape=(n,), order='F', dtype=np.int32)
    mask2 = np.zeros(shape=(n,), order='F', dtype=np.int32)
    mask3 = np.zeros(shape=(n,), order='F', dtype=np.float64)
    func(n=n, idx=fidx, mask1=mask1, mask2=mask2, mask3=mask3)
    assert np.all(mask1[:cidx-1] == 1)
    assert mask1[cidx] == 2
    assert np.all(mask1[cidx+1:] == 3)
    assert np.all(mask2 == np.arange(n, dtype=np.int32) + 1)
    assert np.all(mask3[:fidx] == 2.)
    assert mask3[-1] == 3.

    clean_test(filepath)
    f2p.py_path.unlink()


@pytest.mark.parametrize('frontend', available_frontends())
def test_pygen_logical_statements(tmp_path, frontend):
    """
    A simple test routine to test logical statements
    """

    fcode = """
subroutine pygen_logical_statements(v1, v2, v_xor, v_xnor, v_nand, v_neqv, v_val)
  logical, intent(in) :: v1, v2
  logical, intent(out) :: v_xor, v_nand, v_xnor, v_neqv, v_val(2)

  v_xor = (v1 .and. .not. v2) .or. (.not. v1 .and. v2)
  v_xnor = v1 .eqv. v2
  v_nand = .not. (v1 .and. v2)
  v_neqv = v1 .neqv. v2
  v_val(1) = .true.
  v_val(2) = .false.

end subroutine pygen_logical_statements
"""

    # Generate reference code, compile run and verify
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/(f'pygen_logical_statements_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname='pygen_logical_statements')

    # Test the reference solution
    for v1 in range(2):
        for v2 in range(2):
            v_val = np.zeros(shape=(2,), order='F', dtype=np.int32)
            v_xor, v_xnor, v_nand, v_neqv = function(v1, v2, v_val)
            assert v_xor == (v1 and not v2) or (not v1 and v2)
            assert v_xnor == (v1 and v2) or not (v1 or v2)
            assert v_nand == (not (v1 and v2))
            assert v_neqv == ((not (v1 and v2)) and (v1 or v2))
            assert v_val[0] and not v_val[1]

    # Rename routine to avoid problems with module import caching
    routine.name = f'{routine.name}_{str(frontend)}'

    # Generate and test the transpiled Python kernel
    f2p = FortranPythonTransformation(suffix='_py')
    f2p.apply(source=routine, path=tmp_path)
    mod = load_module(tmp_path, f2p.mod_name)
    func = getattr(mod, f2p.mod_name)

    for v1 in range(2):
        for v2 in range(2):
            v_val = np.zeros(shape=(2,), order='F', dtype=np.int32)
            v_xor, v_xnor, v_nand, v_neqv = func(v1, v2, v_val)
            assert v_xor == (v1 and not v2) or (not v1 and v2)
            assert v_xnor == (v1 and v2) or not (v1 or v2)
            assert v_nand == (not (v1 and v2))
            assert v_neqv == ((not (v1 and v2)) and (v1 or v2))
            assert v_val[0] and not v_val[1]

    clean_test(filepath)
    f2p.py_path.unlink()


@pytest.mark.parametrize('frontend', available_frontends())
def test_pygen_downcasing(tmp_path, frontend):
    """
    A simple test routine to test the conversion to lower case.
    """

    fcode = """
subroutine pygen_downcasing(n, ScalaR, VectOr)
  use iso_fortran_env, only: real64
  implicit none
  integer, intent(in) :: N
  real(kind=real64), intent(inout) :: scalar
  real(kind=real64), intent(inout) :: vector(n)

  integer :: i
  real(kind=real64) :: a, tmp

  real(kind=real64) :: sTmT_F
  sTmT_F(a) = a + 2.

  do i=1, n
     tmp = stmt_F(scalar)
     veCtor(i) = vecTor(i) + tmp
  end do

end subroutine pygen_downcasing
"""

    # Generate reference code, compile run and verify
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/(f'pygen_downcasing_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname='pygen_downcasing')

    n = 3
    scalar = 2.0
    vector = np.zeros(shape=(n,), order='F') + 2.
    function(n, scalar, vector)
    assert np.all(vector == 6.)

    # Rename routine to avoid problems with module import caching
    routine.name = f'{routine.name}_{str(frontend)}'

    # Generate and test the transpiled Python kernel
    f2p = FortranPythonTransformation(suffix='_py')
    f2p.apply(source=routine, path=tmp_path)
    mod = load_module(tmp_path, f2p.mod_name)
    func = getattr(mod, f2p.mod_name)

    assert pygen(routine).islower()

    n = 3
    scalar = 2.0
    vector = np.zeros(shape=(n,), order='F') + 2.
    func(n, scalar, vector)
    assert np.all(vector == 6.)


@pytest.mark.parametrize('frontend', available_frontends(
    xfail=[(OMNI, 'OMNI strictly needs type definitions')])
)
def test_pygen_derived_type_members(tmp_path, frontend):
    """
    A simple test to check derived type member usage.
    """

    fcode = """
subroutine pygen_derived_type_members(n, MyObject)
  use iso_fortran_env, only: real64
  use some_module, only: my_TYPE
  implicit none

  integer, intent(in) :: N
  type(my_TYPE), intent(in) :: MyObject

  integer :: i
  real(kind=real64) :: tmp

  do i=1, n
     tmp = myobject%vector(i) + myobject%scalar
     myobject%vector(i) = tmp
  end do

end subroutine pygen_derived_type_members
"""

    routine = Subroutine.from_source(fcode, frontend=frontend)

    # TODO: Implement type definition representation and test!
    # Without the TypeDef, we can't test the reference either.

    # Generate and test the transpiled Python kernel
    f2p = FortranPythonTransformation(suffix='_py')
    f2p.apply(source=routine, path=tmp_path)
    mod = load_module(tmp_path, f2p.mod_name)
    func = getattr(mod, f2p.mod_name)

    n = 3
    MyType = namedtuple('MyType', ['scalar', 'vector'])
    obj = MyType(scalar=40.0, vector=np.zeros(shape=(n,), order='F') + 2.)
    func(n, obj)
    assert np.all(obj.vector == 42.)
loki-ecmwf-0.3.6/loki/backend/tests/test_fstyle.py0000664000175000017500000001046015167130205022350 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki import Module
from loki.backend import fgen, FortranStyle, IFSFortranStyle
from loki.frontend import available_frontends, OMNI


@pytest.fixture(scope='module', name='fcode')
def fixture_fcode():
    return """
MODULE ONCE_UPON
IMPLICIT NONE

INTEGER, PARAMETER :: ATIME = 8

CONTAINS

  SUBROUTINE THERE_WERE ( N, M, RICK, DAVE, NEVER )
    INTEGER(KIND=4), INTENT(IN) :: N, M
    REAL(KIND=ATIME), INTENT(INOUT) :: RICK, DAVE(N)
    REAL(KIND=ATIME), INTENT(OUT) :: NEVER(N, M)
    INTEGER :: I, J, IJK

    DO I=1, N
      DAVE(I) = DAVE(I) + RICK
    END DO

    IF (RICK > 0.5) THEN
      DO I=1, N
        DO J=1, M
          NEVER(I, J) = RICK + DAVE(I)
        END DO
      ENDDO
    END IF
  END SUBROUTINE
END MODULE
"""


@pytest.mark.parametrize('frontend', available_frontends(
    skip=[(OMNI, "OMNI enforces its own stylistic quirks!")]
))
def test_fgen_default_style(frontend, tmp_path, fcode):
    """ Test the default Fortran styles of the fgen backend """
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    # Test the default Fortran layout
    generated_default = fgen(module)
    assert generated_default == """
MODULE ONCE_UPON
  IMPLICIT NONE

  INTEGER, PARAMETER :: ATIME = 8

  CONTAINS

  SUBROUTINE THERE_WERE (N, M, RICK, DAVE, NEVER)
    INTEGER(KIND=4), INTENT(IN) :: N, M
    REAL(KIND=ATIME), INTENT(INOUT) :: RICK, DAVE(N)
    REAL(KIND=ATIME), INTENT(OUT) :: NEVER(N, M)
    INTEGER :: I, J, IJK

    DO I=1,N
      DAVE(I) = DAVE(I) + RICK
    END DO

    IF (RICK > 0.5) THEN
      DO I=1,N
        DO J=1,M
          NEVER(I, J) = RICK + DAVE(I)
        END DO
      END DO
    END IF
  END SUBROUTINE THERE_WERE
END MODULE ONCE_UPON
""".strip()


@pytest.mark.parametrize('frontend', available_frontends(
    skip=[(OMNI, "OMNI enforces its own stylistic quirks!")]
))
def test_fgen_custom_style(frontend, tmp_path, fcode):
    """ Test a custom Fortran styles of the fgen backend """
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    # Test a custom Fortran layout
    custom_style = FortranStyle(
        conditional_indent=4,
        conditional_end_space=False,
        loop_indent=3,
        loop_end_space=False,
        procedure_spec_indent=5,
        procedure_body_indent=1,
        # procedure_contains_indent=2,
        procedure_end_named=False,
        module_spec_indent=3,
        module_contains_indent=1,
        module_end_named=False,
    )
    generated_custom = fgen(module, style=custom_style)
    assert generated_custom == """
MODULE ONCE_UPON
   IMPLICIT NONE

   INTEGER, PARAMETER :: ATIME = 8

 CONTAINS

 SUBROUTINE THERE_WERE (N, M, RICK, DAVE, NEVER)
      INTEGER(KIND=4), INTENT(IN) :: N, M
      REAL(KIND=ATIME), INTENT(INOUT) :: RICK, DAVE(N)
      REAL(KIND=ATIME), INTENT(OUT) :: NEVER(N, M)
      INTEGER :: I, J, IJK

  DO I=1,N
     DAVE(I) = DAVE(I) + RICK
  ENDDO

  IF (RICK > 0.5) THEN
      DO I=1,N
         DO J=1,M
            NEVER(I, J) = RICK + DAVE(I)
         ENDDO
      ENDDO
  ENDIF
 END SUBROUTINE
END MODULE
""".strip()


@pytest.mark.parametrize('frontend', available_frontends(
    skip=[(OMNI, "OMNI enforces its own stylistic quirks!")]
))
def test_fgen_ifs_style(frontend, tmp_path, fcode):
    """ Test an IFS-specific Fortran styles of the fgen backend """
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    generated = fgen(module, style=IFSFortranStyle())
    assert generated == """
MODULE ONCE_UPON
IMPLICIT NONE

INTEGER, PARAMETER :: ATIME = 8

  CONTAINS

  SUBROUTINE THERE_WERE (N, M, RICK, DAVE, NEVER)
  INTEGER(KIND=4), INTENT(IN) :: N, M
  REAL(KIND=ATIME), INTENT(INOUT) :: RICK, DAVE(N)
  REAL(KIND=ATIME), INTENT(OUT) :: NEVER(N, M)
  INTEGER :: I, J, IJK

  DO I=1,N
    DAVE(I) = DAVE(I) + RICK
  ENDDO

  IF (RICK > 0.5) THEN
    DO I=1,N
      DO J=1,M
        NEVER(I, J) = RICK + DAVE(I)
      ENDDO
    ENDDO
  ENDIF
  END SUBROUTINE THERE_WERE
END MODULE ONCE_UPON
""".strip()
loki-ecmwf-0.3.6/loki/backend/tests/test_stringifier.py0000664000175000017500000001660215167130205023373 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from io import StringIO

import pytest

from loki import Module, Subroutine
from loki.backend import DefaultStyle, Stringifier, pprint
from loki.frontend import available_frontends, OMNI


@pytest.mark.parametrize('frontend', available_frontends())
def test_stringifier(frontend, tmp_path):
    """
    Test basic stringifier capability for most IR nodes.
    """
    fcode = """
MODULE some_mod
  INTEGER :: n
  !$loki dimension(klon)
  REAL :: arr(:)
  CONTAINS
    SUBROUTINE some_routine (x, y)
      ! This is a basic subroutine with some loops
      IMPLICIT NONE
      REAL, INTENT(IN) :: x
      REAL, INTENT(OUT) :: y
      INTEGER :: i
      ! And now to the content
      IF (x < 1E-8 .and. x > -1E-8) THEN
        x = 0.
      ELSE IF (x > 0.) THEN
        DO WHILE (x > 1.)
          x = x / 2.
        ENDDO
      ELSE
        x = -x
      ENDIF
      y = 0
      DO i=1,n
        y = y + x*x
      ENDDO
      y = my_sqrt(y) + 1. + 1. + 1. + 1. + 1. + 1. + 1. + 1. + 1. + 1. + 1. + 1. + 1.
    END SUBROUTINE some_routine
    FUNCTION my_sqrt (arg)
      IMPLICIT NONE
      REAL, INTENT(IN) :: arg
      REAL :: my_sqrt
      my_sqrt = SQRT(arg)
    END FUNCTION my_sqrt
  SUBROUTINE other_routine (m)
    ! This is just to have some more IR nodes
    ! with multi-line comments and everything...
    IMPLICIT NONE
    INTEGER, INTENT(IN) :: m
    REAL, ALLOCATABLE :: var(:)
    !$loki some pragma
    SELECT CASE (m)
      CASE (0)
        m = 1
      CASE (1:10)
        PRINT *, '1 to 10'
      CASE (-1, -2)
        m = 10
      CASE DEFAULT
        PRINT *, 'Default case'
    END SELECT
    ASSOCIATE (x => arr(m))
      x = x * 2.
    END ASSOCIATE
    ALLOCATE(var, source=arr)
    CALL some_routine (arr(1), var(1))
    arr(:) = arr(:) + var(:)
    DEALLOCATE(var)
  END SUBROUTINE other_routine
END MODULE some_mod
    """.strip()
    ref_lines = [
        "",  # l. 1
        "#",
        "##",
        "##",
        "##",
        "#",
        "##",
        "##",
        "###",
        "###",  # l. 10
        "###",
        "###",
        "##",
        "###",
        "###",
        "#### -1E-8>",
        "#####",
        "####",
        "#####",
        "###### 0.>",  # l. 20
        "####### 1.>",
        "########",
        "######",
        "#######",
        "###",
        "###",
        "####",
        "###",
        "#", # l. 30
        "##",
        "###",
        "###",
        "###",
        "##",
        "###",
        "#",
        "##",
        "##",
        "###",  # l. 40
        "###",
        "###",
        "##",
        "###",
        "###",
        "####",
        "#####",
        "####",
        "#####",
        "####",  # l. 50
        "#####",
        "####",
        "#####",
        "###",
        "####",
        "###",
        "###",
        "###",
        "###",
    ]

    if frontend == OMNI:
        # Some string inconsistencies
        ref_lines[15] = ref_lines[15].replace('1E-8', '1e-8')
        ref_lines[35] = ref_lines[35].replace('SQRT', 'sqrt')
        ref_lines[48] = ref_lines[48].replace('PRINT', 'print')
        ref_lines[52] = ref_lines[52].replace('PRINT', 'print')

    cont_index = 27  # line number where line continuation is happening
    ref = '\n'.join(ref_lines)
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    # Test custom indentation
    def line_cont(indent):
        return f'\n{"...":{max(len(indent), 1)}} '

    assert Stringifier(
        style=DefaultStyle(indent_char='#', indent_default=1), line_cont=line_cont
    ).visit(module).strip() == ref.strip()

    # Test default
    ref_lines = ref.strip().replace('#', '  ').splitlines()
    ref_lines[cont_index] = '      '] + ref_lines[cont_index+2:]
    w_ref = '\n'.join(ref_lines)
    assert Stringifier(
        style=DefaultStyle(indent_char='#', indent_default=1, linewidth=44), line_cont=line_cont
    ).visit(module).strip() == w_ref


@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'OMNI fails to read without full module')]))
def test_pprint_select_type(frontend, tmp_path):
    fcode = """
subroutine select_type_routine(arg)
    use type_mod
    implicit none
    class(base), intent(inout) :: arg
    select type( arg )
        class is(derived1)
            print *, 'derived1'
        type is(derived2)
            print *, 'derived2'
        class default
            print *, 'default'
    end select
    print *, 'after select'
end subroutine select_type_routine
    """.strip()

    routine = Subroutine.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    stream = StringIO()
    pprint(routine, stream=stream)
    text = stream.getvalue()
    assert '' in text
    assert '' in text
    assert '' in text
    assert '' in text
loki-ecmwf-0.3.6/loki/backend/tests/test_fgen.py0000664000175000017500000003773715167130205022001 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki import Module, Subroutine, Sourcefile
from loki.backend import fgen
from loki.expression import symbols as sym
from loki.frontend import available_frontends, OMNI
from loki.ir import DataDeclaration
from loki.types import ProcedureType, BasicType


@pytest.mark.parametrize('frontend', available_frontends())
def test_fgen_literal_list_linebreak(frontend, tmp_path):
    """
    Test correct handling of linebreaks for LiteralList expression nodes
    """
    fcode = """
module some_mod
  implicit none
  INTEGER, PARAMETER :: JPRB = SELECTED_REAL_KIND(13,300)
  interface
    subroutine config_gas_optics_sw_spectral_def_allocate_bands_only(a, b)
        INTEGER, PARAMETER :: JPRB = SELECTED_REAL_KIND(13,300)
        real(kind=jprb), intent(in) :: a(:), b(:)
    end subroutine config_gas_optics_sw_spectral_def_allocate_bands_only
  end interface
contains
  subroutine literal_list_linebreak
    implicit none
    real(jprb), parameter, dimension(1,140) :: frac &
        = reshape( (/ 0.21227E+00, 0.18897E+00, 0.25491E+00, 0.17864E+00, 0.11735E+00, 0.38298E-01, 0.57871E-02, &
        &    0.31753E-02, 0.53169E-03, 0.76476E-04, 0.16388E+00, 0.15241E+00, 0.14290E+00, 0.12864E+00, &
        &    0.11615E+00, 0.10047E+00, 0.80013E-01, 0.60445E-01, 0.44918E-01, 0.63395E-02, 0.32942E-02, &
        &    0.54541E-03, 0.15380E+00, 0.15194E+00, 0.14339E+00, 0.13138E+00, 0.11701E+00, 0.10081E+00, &
        &    0.82296E-01, 0.61735E-01, 0.41918E-01, 0.45918E-02, 0.37743E-02, 0.30121E-02, 0.22500E-02, &
        &    0.14490E-02, 0.55410E-03, 0.78364E-04, 0.15938E+00, 0.15146E+00, 0.14213E+00, 0.13079E+00, &
        &    0.11672E+00, 0.10053E+00, 0.81566E-01, 0.61126E-01, 0.41150E-01, 0.44488E-02, 0.36950E-02, &
        &    0.29101E-02, 0.21357E-02, 0.19609E-02, 0.14134E+00, 0.14390E+00, 0.13913E+00, 0.13246E+00, &
        &    0.12185E+00, 0.10596E+00, 0.87518E-01, 0.66164E-01, 0.44862E-01, 0.49402E-02, 0.40857E-02, &
        &    0.32288E-02, 0.23613E-02, 0.15406E-02, 0.58258E-03, 0.82171E-04, 0.29127E+00, 0.28252E+00, &
        &    0.22590E+00, 0.14314E+00, 0.45494E-01, 0.71792E-02, 0.38483E-02, 0.65712E-03, 0.29810E+00, &
        &    0.27559E+00, 0.11997E+00, 0.10351E+00, 0.84515E-01, 0.62253E-01, 0.41050E-01, 0.44217E-02, &
        &    0.36946E-02, 0.29113E-02, 0.34290E-02, 0.55993E-03, 0.31441E+00, 0.27586E+00, 0.21297E+00, &
        &    0.14064E+00, 0.45588E-01, 0.65665E-02, 0.34232E-02, 0.53199E-03, 0.19811E+00, 0.16833E+00, &
        &    0.13536E+00, 0.11549E+00, 0.10649E+00, 0.93264E-01, 0.75720E-01, 0.56405E-01, 0.41865E-01, &
        &    0.59331E-02, 0.26510E-02, 0.40040E-03, 0.32328E+00, 0.26636E+00, 0.21397E+00, 0.14038E+00, &
        &    0.52142E-01, 0.38852E-02, 0.14601E+00, 0.13824E+00, 0.27703E+00, 0.22388E+00, 0.15446E+00, &
        &    0.48687E-01, 0.98054E-02, 0.18870E-02, 0.11961E+00, 0.12106E+00, 0.13215E+00, 0.13516E+00, &
        &    0.25249E+00, 0.16542E+00, 0.68157E-01, 0.59725E-02, 0.49258E+00, 0.33651E+00, 0.16182E+00, &
        &    0.90984E-02, 0.95202E+00, 0.47978E-01, 0.91716E+00, 0.82857E-01, 0.77464E+00, 0.22536E+00 /), (/ 1,140 /) )
    call config_gas_optics_sw_spectral_def_allocate_bands_only( &
         &  [2600.0_jprb, 3250.0_jprb, 4000.0_jprb, 4650.0_jprb, 5150.0_jprb, 6150.0_jprb, 7700.0_jprb, &
         &   8050.0_jprb, 12850.0_jprb, 16000.0_jprb, 22650.0_jprb, 29000.0_jprb, 38000.0_jprb, 820.0_jprb], &
         &  [3250.0_jprb, 4000.0_jprb, 4650.0_jprb, 5150.0_jprb, 6150.0_jprb, 7700.0_jprb, 8050.0_jprb, &
         &   12850.0_jprb, 16000.0_jprb, 22650.0_jprb, 29000.0_jprb, 38000.0_jprb, 50000.0_jprb, 2600.0_jprb])
  end subroutine literal_list_linebreak
end module some_mod
    """.strip()

    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    routine = module['literal_list_linebreak']

    # Make sure all lines are continued correctly
    code = module.to_fortran()
    code_lines = code.splitlines()
    assert len(code_lines) in (35, 36) # OMNI produces an extra line
    assert all(line.strip(' &\n') for line in code_lines)
    assert all(len(line) < 132 for line in code_lines)

    # Make sure it works also with less indentation
    spec_code = fgen(routine.spec)
    assert spec_code.count('&') == 32
    spec_lines = spec_code.splitlines()
    assert len(spec_lines) == 18
    assert all(len(line) < 132 for line in spec_code.splitlines())

    body_code = fgen(routine.body)
    assert body_code.count(',') == 27
    assert body_code.count('(/') == 2
    assert body_code.count('/)') == 2
    assert body_code.count('&') == 6
    body_lines = body_code.splitlines()
    assert len(body_lines) == 4
    assert all(len(line) < 132 for line in body_lines)


@pytest.mark.parametrize('frontend', available_frontends())
def test_character_list_linebreak(frontend, tmp_path):
    fcode = """
module some_mod
  implicit none
  character(len=*), parameter :: IceModelName(0:5) = (/ 'Monochromatic         ', &
       &                                                'Fu-IFS                ', &
       &                                                'Baran-EXPERIMENTAL    ', &
       &                                                'Baran2016             ', &
       &                                                'Baran2017-EXPERIMENTAL', &
       &                                                'Yi                    ' /)
end module some_mod
    """
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    generated_fcode = module.to_fortran()
    for ice_model_name in (
        "'Monochromatic         '",
        "'Fu-IFS                '",
        "'Baran-EXPERIMENTAL    '",
        "'Baran2016             '",
        "'Baran2017-EXPERIMENTAL'",
        "'Yi                    '"
    ):
        assert ice_model_name in generated_fcode


@pytest.mark.parametrize('frontend', available_frontends())
def test_fgen_data_stmt(frontend):
    """
    Test correct formatting of data declaration statements
    """
    fcode = """
subroutine data_stmt
    implicit none
    INTEGER, PARAMETER :: JPRB = SELECTED_REAL_KIND(13,300)
    REAL(KIND=JPRB) :: ZAMD
    INTEGER :: KXINDX(35)
    data ZAMD   /  28.970_JPRB    /
    DATA KXINDX /0,2,3,0,31*0/
end subroutine data_stmt
    """.strip()

    routine = Subroutine.from_source(fcode, frontend=frontend)
    assert isinstance(routine.spec.body[-1], DataDeclaration)
    spec_code = fgen(routine.spec)
    assert spec_code.lower().count('data ') == 2
    assert spec_code.count('/') == 4
    if frontend != OMNI:
        # OMNI seems to evaluate constant expressions, replacing 31*0 by 0,
        # although it's not a product here but a repeat specifier (great job, Fortran!)
        assert '31*0' in spec_code


@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'Loki likes only valid code')]))
def test_multiline_inline_conditional(frontend):
    """
    Test correct formatting of an inline :any:`Conditional` that
    contains a multi-line :any:`CallStatement`.
    """
    fcode = """
subroutine test_fgen(DIMS, ZSURF_LOCAL)
  type(DIMENSION_TYPE), intent(in) :: DIMS
  type(SURFACE_TYPE), intent(inout) :: ZSURF_LOCAL
  type(STATE_TYPE) :: TENDENCY_LOC
contains
subroutine test_inline_multiline(KDIMS, LBUD23)

  DO JKGLO=1,NGPTOT,NPROMA
    ! Add saturation adjustment tendencies to cloud scheme (LBUD23)
    IF (LBUD23) CALL UPDATE_FIELDS(YDPHY2,1,DIMS%KIDIA,DIMS%KFDIA,DIMS%KLON,DIMS%KLEV,&
     & PTA1=TENDENCY_LOC%T, PO1=ZSURF_LOCAL%GSD_XA%PGROUP(:,:,19),&
     & PTA2=TENDENCY_LOC%Q, PO2=ZSURF_LOCAL%GSD_XA%PGROUP(:,:,20),&
     & LDV3=YGFL%YL%LT1, PTA3=TENDENCY_LOC%CLD(:,:,NCLDQL), PO3=ZSURF_LOCAL%GSD_XA%PGROUP(:,:,21),&
     & LDV4=YGFL%YI%LT1, PTA4=TENDENCY_LOC%CLD(:,:,NCLDQI), PO4=ZSURF_LOCAL%GSD_XA%PGROUP(:,:,22))
  ENDDO
end subroutine test_inline_multiline
end subroutine test_fgen
    """.strip()
    routine = Subroutine.from_source(fcode, frontend=frontend)

    out = fgen(routine)
    for line in out.splitlines():
        assert line.count('&') <= 2
        if line.count('&') == 2:
            assert len(line.split('&')[1]) > 60


@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'Loki likes only valid code')]))
def test_multiline_inline_conditional_long(frontend):
    """
    Test correct formatting of an inline :any:`Conditional` that
    that creates a particularly long line.
    """
    fcode = """
subroutine test_inline_multiline_long(array, flag)
  real, intent(inout) :: array
  logical, intent(in) :: flag

  if (flag) call a_subroutine_with_an_exquisitely_loong_and_expertly_chosen_name_and_a_few_keyword_arguments(my_favourite_array=array)
end subroutine test_inline_multiline_long
    """.strip()
    routine = Subroutine.from_source(fcode, frontend=frontend)

    out = fgen(routine)
    for line in out.splitlines():
        assert len(line) < 132
        assert line.count('&') <= 2


@pytest.mark.parametrize('frontend', available_frontends())
def test_fgen_save_attribute(frontend, tmp_path):
    """
    Make sure the SAVE attribute on declarations is preserved (#164)
    """
    fcode = """
MODULE test
    INTEGER, SAVE :: variable
END MODULE test
    """.strip()
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    assert module['variable'].type.save is True
    assert len(module.declarations) == 1
    assert 'SAVE' in fgen(module.declarations[0])
    assert 'SAVE' in module.to_fortran()


@pytest.mark.parametrize('frontend', available_frontends())
def test_fgen_protected_attribute(frontend, tmp_path):
    """
    Make sure the PROTECTED attribute on declarations is preserved (#506).

    This test mimics the `test_fgen_save_attribute` test.
    """
    fcode = """
MODULE test
    INTEGER, PROTECTED :: variable
END MODULE test
    """.strip()
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    assert module['variable'].type.protected is True
    assert len(module.declarations) == 1
    assert 'PROTECTED' in fgen(module.declarations[0])
    assert 'PROTECTED' in module.to_fortran()


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('external_decl', ('real :: x\n    external x', 'real, external :: x'))
@pytest.mark.parametrize('body', ('', 'y = x()'))
def test_fgen_external_procedure(frontend, external_decl, body):
    fcode = f"""
SUBROUTINE DRIVER
    implicit none
    real :: y
    {external_decl}
    {body}
END SUBROUTINE DRIVER
    """.strip()
    routine = Subroutine.from_source(fcode, frontend=frontend)
    x = routine.variable_map['x']
    assert x.type.external
    assert isinstance(x.type.dtype, ProcedureType)
    assert x.type.dtype.return_type.dtype == BasicType.REAL
    assert isinstance(x, (sym.Scalar, sym.ProcedureSymbol))
    assert 'real, external :: x' in routine.to_fortran().lower()


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('use_module', (True, False))
def test_fgen_procedure_pointer(frontend, use_module, tmp_path):
    """
    Test correct code generation for procedure pointers

    This was reported in #393
    """
    fcode_module = """
MODULE SPSI_MODNEW
IMPLICIT NONE
INTERFACE
    REAL FUNCTION SPNSI ()
    END FUNCTION SPNSI
END INTERFACE
END MODULE SPSI_MODNEW
    """.strip()

    fcode = """
SUBROUTINE SPCMNEW(FUNC)
USE SPSI_MODNEW, ONLY : SPNSI
IMPLICIT NONE
PROCEDURE(SPNSI), POINTER :: SPNSIPTR
PROCEDURE(REAL), POINTER, INTENT(OUT) :: FUNC
FUNC => SPNSIPTR
END SUBROUTINE SPCMNEW
    """.strip()

    if frontend == OMNI and not use_module:
        pytest.skip('Parsing without module definitions impossible in OMNI')

    definitions = []
    if use_module:
        module = Sourcefile.from_source(fcode_module, frontend=frontend, xmods=[tmp_path])
        definitions.extend(module.definitions)
    source = Sourcefile.from_source(fcode, frontend=frontend, definitions=definitions, xmods=[tmp_path])
    routine = source['spcmnew']
    ptr = routine.variable_map['spnsiptr']
    func = routine.variable_map['func']

    # Make sure we always have procedure type as dtype for the declared pointers
    assert isinstance(ptr.type.dtype, ProcedureType)
    assert isinstance(func.type.dtype, ProcedureType)

    # We should have the inter-procedural annotation in place if the module
    # definition was provided
    if use_module:
        assert ptr.type.dtype.procedure is module['spnsi'].body[0]
    else:
        assert ptr.type.dtype.procedure == BasicType.DEFERRED

    # With an implicit interface routine like this, we will never have
    # procedure information in place
    assert func.type.dtype.procedure == BasicType.DEFERRED
    assert func.type.dtype.return_type.dtype == BasicType.REAL

    # Check the interfaces declared on the variable declarations
    assert tuple(decl.interface for decl in routine.declarations) == ('SPNSI', BasicType.REAL)

    # Ensure that the fgen backend does the right thing
    assert 'procedure(spnsi), pointer :: spnsiptr' in source.to_fortran().lower()
    assert 'procedure(real), pointer, intent(out) :: func' in source.to_fortran().lower()


@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'SELECT TYPE not implemented')]))
def test_fgen_select_type(frontend, tmp_path):
    fcode_module = """
module select_type_mod
    implicit none

    type, abstract :: base_type
    end type base_type

    type, extends(base_type) :: my_type
        integer :: val
    contains
        procedure :: get_val_ptr => get_val_ptr_my
    end type my_type

    type, extends(base_type) :: other_type
        integer :: val(2)
    contains
        procedure :: get_val_ptr => get_val_ptr_other
    end type other_type

    type :: container_type
        type(base_type), allocatable :: arr(:)
        integer :: n_arr
    end type container_type

contains

    subroutine get_val_ptr_my(self, ptr)
        class(my_type), intent(inout) :: self
        integer, pointer, intent(out) :: ptr

        ptr => self%val
    end subroutine get_val_ptr_my

    subroutine get_val_ptr_other(self, ptr)
        class(other_type), intent(inout) :: self
        integer, pointer, intent(out) :: ptr(:)

        ptr => self%val
    end subroutine get_val_ptr_other

    subroutine my_routine(container)
        use, intrinsic :: iso_fortran_env, only: error_unit

        type(container_type), intent(inout) :: container
        type(my_type), pointer :: my_t => null()
        type(other_type), pointer :: other_t => null()
        integer, pointer :: my_ptr => null()
        integer, pointer :: other_ptr(:) => null()
        integer :: jj

        do jj=1,self%n_arr
            associate( t_ptr => self%arr(j) )
                select type( t_ptr )
                    class is( my_type )
                        my_t => t_ptr
                        call my_t%get_val_ptr(my_ptr)
                    class is( other_type )
                        other_t => t_ptr
                        call other_t%get_val_ptr(other_ptr)
                    class default
                        write(error_unit, '(a)') 'unexpected class for t_ptr'
                        call abor1('error')
                end select

                named_cond: select type( t_ptr )
                    type is( my_type ) named_cond
                        print *, 'my_type'
                end select named_cond
            end associate
        end do
    end subroutine my_routine
end module select_type_mod
    """.strip()

    module = Sourcefile.from_source(fcode_module, frontend=frontend, xmods=[tmp_path])
    code = module.to_fortran()

    assert 'SELECT TYPE (t_ptr)' in code
    assert 'CLASS IS (my_type)' in code
    assert 'CLASS DEFAULT' in code
    assert 'named_cond: SELECT TYPE (t_ptr)' in code
    assert 'TYPE IS (my_type) named_cond' in code

    reparsed = Sourcefile.from_source(code, frontend=frontend, xmods=[tmp_path])
    assert reparsed.to_fortran() == module.to_fortran()
loki-ecmwf-0.3.6/loki/backend/tests/test_conservative.py0000664000175000017500000002652115167130205023557 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki import Subroutine, Module, Sourcefile, config_override
from loki.backend import fgen
from loki.ir import nodes as ir, FindNodes, SubstituteExpressions
from loki.frontend import FP, OMNI, REGEX, SourceStatus, available_frontends


def test_fgen_conservative_routine():
    fcode = """
SUBROUTINE MY_TEST_ROUTINE( N,DAVE)
  USE MY_MOD, ONLY : AKIND, RTYPE
  IMPLICIT NONE
  INTEGER,          INTENT(IN) :: N
  tyPe(RTYPE)                 :: RICK ! CAN'T MAKE ARGUMENT YET!
  REAL(KIND=AKIND), INTENT(INOUT) :: DAVE(N)
  REAL(KIND=AKIND) :: TMP
  INTEGER :: I

  DO I=1, N
    IF (  DAVE(I)    > 0.5) THEN
      ! Who is DAVE = ...
      TMP = RICK%A
      DAvE( I)   = RICK%A

         ! BUT ALSO...
         RICK%B = DaVe(  i)
       ELSE
          ! ... AND ...
            DaVE( I ) = 66.6
    END IF

      ! BECAUSE DAVE WILL ...
      CALL  NEVER_GONNA ( DAVE%YOU_UP   )
  END DO
END SUBROUTINE   MY_TEST_ROUTINE
"""
    with config_override({'frontend-store-source': True}):
        routine = Subroutine.from_source(fcode, frontend=FP)

    # Check the untouched output of a few noes
    s_routine = fgen(routine, conservative=True)
    assert fcode.strip() == s_routine.strip()

    str_spec = fgen(routine.spec, conservative=True)
    exp_spec = '\n'.join(fcode.splitlines()[2:9])
    assert exp_spec.strip() == str_spec.strip()

    str_body = fgen(routine.body, conservative=True)
    exp_body = '\n'.join(fcode.splitlines()[9:26])
    assert exp_body.strip() == str_body.strip()

    str_loop = fgen(FindNodes(ir.Loop).visit(routine.body), conservative=True)
    exp_loop = '\n'.join(fcode.splitlines()[10:26])
    assert exp_loop == str_loop

    # Use `SubstituteExpressions` to replace RICk with BOB, including in spec!
    decls = FindNodes(ir.VariableDeclaration).visit(routine.spec)
    assigns = FindNodes(ir.Assignment).visit(routine.body)
    rick = routine.variable_map['RICK']
    bob = routine.Variable(name='BOB', type=rick.type)  # This replicates type info!
    sub_rick = SubstituteExpressions({rick: bob}, invalidate_source=True)

    routine.spec = sub_rick.visit(routine.spec)
    routine.body = sub_rick.visit(routine.body)

    routine.source.status = SourceStatus.INVALID_CHILDREN
    assert routine.spec.source.status == SourceStatus.INVALID_CHILDREN
    assert routine.body.source.status == SourceStatus.INVALID_CHILDREN

    decls = FindNodes(ir.VariableDeclaration).visit(routine.spec)
    assert 'bob' in decls[1].symbols and not decls[1].source.status == SourceStatus.VALID
    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert 'bob%a' == assigns[0].rhs and not assigns[0].source.status == SourceStatus.VALID
    assert 'bob%b' == assigns[2].lhs and not assigns[2].source.status == SourceStatus.VALID

    # And now check the actual output formatting...
    expected = fcode.replace('RICK', 'BOB').strip()
    generated = fgen(routine, conservative=True).strip()
    assert generated == expected


def test_fgen_conservative_module():
    fcode_type = """
MODULE MY_MOD
  INTEGER, PARAMETER :: AKIND = 8
  TYPE RTYPE
    real(kind = 8) ::   DaVe
  END TYPE RTYPE
END MODULE   MY_MOD
"""

    fcode = """
MODULE MY_TEST_MOD
  USE MY_MOD, ONLY: AKIND, RTYPE
  IMPLICIT NONE

  REAL(KIND=AKIND) ::    DaVE(  5 )

  CONTAINS

  SUBROUTINE A_SHORT_ROUTINE( N,DAVE)
    INTEGER,          INTENT(IN) :: N
    REAL(KIND=AKIND), INTENT(INOUT) :: DAVE(N)

    DaVE(:) = DaVE(:) + 2.0
  END SUBROUTINE   A_SHORT_ROUTINE
END MODULE MY_TEST_MOD
"""
    with config_override({'frontend-store-source': True}):
        type_mod = Module.from_source(fcode_type, frontend=FP)
        module = Module.from_source(fcode, frontend=FP)

    routine = module['a_short_routine']

    # Check modules can be re-created string-identically
    s_type_mod = fgen(type_mod, conservative=True)
    assert fcode_type.strip() == s_type_mod.strip()
    s_module = fgen(module, conservative=True)
    assert fcode.strip() == s_module.strip()

    # Type Module: Use `SubstituteExpressions` to replace AKIND with BKIND
    akind = FindNodes(ir.VariableDeclaration).visit(type_mod.spec)[0].symbols[0]
    assert akind == 'AKIND'
    sub_expr = SubstituteExpressions(
        {akind: akind.clone(name='BKIND')}, invalidate_source=True
    )
    type_mod.spec = sub_expr.visit(type_mod.spec)
    type_mod.source.status = SourceStatus.INVALID_CHILDREN
    assert type_mod.spec.source.status == SourceStatus.INVALID_CHILDREN

    # Type Module: Check that substitutions have invalidated relevant nodes
    decls = FindNodes(ir.VariableDeclaration).visit(type_mod.spec)
    assert 'bkind' in decls[0].symbols and not decls[0].source.status == SourceStatus.VALID

    # Type Module: Check the actual output formatting of type module
    type_mod_expected = fcode_type.replace('AKIND', 'BKIND').strip()
    type_mod_generated = fgen(type_mod, conservative=True).strip()
    assert type_mod_generated == type_mod_expected

    # Main Module: Use `SubstituteExpressions` to replace AKIND with BKIND in main module
    akind = FindNodes(ir.Import).visit(module.spec)[0].symbols[0]
    assert akind == 'AKIND'
    sub_expr = SubstituteExpressions(
        {akind: akind.clone(name='BKIND')}, invalidate_source=True
    )
    module.spec = sub_expr.visit(module.spec)
    module.source.status = SourceStatus.INVALID_CHILDREN
    assert module.spec.source.status == SourceStatus.INVALID_CHILDREN
    # TODO: Change routine directly, as Transformer does not recurse into program unit yet
    routine.spec = sub_expr.visit(module['a_short_routine'].spec)
    routine.source.status = SourceStatus.INVALID_CHILDREN
    assert routine.spec.source.status == SourceStatus.INVALID_CHILDREN

    module.contains.source.status = SourceStatus.INVALID_CHILDREN

    # Main Module: Check that substitutions have invalidated relevant nodes
    decls_m = FindNodes(ir.VariableDeclaration).visit(module.spec)
    assert len(decls_m) == 1
    assert decls_m[0].symbols[0] == 'dave(5)'
    assert decls_m[0].symbols[0].type.kind == 'bkind'
    assert decls_m[0].source.status == SourceStatus.INVALID_NODE

    # Main Module: Check the actual output formatting
    module_expected = fcode.replace('AKIND', 'BKIND').strip()
    module_generated = fgen(module, conservative=True).strip()
    assert module_generated == module_expected


@pytest.mark.parametrize('frontend', available_frontends(
    include_regex=True, skip=[(OMNI, 'OMNI is not string-conservative')]
))
def test_fgen_conservative_sourcefile(frontend):
    """ Test outer program unit conservation via `ir.Section` and REGEX frontend """

    fcode = """
subroutine some_routine
implicit none
end subroutine some_routine

subroutine OTHER_ROUTINE
implicit none
call some_routine
end subroutine OTHER_ROUTINE
"""
    with config_override({'frontend-store-source': True}):
        sourcefile = Sourcefile.from_source(fcode, frontend=frontend)

    assert sourcefile.source
    assert sourcefile.ir.source

    routines = sourcefile.routines
    assert len(routines) == 2
    assert routines[0].source
    if frontend == REGEX:
        sourcefile.routines[0].make_complete()
        sourcefile.routines[1].make_complete()
    assert routines[0].spec.source
    assert routines[0].body.source
    assert routines[1].source
    assert routines[1].spec.source
    assert routines[1].body.source

    # Modify the subroutine objects only
    routines[0].name = routines[0].name.upper()
    routines[0].source.status = SourceStatus.INVALID_NODE

    routines[1].name = routines[1].name.lower()
    routines[1].source.status = SourceStatus.INVALID_NODE

    # Ensure only header/footer are changed
    assert routines[0].to_fortran(conservative=True) == """
SUBROUTINE SOME_ROUTINE ()
implicit none

END SUBROUTINE SOME_ROUTINE
""".strip()

    assert routines[1].to_fortran(conservative=True) == """
SUBROUTINE other_routine ()
implicit none
call some_routine
END SUBROUTINE other_routine
""".strip()


def test_fgen_conservative_rebuild():
    """ Test that triggers a near complete re-build """

    fcode = """
MODULE MY_TEST_MOD
  use type_mod, only: akind, ikind, rtype
  ! use func_mod, only: my_func
  implicit none

  REAL(KIND=AKIND) ::    DaVE(  5 )

  CONTAINS

  SUBROUTINE A_SHORT_ROUTINE( N, DAVE)
    INTEGER,          INTENT(IN) :: N
    REAL(KIND=AKIND), INTENT(INOUT) :: DAVE(N)
    integer( kind =ikind) :: i

    DaVE( : ) = DaVE(:) + 2.0
    do    i=1, n
      if   (  DaVe(i) ==    0.0)  then
        dave (i) = 3.0
      end  if
    enddo

    CALL My_Func(n, daVE(1:n))
  END SUBROUTINE   A_SHORT_ROUTINE
END MODULE MY_TEST_MOD
"""
    with config_override({'frontend-store-source': True}):
        module = Module.from_source(fcode, frontend=FP)
    routine = module['a_short_routine']

    # Change nearly every line to trigger full re-build
    smap = module.imported_symbol_map
    vmap = module.variable_map
    subs_module = SubstituteExpressions(
        {
            smap['ikind']: smap['ikind'].clone(name='ikinder'),
            smap['akind']: smap['akind'].clone(name='akinder'),
            vmap['dave'].symbol: vmap['dave'].symbol.clone(name='rick')
        }, invalidate_source=True
    )
    module.spec = subs_module.visit(module.spec)
    module.contains = subs_module.visit(module.contains)
    assert module.spec.source.status == SourceStatus.INVALID_CHILDREN
    assert module.contains.source.status == SourceStatus.INVALID_CHILDREN
    module.name = 'A_NEW_MOD'
    module.source.status = SourceStatus.INVALID_NODE

    vmap = routine.variable_map
    subs_routine = SubstituteExpressions(
        {
            vmap['dave'].symbol: vmap['dave'].symbol.clone(name='rick'),
            vmap['n']: vmap['n'].clone(name='m'),
            vmap['i']: vmap['i'].clone(name='j'),
        }, invalidate_source=True
    )
    routine.spec = subs_routine.visit(routine.spec)
    routine.body = subs_routine.visit(routine.body)
    assert routine.spec.source.status == SourceStatus.INVALID_CHILDREN
    assert routine.body.source.status == SourceStatus.INVALID_CHILDREN
    routine.name = 'A_CHANGED_ROUTINE'
    routine.source.status = SourceStatus.INVALID_NODE

    # Check that changes have indeed invalidated nodes
    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assigns) == 2
    assert assigns[0].lhs == 'rick(:)' and assigns[0].rhs == 'rick(:) + 2.0'
    assert assigns[0].source.status == SourceStatus.INVALID_NODE
    assert assigns[1].lhs == 'rick(j)' and assigns[1].rhs == '3.0'
    assert assigns[1].source.status == SourceStatus.INVALID_NODE

    loops = FindNodes(ir.Loop).visit(routine.body)
    assert len(loops) == 1
    assert loops[0].variable == 'j' and loops[0].bounds == '1:m'
    assert loops[0].source.status == SourceStatus.INVALID_NODE

    conds = FindNodes(ir.Conditional).visit(routine.body)
    assert len(conds) == 1
    assert conds[0].condition == 'rick(j) == 0.0'
    assert conds[0].source.status == SourceStatus.INVALID_NODE

    # Check that fully generated and conservative agree
    routine_expected = fgen(module, conservative=False).strip()
    routine_generated = fgen(module, conservative=True).strip()
    assert routine_generated == routine_expected
loki-ecmwf-0.3.6/loki/backend/tests/test_cufgen.py0000664000175000017500000000751415167130205022317 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest
from pydantic import ValidationError

from loki import Module
from loki.ir import nodes as ir, FindNodes, Transformer
from loki.expression import symbols as sym
from loki.frontend import available_frontends


@pytest.mark.parametrize('frontend', available_frontends())
def test_cufgen(frontend, tmp_path):
    """
    A simple test routine to test the Cuda Fortran (CUF) backend
    """

    fcode = """
module transformation_module_cufgen
  implicit none
  integer, parameter :: len = 10
contains

  subroutine driver(a, b, c)
    integer, intent(inout) :: a
    integer, intent(inout) :: b(len)
    integer, intent(inout) :: c(a, len)
    integer :: var_device
    integer :: var_managed
    integer :: var_constant
    integer :: var_shared
    integer :: var_pinned
    integer :: var_texture
    call kernel(a, b)
  end subroutine driver

  subroutine kernel(a, b)
    integer, intent(inout) :: a
    integer, intent(inout) :: b(len)
    real :: x(a)
    real :: k2_tmp(a, a)
    call device1(x, k2_tmp)
  end subroutine kernel

  subroutine device(x, y)
    real, intent(inout) :: x(len)
    real, intent(inout) :: y(len, len)
  end subroutine device

end module transformation_module_cufgen
"""

    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    driver = module['driver']
    kernel = module['kernel']
    device_subroutine = module['device']

    assert driver
    assert module.to_fortran(cuf=True) == module.to_fortran()

    for var in driver.variables:
        if "device" in var.name:
            var.type = var.type.clone(device=True)
        if "managed" in var.name:
            var.type = var.type.clone(managed=True)
        if "constant" in var.name:
            var.type = var.type.clone(constant=True)
        if "shared" in var.name:
            var.type = var.type.clone(shared=True)
        if "pinned" in var.name:
            var.type = var.type.clone(pinned=True)
        if "texture" in var.name:
            var.type = var.type.clone(texture=True)

    call_map = {}
    for call in FindNodes(ir.CallStatement).visit(driver.body):
        if "kernel" in str(call.name):
            with pytest.raises(ValidationError):
                _ = call.clone(chevron=(sym.IntLiteral(1), sym.IntLiteral(1), sym.IntLiteral(1), sym.IntLiteral(1),
                                        sym.IntLiteral(1)))
            with pytest.raises(ValidationError):
                _ = call.clone(chevron=(1, 1))
            with pytest.raises(ValidationError):
                _ = call.clone(chevron=2)

            call_map[call] = call.clone(chevron=(sym.IntLiteral(1), sym.IntLiteral(1),
                                                 sym.IntLiteral(1), sym.IntLiteral(1)))

    driver.body = Transformer(call_map).visit(driver.body)

    kernel.prefix = ("ATTRIBUTES(GLOBAL)",)
    device_subroutine.prefix = ("ATTRIBUTES(DEVICE)",)

    cuf_driver_str = driver.to_fortran(cuf=True)
    cuf_kernel_str = kernel.to_fortran(cuf=True)
    cuf_device_str = device_subroutine.to_fortran(cuf=True)

    assert "INTEGER, DEVICE" in cuf_driver_str
    assert "INTEGER, MANAGED" in cuf_driver_str
    assert "INTEGER, CONSTANT" in cuf_driver_str
    assert "INTEGER, SHARED" in cuf_driver_str
    assert "INTEGER, PINNED" in cuf_driver_str
    assert "INTEGER, TEXTURE" in cuf_driver_str

    assert "<<<" in cuf_driver_str and ">>>" in cuf_driver_str

    assert "ATTRIBUTES(GLOBAL) SUBROUTINE kernel" in cuf_kernel_str
    assert "ATTRIBUTES(DEVICE) SUBROUTINE device" in cuf_device_str
loki-ecmwf-0.3.6/loki/backend/dacegen.py0000664000175000017500000001025515167130205020231 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.backend.pygen import PyCodegen
from loki.backend.style import DefaultStyle

from loki.expression import symbols as sym, ExpressionRetriever
from loki.ir import is_loki_pragma
from loki.tools import OrderedSet
from loki.types import BasicType

__all__ = ['dacegen', 'DaceCodegen']


def dace_type(_type):
    if _type.dtype == BasicType.LOGICAL:
        return 'dace.bool'
    if _type.dtype == BasicType.INTEGER:
        return 'dace.int32'
    if _type.dtype == BasicType.REAL:
        if str(_type.kind) in ('real32',):
            return 'dace.float32'
        return 'dace.float64'
    raise ValueError(str(_type))


class DaceCodegen(PyCodegen):
    """
    Tree visitor that extends `PyCodegen` with Dace-specific language variations.
    """

    # Handler for outer objects

    def visit_Subroutine(self, o, **kwargs):
        """
        Format as:
            ...imports...
            def ():
                ...spec without imports and only declarations with initial values...
                ...body...
        """
        # Some boilerplate imports...
        standard_imports = ['dace', 'numpy as np']
        header = [self.format_line('import ', name) for name in standard_imports]

        # ...and imports from the spec
        # TODO

        # Generate header with argument signature
        retriever = ExpressionRetriever(lambda e: isinstance(e, sym.Scalar))
        symbols = OrderedSet()
        for arg in o.arguments:
            if isinstance(arg, sym.Array):
                shape_vars = retriever.retrieve(arg.shape)
                symbols |= OrderedSet(v.name.lower() for v in shape_vars)
        arguments = [f'{arg.name.lower()}: {self.visit(arg.type, **kwargs)}'
                     for arg in o.arguments if arg.name.lower() not in symbols]
        header += [self.format_line('{name} = dace.symbol("{name}")'.format(name=s))
                   for s in symbols]
        header += [self.format_line('@dace.program')]
        header += [self.format_line('def ', o.name.lower(), '(', self.join_items(arguments), '):')]

        # ...and generate the spec without imports and only declarations with initial value
        self.depth += 1
        body = [self.visit(o.spec, **kwargs)]

        # Fill the body and close everything off
        body += [self.visit(o.body, **kwargs)]
        self.depth -= 1

        return self.join_lines(*header, *body)

    def visit_Module(self, o, **kwargs):
        raise NotImplementedError()

    # Handler for IR nodes

    def visit_Loop(self, o, **kwargs):
        """
        Format loop with explicit range as
          for  in range(,  + , ):
            ...body...
        """
        if not is_loki_pragma(o.pragma, starts_with='dataflow'):
            return super().visit_Loop( o, **kwargs)

        var = self.visit(o.variable, **kwargs)
        start = self.visit(o.bounds.start, **kwargs)
        end = self.visit(o.bounds.stop, **kwargs)
        if o.bounds.step:
            incr = self.visit(o.bounds.step, **kwargs)
            cntrl = f'dace.map[{start}:{end}+{incr}:{incr}]'
        else:
            cntrl = f'dace.map[{start}:{end}+1]'
        header = self.format_line('for ', var, ' in ', cntrl, ':')
        self.depth += 1
        body = self.visit(o.body, **kwargs)
        self.depth -= 1
        return self.join_lines(header, body)

    def visit_SymbolAttributes(self, o, **kwargs):
        dtype = dace_type(o)
        shape = ''
        if o.shape is not None:
            dims = [self.visit(dim, **kwargs) for dim in o.shape]
            shape = f'[{", ".join(d for d in dims if d)}]'
        return f'{dtype}{shape}'



def dacegen(ir):
    """
    Generate standard Python 3 code with Dace-specializations (and Numpy) from one
    or many IR objects/trees.
    """
    return DaceCodegen(style=DefaultStyle(linewidth=300)).visit(ir)
loki-ecmwf-0.3.6/loki/backend/fgencon.py0000664000175000017500000002440415167130205020263 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.backend.fgen import FortranCodegen
from loki.frontend.source import SourceStatus
from loki.tools.util import as_tuple


__all__ = ['FortranCodegenConservative']


class FortranCodegenConservative(FortranCodegen):
    """
    Strictly conservative version of :any:`FortranCodegen` visitor
    that will attempt to use existing :any:`Source` information from
    the frontends where possible.
    """

    def visit_Node(self, o, *args, **kwargs):
        if o.source and o.source.status == SourceStatus.VALID:
            return o.source.string
        return super().visit_Node(o, *args, **kwargs)

    def visit_Assignment(self, o, *args, **kwargs):
        if o.source and o.source.status == SourceStatus.VALID:
            return o.source.string

        if o.source and o.source.status == SourceStatus.INVALID_NODE:
            # Attempt to recreate some structure by locating `=`
            slist = o.source.string.split('=', maxsplit=1)
            assert len(slist) == 2

            # If LHS hasn't changed, prefer source
            lhs = self.visit(o.lhs, **kwargs) + ' '
            if slist[0] == o.lhs:
                lhs = slist[0]
            else:
                # At least prescribe previous indent...
                ind = len(o.source.string) - len(o.source.string.lstrip())
                lhs = ind*' ' + lhs

            # If RHS hasn't changed, prefer source
            rhs = ' ' + self.visit(o.rhs, **kwargs)
            if slist[1] == o.rhs:
                rhs = slist[1]

            comment = str(self.visit(o.comment, **kwargs)) if o.comment else ''
            if o.ptr:
                return self.format_line(lhs, '=>', rhs, comment=comment)

            return self.format_line(lhs, '=', rhs, comment=comment, no_indent=True)

        return super().visit_Assignment(o, *args, **kwargs)

    def visit_CallStatement(self, o, *args, **kwargs):
        if o.source and o.source.status == SourceStatus.VALID:
            return o.source.string
        return super().visit_CallStatement(o, *args, **kwargs)

    def visit_Comment(self, o, *args, **kwargs):
        if o.source and o.source.status == SourceStatus.VALID:
            slist = o.source.string.split('!', maxsplit=1)
            pre, txt = (slist[0], slist[1]) if len(slist) > 1 else (None, slist[0])
            if pre and not pre.isspace():
                # An inline comment, only return text and leading whitespace
                ws = ' '*(len(pre)-len(pre.rstrip(' ')))
                return f'{ws}!{txt}'
            return o.source.string
        return super().visit_Comment(o, *args, **kwargs)

    def visit_Conditional(self, o, *args, **kwargs):
        if o.source and o.source.status == SourceStatus.VALID:
            return o.source.string

        if o.source and o.source.status == SourceStatus.INVALID_CHILDREN:
            if o.inline:
                # TODO: Deal with inline conditionals properly
                return super().visit_Conditional(o, *args, **kwargs)

            header = o.source.string.splitlines()[0]

            self.depth += self.style.conditional_indent
            body = self.visit(o.body, **kwargs)
            if o.has_elseif:
                self.depth -= self.style.conditional_indent
                else_body = [self.visit(o.else_body, is_elseif=True, **kwargs)]
            else:
                else_body = [self.visit(o.else_body, **kwargs)]
                self.depth -= self.style.conditional_indent
                if o.else_body:
                    # Get the `ELSE` from source to get its indentation
                    elseline = [
                        s for s in o.source.string.splitlines()
                        if s.upper().strip() == 'ELSE'
                    ]
                    else_body = [elseline[-1]] + else_body

                # Recapture the footer line from source
                footer = o.source.string.splitlines()[o.source.lines[1]-o.source.lines[0]]
                else_body += [footer]

            return self.join_lines(header, body, *else_body)

        return super().visit_Conditional(o, *args, **kwargs)

    def visit_VariableDeclaration(self, o, *args, **kwargs):
        if o.source and o.source.status == SourceStatus.VALID:
            return o.source.string

        if o.source and o.source.status == SourceStatus.INVALID_NODE:
            no_indent = False

            # Attempt to recreate some structure by locating `::`
            slist = o.source.string.split('::', maxsplit=1)
            assert len(slist) == 2

            # If type attributes haven't changed, prefer source (and don't indent)
            attributes = str(self.join_items(self._construct_type_attributes(o, **kwargs))) + ' '
            # TODO: Type attributes don't string compare cleanly yet; needs fixing!
            if slist[0].strip().lower() == attributes.strip().lower():
                attributes = slist[0]
                no_indent = True

            # If declared variables haven't changed, prefer source
            variables = ' ' + str(self.join_items(self._construct_decl_variables(o, **kwargs)))
            if o.symbols == as_tuple(slist[1].split(',')):
                variables = slist[1]

            comment = str(self.visit(o.comment, **kwargs)) if o.comment else ''

            return self.format_line(attributes, '::', variables, comment, no_indent=no_indent)

        return super().visit_VariableDeclaration(o, *args, **kwargs)

    def visit_Import(self, o, *args, **kwargs):
        if o.source and o.source.status == SourceStatus.VALID:
            return o.source.string
        return super().visit_Import(o, *args, **kwargs)

    def visit_Loop(self, o, *args, **kwargs):
        if o.source and o.source.status == SourceStatus.VALID:
            return o.source.string

        if o.source and o.source.status == SourceStatus.INVALID_CHILDREN:
            # Recapture header and footer from source
            header = o.source.string.splitlines()[0]
            footer = o.source.string.splitlines()[o.source.lines[1]-o.source.lines[0]]

            pragma = self.visit(o.pragma, **kwargs)
            pragma_post = self.visit(o.pragma_post, **kwargs)
            self.depth += self.style.loop_indent
            body = self.visit(o.body, **kwargs)
            self.depth -= self.style.loop_indent
            return self.join_lines(pragma, header, body, footer, pragma_post)

        return super().visit_Loop(o, *args, **kwargs)

    def visit_Section(self, o, *args, **kwargs):
        if o.source and o.source.status == SourceStatus.VALID:
            return o.source.string
        return super().visit_Section(o, *args, **kwargs)

    def visit_Subroutine(self, o, *args, **kwargs):
        if o.source and o.source.status == SourceStatus.VALID:
            return o.source.string

        if o.source and o.source.status == SourceStatus.INVALID_CHILDREN:
            # Re-construct header and footer from source if possible
            h_end = o.body.source.lines[0] if o.body.source else o.source.lines[1]
            h_end = min(h_end, o.spec.source.lines[0]) if o.spec.source else h_end
            if o.docstring:
                h_end = min(h_end, o.docstring[0].source.lines[0])

            if h_end < o.source.lines[1]:
                header = '\n'.join(o.source.string.splitlines()[:h_end-o.source.lines[0]])
            else:
                header = self._construct_subroutine_header(o, **kwargs)

            # For one-line footers reconstruct from source
            foot = o.source.string.splitlines()[o.source.lines[1]-o.source.lines[0]]
            if 'END ' in foot.upper():
                footer = foot
            else:
                footer = self._construct_procedure_footer(o, **kwargs)

            self.depth += self.style.procedure_spec_indent
            docstring = self.visit(o.docstring, **kwargs)
            spec = self.visit(o.spec, **kwargs)
            self.depth -= self.style.procedure_spec_indent

            self.depth += self.style.procedure_body_indent
            body = self.visit(o.body, **kwargs)
            self.depth -= self.style.procedure_body_indent

            self.depth += self.style.procedure_contains_indent
            contains = self.visit(o.contains, **kwargs)
            self.depth -= self.style.procedure_contains_indent
            if contains:
                return self.join_lines(header, docstring, spec, body, contains, footer)

            return self.join_lines(header, docstring, spec, body, footer)

        return super().visit_Subroutine(o, *args, **kwargs)

    def visit_Module(self, o, *args, **kwargs):
        if o.source and o.source.status == SourceStatus.VALID:
            return o.source.string

        if o.source and o.source.status == SourceStatus.INVALID_CHILDREN:
            # Re-construct header and footer from source if possible
            h_end = o.contains.source.lines[0] if o.contains and o.contains.source else o.source.lines[1]
            h_end = min(h_end, o.spec.source.lines[0]) if o.spec.source else h_end

            if h_end < o.source.lines[1]:
                header = '\n'.join(o.source.string.splitlines()[:h_end-o.source.lines[0]])
            else:
                header = self._construct_module_header(o, **kwargs)

            # For one-line footers reconstruct from source
            foot = o.source.string.splitlines()[o.source.lines[1]-o.source.lines[0]]
            if 'END ' in foot.upper():
                footer = foot
            else:
                footer = self._construct_module_footer(o, **kwargs)

            self.depth += self.style.module_spec_indent
            spec = self.visit(o.spec, **kwargs)
            self.depth -= self.style.module_spec_indent

            self.depth += self.style.module_contains_indent
            contains = self.visit(o.contains, **kwargs)
            self.depth -= self.style.module_contains_indent
            if contains:
                return self.join_lines(header, spec, contains, footer)

            return self.join_lines(header, spec, footer)

        return super().visit_Module(o, *args, **kwargs)
loki-ecmwf-0.3.6/loki/backend/cufgen.py0000664000175000017500000000620215167130205020107 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.backend.fgen import FortranCodegen
from loki.backend.style import FortranStyle

__all__ = ['cufgen', 'CudaFortranCodegen']


class CudaFortranCodegen(FortranCodegen):
    """
    Tree visitor that extends :any:`FortranCodegen` with Cuda Fortran (CUF) language variations.
    """

    def visit_CallStatement(self, o, **kwargs):
        """
        Format call statement as
          CALL() ()
          with the chevron as launch configuration for device offloading,
          resulting in something like
          call kernel<<>>(arg1,arg2,...)
        """
        pragma = self.visit(o.pragma, **kwargs)
        name = self.visit(o.name, **kwargs)
        args = self.visit_all(o.arguments, **kwargs)
        if o.chevron is not None:
            chevron = f"<<<{','.join([str(elem) for elem in o.chevron])}>>>"
        else:
            chevron = ""
        if o.kwarguments:
            args += tuple(f'{self.visit(arg[0], **kwargs)}={self.visit(arg[1], **kwargs)}' for arg in o.kwarguments)
        call = self.format_line('CALL ', name, chevron, '(', self.join_items(args), ')')
        return self.join_lines(pragma, call)

    def visit_SymbolAttributes(self, o, **kwargs):
        """
        Format declaration attributes as
          [()] [, ]
        """
        attr_str = super().visit_SymbolAttributes(o, **kwargs)
        attributes = []

        attr_dic = {
            "device": "DEVICE",
            "managed": "MANAGED",
            "constant": "CONSTANT",
            "shared": "SHARED",
            "pinned": "PINNED",
            "texture": "TEXTURE"
                    }

        for key, value in attr_dic.items():
            if getattr(o, key):
                attributes += [value]

        return self.join_items([attr_str] + attributes)


def cufgen(ir, style=None, depth=0):
    """
    Generate CUDA Fortran code from one or many IR objects/trees.

    Implemented by extending the :class:`FortranCodegen` to support
    CUDA Fortran specific syntax. Refer to the CUDA_FORTRAN_PROGRAMMING_GUIDE_ for more information.

    Supported subset of the CUDA Fortran specifications:

    * variable qualifiers e.g. ``attributes(device)``
    * chevron syntax for to launch kernels e.g. ``call kernel<<>>(arg1,arg2,...)``

    Natively supported (via :class:`FortranCodegen`):

    * subroutine/function qualifiers e.g. ``attributes(global)`` via :py:attr:`loki.Subroutine.prefix`
    * kernel loop directives via :class:`loki.ir.Pragma`

    .. _CUDA_FORTRAN_PROGRAMMING_GUIDE: https://docs.nvidia.com/hpc-sdk/compilers/cuda-fortran-prog-guide/index.html
    """
    style = style if style else FortranStyle()
    return CudaFortranCodegen(style=style, depth=depth).visit(ir)
loki-ecmwf-0.3.6/loki/backend/pygen.py0000664000175000017500000002755615167130205020001 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from pymbolic.mapper.stringifier import PREC_NONE, PREC_CALL

from loki.backend.pprint import Stringifier
from loki.backend.style import DefaultStyle

from loki.expression import symbols as sym, LokiStringifyMapper
from loki.types import BasicType, DerivedType, SymbolAttributes


__all__ = ['pygen', 'PyCodegen', 'PyCodeMapper']


def numpy_type(_type):
    if _type.shape is not None:
        return 'np.ndarray'
    if _type.dtype == BasicType.LOGICAL:
        return 'bool'
    if _type.dtype == BasicType.INTEGER:
        return 'np.int32'
    if _type.dtype == BasicType.REAL:
        if str(_type.kind) in ('real32',):
            return 'np.float32'
        return 'np.float64'
    if isinstance(_type.dtype, DerivedType):
        return _type.dtype.name
    raise ValueError(str(_type))


class PyCodeMapper(LokiStringifyMapper):
    """
    Generate Python representation of expression trees using numpy syntax.
    """
    # pylint: disable=abstract-method, unused-argument

    def map_logic_literal(self, expr, enclosing_prec, *args, **kwargs):
        return 'True' if bool(expr.value) else 'False'

    def map_float_literal(self, expr, enclosing_prec, *args, **kwargs):
        return str(expr.value)

    map_int_literal = map_float_literal

    def map_cast(self, expr, enclosing_prec, *args, **kwargs):
        _type = SymbolAttributes(BasicType.from_fortran_type(expr.name), kind=expr.kind)
        expression = self.parenthesize_if_needed(
            self.join_rec('', expr.parameters, PREC_NONE, *args, **kwargs),
            PREC_CALL, PREC_NONE)
        return self.parenthesize_if_needed(
            self.format('%s(%s)', numpy_type(_type), expression), enclosing_prec, PREC_CALL)

    def map_variable_symbol(self, expr, enclosing_prec, *args, **kwargs):
        return expr.name.replace('%', '.')

    def map_meta_symbol(self, expr, enclosing_prec, *args, **kwargs):
        return self.rec(expr._symbol, enclosing_prec, *args, **kwargs)

    map_scalar = map_meta_symbol
    map_array = map_meta_symbol

    def map_array_subscript(self, expr, enclosing_prec, *args, **kwargs):
        name_str = self.rec(expr.aggregate, PREC_NONE, *args, **kwargs)
        dims = [self.format(self.rec(d, PREC_NONE, *args, **kwargs)) for d in expr.index_tuple]
        dims = [d for d in dims if d]
        if not dims:
            index_str = ''
        else:
            index_str = f'[{", ".join(dims)}]'
        return self.format('%s%s', name_str, index_str)

    map_string_subscript = map_array_subscript

    def map_string_concat(self, expr, enclosing_prec, *args, **kwargs):
        return ' + '.join(self.rec(c, enclosing_prec, *args, **kwargs) for c in expr.children)

    def map_inline_call(self, expr, enclosing_prec, *args, **kwargs):
        arguments = ', '.join(self.rec(p, PREC_NONE, *args, **kwargs) for p in expr.parameters)

        if expr.kw_parameters:
            arguments += ', ' + ', '.join(
                f'{self.rec(k, PREC_NONE, *args, **kwargs)}={self.rec(v, PREC_NONE, *args, **kwargs)}'
                for k, v in expr.kw_parameters.items()
            )

        f = self.rec(expr.function, PREC_NONE, *args, **kwargs)
        return self.format(f'{str(f)}({arguments})')

    def map_deferred_type_symbol(self, expr, *args, **kwargs):
        return str(expr.name).replace('%', '.')


class PyCodegen(Stringifier):
    """
    Tree visitor to generate standard Python code (with Numpy) from IR.
    """

    def __init__(self, style, depth=0):
        super().__init__(
            style=style, depth=depth, symgen=PyCodeMapper(), line_cont='\n{}  '.format
        )

    # Handler for outer objects

    def visit_Sourcefile(self, o, **kwargs):
        """
        Format as
          ...modules...
          ...subroutines...
        """
        return self.visit(o.ir, **kwargs)

    def visit_Module(self, o, **kwargs):
        raise NotImplementedError()

    def visit_Subroutine(self, o, **kwargs):
        """
        Format as:
            ...imports...
            def ():
                ...spec without imports and only declarations with initial values...
                ...body...
        """
        # Some boilerplate imports...
        standard_imports = ['numpy as np']
        header = [self.format_line('import ', name) for name in standard_imports]

        # ...and imports from the spec
        # TODO

        # Generate header with argument signature
        # Note: we skip scalar out arguments and add a return statement for those below
        scalar_args = [a for a in o.arguments if isinstance(a, sym.Scalar)]
        inout_args = [a for a in scalar_args if a.type.intent and a.type.intent.lower() == 'inout']
        out_args = [a for a in scalar_args if a.type.intent and a.type.intent.lower() == 'out']
        arguments = [arg for arg in o.arguments if arg not in out_args]
        arg_str = []
        for arg in arguments:
            if isinstance(arg.type.dtype, DerivedType):
                arg_str += [f'{arg.name}']
            else:
                dtype = self.visit(arg.type, **kwargs)
                arg_str += [f'{arg.name}: {dtype}']
        header += [self.format_line('def ', o.name, '(', self.join_items(arg_str), '):')]

        # ...and generate the spec without imports and only declarations for variables that
        # either are local arrays or are assigned an initial value
        self.depth += self.style.indent_default
        body = [self.visit(o.spec, **kwargs)]

        # Fill the body
        body += [self.visit(o.body, **kwargs)]

        # Add return statement for scalar out arguments and close everything off
        ret_args = [arg for arg in o.arguments if arg in inout_args + out_args]
        body += [self.format_line('return ', self.join_items(self.visit_all(ret_args, **kwargs)))]
        self.depth -= self.style.indent_default

        return self.join_lines(*header, *body)

    # Handler for IR nodes

    def visit_Intrinsic(self, o, **kwargs):  # pylint: disable=unused-argument
        """
        Format intrinsic nodes.
        """
        return self.format_line(str(o.text).lstrip())

    def visit_Comment(self, o, **kwargs):  # pylint: disable=unused-argument
        """
        Format comments.
        """
        text = o.text or o.source.string
        text = str(text).lstrip().replace('!', '#', 1)
        return self.format_line(text, no_wrap=True)

    def visit_CommentBlock(self, o, **kwargs):
        """
        Format comment blocks.
        """
        comments = self.visit_all(o.comments, **kwargs)
        return self.join_lines(*comments)

    def visit_VariableDeclaration(self, o, **kwargs):
        """
        Format declaration as
           = 
        and skip any arguments or scalars without an initial value
        """
        comment = self.visit(o.comment, **kwargs) if o.comment else None

        # Initialise local arrays via numpy
        local_arrays = [v for v in o.symbols if isinstance(v, sym.Array) and not v.type.intent]
        array_decls = tuple(
            self.format_line(
                v.name, ' = np.ndarray(order="F", shape=(',
                self.join_items(self.visit_all(v.dimensions, **kwargs)), ',))'
            ) for v in local_arrays
        )

        # Assign initial values, if given
        init_decls = tuple(
            self.format_line(v.name, ' = ', self.visit(v.initial, **kwargs))
            for v in o.symbols if hasattr(v, 'initial') and v.initial is not None
        )

        # Break out early to avoid needless newlines
        if not comment and not array_decls and not init_decls:
            return None

        return self.join_lines(comment, *array_decls, *init_decls)

    def visit_Import(self, o, **kwargs):  # pylint: disable=unused-argument
        """
        Skip imports
        """
        return None

    def visit_Loop(self, o, **kwargs):
        """
        Format loop with explicit range as
          for  in range(,  + , ):
            ...body...
        """
        var = self.visit(o.variable, **kwargs)
        start = self.visit(o.bounds.start, **kwargs)
        end = self.visit(o.bounds.stop, **kwargs)
        if o.bounds.step:
            incr = self.visit(o.bounds.step, **kwargs)
            cntrl = f'range({start}, {end} + {incr}, {incr})'
        else:
            cntrl = f'range({start}, {end} + 1)'
        header = self.format_line('for ', var, ' in ', cntrl, ':')
        self.depth += self.style.indent_default
        body = self.visit(o.body, **kwargs)
        self.depth -= self.style.indent_default
        return self.join_lines(header, body)

    def visit_WhileLoop(self, o, **kwargs):
        """
        Format loop as:
          while :
            ...body...
        """
        if o.condition is not None:
            condition = self.visit(o.condition, **kwargs)
        else:
            condition = 'True'
        header = self.format_line('while ', condition, ':')
        self.depth += self.style.indent_default
        body = self.visit(o.body, **kwargs)
        self.depth -= self.style.indent_default
        return self.join_lines(header, body)

    def visit_Conditional(self, o, **kwargs):
        """
        Format conditional as
        if :
          ...body...
        [elif :]
          [...body...]
        [else:]
          [...body...]
        """
        is_elseif = kwargs.pop('is_elseif', False)
        keyword = 'elif' if is_elseif else 'if'
        header = self.format_line(keyword, ' ', self.visit(o.condition, **kwargs), ':')
        self.depth += self.style.indent_default
        body = self.visit(o.body, **kwargs)
        if o.has_elseif:
            self.depth -= self.style.indent_default
            else_body = [self.visit(o.else_body, is_elseif=True, **kwargs)]
        else:
            else_body = [self.visit(o.else_body, **kwargs)]
            self.depth -= self.style.indent_default
            if o.else_body:
                else_body = [self.format_line('else:')] + else_body
        return self.join_lines(header, body, *else_body)

    def visit_Assignment(self, o, **kwargs):
        """
        Format statement as
           =  []
        """
        lhs = self.visit(o.lhs, **kwargs)
        rhs = self.visit(o.rhs, **kwargs)
        comment = None
        if o.comment:
            comment = f'  {self.visit(o.comment, **kwargs)}'
        return self.format_line(lhs, ' = ', rhs, comment=comment)

    def visit_Section(self, o, **kwargs):
        """
        Format the section's body.
        """
        return self.visit(o.body, **kwargs)

    def visit_CallStatement(self, o, **kwargs):
        """
        Format call statement as
          ()
        """
        args = self.visit_all(o.arguments, **kwargs)
        kw_args = tuple(f'{kw}={self.visit(arg, **kwargs)}' for kw, arg in o.kwarguments)
        return self.format_line(o.name, '(', self.join_items(args + kw_args), ')')

    def visit_SymbolAttributes(self, o, **kwargs):  # pylint: disable=unused-argument
        return numpy_type(o)

    def visit_StatementFunction(self, o, **kwargs):
        args = tuple(self.visit(a, **kwargs) for a in o.arguments)
        header = self.format_line('def ', o.variable.name, f'({self.join_items(args)}):')

        self.depth += self.style.indent_default
        body = self.format_line('return ', self.visit(o.rhs, **kwargs))
        self.depth -= self.style.indent_default
        return f'{header}\n{body}'


def pygen(ir):
    """
    Generate standard Python 3 code (that uses Numpy) from one or many IR objects/trees.
    """
    return PyCodegen(style=DefaultStyle(linewidth=300)).visit(ir)
loki-ecmwf-0.3.6/loki/backend/style.py0000664000175000017500000000345115167130205020003 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from pydantic.dataclasses import dataclass

__all__ = ['DefaultStyle', 'FortranStyle', 'IFSFortranStyle']

@dataclass
class DefaultStyle:
    """
    Default style class that defines the formatting of generated code.
    """
    linewidth: int = 90

    indent_default: int = 2
    indent_char: str = ' '


@dataclass
class FortranStyle(DefaultStyle):
    """
    Style class that defines the output code style for a Fortran backend.
    """
    linewidth: int = 132

    associate_indent: int = 2

    conditional_indent: int = 2
    conditional_end_space: bool = True

    loop_indent: int = 2
    loop_end_space: bool = True

    procedure_spec_indent: int = 2
    procedure_body_indent: int = 2
    procedure_contains_indent: int = 2
    procedure_end_named: bool = True

    module_spec_indent: int = 2
    module_contains_indent: int = 2
    module_end_named: bool = True


@dataclass
class IFSFortranStyle(FortranStyle):
    """
    Style class that defines the output code style for a Fortran backend.
    """
    linewidth: int = 132

    associate_indent: int = 0

    conditional_indent: int = 2
    conditional_end_space: bool = False

    loop_indent: int = 2
    loop_end_space: bool = False

    procedure_spec_indent: int = 0
    procedure_body_indent: int = 0
    procedure_contains_indent: int = 2
    procedure_end_named: bool = True

    module_spec_indent: int = 0
    module_contains_indent: int = 2
    module_end_named: bool = True
loki-ecmwf-0.3.6/loki/backend/cppgen.py0000664000175000017500000000611415167130205020116 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.backend.cgen import CCodegen, CCodeMapper, IntrinsicTypeC
from loki.backend.style import DefaultStyle

from loki.expression import Array
from loki.types import BasicType

__all__ = ['cppgen', 'CppCodegen', 'CppCodeMapper', 'IntrinsicTypeCpp']


class IntrinsicTypeCpp(IntrinsicTypeC):
    """
    Mapping Fortran type to corresponding C++ type.
    """

    def get_str_from_symbol_attr(self, _type, *args, **kwargs):
        if _type.dtype == BasicType.INTEGER:
            if _type.parameter:
                return 'const int'
            return 'int'
        return super().get_str_from_symbol_attr(_type, *args, **kwargs)

cpp_intrinsic_type = IntrinsicTypeCpp()


class CppCodeMapper(CCodeMapper):
    """
    A :class:`StringifyMapper`-derived visitor for Pymbolic expression trees that converts an
    expression to a string adhering to standardized C++.
    """
    # pylint: disable=abstract-method, unused-argument

    def map_inline_call(self, expr, enclosing_prec, *args, **kwargs):
        if expr.function.name.lower() == 'present':
            return self.format('%s', expr.parameters[0].name)
        return super().map_inline_call(expr, enclosing_prec, *args, **kwargs)


class CppCodegen(CCodegen):
    """
    Tree visitor to generate standardized C++ code from IR.
    """

    def __init__(self, depth=0, indent='  ', linewidth=90, **kwargs):
        symgen = kwargs.pop('symgen', CppCodeMapper(cpp_intrinsic_type))

        super().__init__(depth=depth, indent=indent, linewidth=linewidth,
                         symgen=symgen, **kwargs)

    def _subroutine_argument_keyword(self, a):
        if isinstance(a, Array) and a.type.intent.lower() == "in":
            return 'const '
        return ''

    def _subroutine_declaration(self, o, **kwargs):
        opt_extern = kwargs.get('extern', False)
        declaration = [self.format_line('extern "C" {\n')] if opt_extern else []
        declaration += super()._subroutine_declaration(o, **kwargs)
        return declaration

    def _subroutine_body(self, o, **kwargs):
        body = super()._subroutine_body(o, **kwargs)
        return body

    def _subroutine_footer(self, o, **kwargs):
        opt_extern = kwargs.get('extern', False)
        footer = super()._subroutine_footer(o, **kwargs)
        footer += [self.format_line('\n} // extern')] if opt_extern else []
        return footer

    def _subroutine_optional_args(self, a):
        if a.type.optional:
            return ' = nullptr'
        return ''

def cppgen(ir, **kwargs):
    """
    Generate standardized C++ code from one or many IR objects/trees.
    """
    style = kwargs.pop('style', DefaultStyle())
    depth = kwargs.pop('depth', 0)
    return CppCodegen(style=style, depth=depth).visit(ir, **kwargs)
loki-ecmwf-0.3.6/loki/backend/pprint.py0000664000175000017500000002704415167130205020163 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Pretty-printer classes for IR
"""

from sys import stdout

from loki.backend.style import DefaultStyle

from loki.tools import JoinableStringList, is_iterable, as_tuple
from loki.ir.visitor import Visitor


__all__ = ['Stringifier', 'pprint']


class Stringifier(Visitor):
    """
    Convert a given IR tree to a string representation.

    This serves as base class for backends and provides a number of helpful
    routines that ease implementing automatic recursion and line wrapping.
    It doubles as a means to produce a human readable representation of the
    IR, which is useful for debugging purposes.

    Parameters
    ----------
    depth : int, optional
        The level of indentation to be applied initially.
    indent : str, optional
        The string to be prepended to a line for each level of indentation.
    linewidth : int, optional
        The line width limit after which to break a line.
    line_cont : optional
        A function handle that accepts the current indentation string
        (:attr:`Stringifier.indent`) and returns the string for line
        continuation. This is inserted between two lines when they need to
        wrap to stay within the line width limit. Defaults to newline character
        plus indentation.
    symgen : optional
        A function handle that accepts a :any:`pymbolic.primitives.Expression`
        and produces a string representation for that.
    """

    # pylint: disable=arguments-differ

    def __init__(
            self, style, depth=0, symgen=str,
            line_cont=lambda indent: '\n' + indent
    ):
        super().__init__()
        self.style = style
        self.depth = depth

        self.line_cont = line_cont
        self._symgen = symgen

    @property
    def symgen(self):
        """
        Formatter for expressions.
        """
        return self._symgen

    @property
    def indent(self):
        """
        Yield indentation string according to current depth.

        Returns
        -------
        str
            A string containing ``indent * depth``.
        """
        return self.style.indent_char * self.depth

    @staticmethod
    def join_lines(*lines):
        """
        Combine multiple lines into a long string, inserting line breaks in between.
        Entries that are `None` are skipped.

        Parameters
        ----------
        lines : list
             The lines to be combined.

        Returns
        -------
        str or `None`
            The combined string or `None` if an empty list was given.
        """
        if not lines:
            return None
        return '\n'.join(line for line in lines if line is not None)

    def join_items(self, items, sep=', ', separable=True):
        """
        Concatenate a list of items into :any:`JoinableStringList`.

        The return value can be passed to :meth:`format_line` or
        :meth:`format_node` or converted to a string with `str`, using
        the :any:`JoinableStringList` as an argument.
        Upon expansion, lines will be wrapped automatically to stay within
        the linewidth limit.

        Parameters
        ----------
        items : list
            The list of strings to be joined.
        sep : str, optional
            The separator to be inserted between items.
        separable : bool, optional
            Allow line breaks between individual :data:`items`.

        Returns
        -------
        :any:`JoinableStringList`
        """
        return JoinableStringList(
            items, sep=sep, width=self.style.linewidth,
            cont=self.line_cont(self.indent), separable=separable
        )

    def format_node(self, name, *items):
        """
        Default format for a node.

        Creates a string of the form ````.
        """
        if items:
            return self.format_line('<', name, ' ', self.join_items(items), '>')
        return self.format_line('<', name, '>')

    def format_line(
            self, *items, comment=None, no_wrap=False,
            no_indent=False, trim_spaces=True
    ):
        """
        Format a line by concatenating all items and applying indentation while observing
        the allowed line width limit.

        Note that the provided comment will simply be appended to the line and no line
        width limit will be enforced for that.

        :param list items: the items to be put on that line.
        :param str comment: an optional inline comment to be put at the end of the line.
        :param bool no_wrap: disable line wrapping.
        :param bool no_indent: do not apply indentation.

        :return: the string of the current line, potentially including line breaks if
                 required to observe the line width limit.
        :rtype: str
        """
        if not no_indent:
            items = [self.indent, *items]
        if no_wrap:
            # Simply concatenate items and append the comment
            line = ''.join(str(item) for item in items)
        else:
            # Use join_items to concatenate items
            line = str(self.join_items(items, sep=''))
        if comment:
            return line + comment
        line = line.rstrip() if trim_spaces else line
        return line

    def visit_all(self, item, *args, **kwargs):
        """
        Convenience function to call :meth:`visit` for all given arguments.

        If only a single argument is given that is iterable,
        :meth:`visit` is called on all of its elements instead.
        """
        if is_iterable(item) and not args:
            return as_tuple(self.visit(i, **kwargs) for i in item if i is not None)
        return as_tuple(self.visit(i, **kwargs) for i in [item, *args] if i is not None)

    # Handler for outer objects

    def visit_Module(self, o, **kwargs):
        """
        Format a :any:`Module` as

        .. code-block:: none

           
             ...spec...
             ...routines...
        """
        header = self.format_node(repr(o))
        self.depth += self.style.indent_default
        spec = self.visit(o.spec, **kwargs)
        routines = self.visit(o.subroutines, **kwargs)
        self.depth -= self.style.indent_default
        return self.join_lines(header, spec, routines)

    def visit_Subroutine(self, o, **kwargs):
        """
        Format a :any:`Subroutine` as

        .. code-block:: none

           
             ...docstring...
             ...spec...
             ...body...
             ...members...
        """
        header = self.format_node(repr(o))
        self.depth += self.style.indent_default
        docstring = self.visit(o.docstring, **kwargs)
        spec = self.visit(o.spec, **kwargs)
        body = self.visit(o.body, **kwargs)
        members = self.visit(o.members, **kwargs)
        self.depth -= self.style.indent_default
        return self.join_lines(header, docstring, spec, body, members)

    # Handler for AST base nodes

    def visit_Node(self, o, **kwargs):
        """
        Format a :any:`Node` as

        .. code-block:: none

           
        """
        return self.format_node(repr(o))

    def visit_Expression(self, o, **kwargs):  # pylint: disable=unused-argument
        """
        Dispatch routine to expression tree stringifier
        :attr:`Stringifier.symgen`.
        """
        return self.symgen(o)

    def visit_tuple(self, o, **kwargs):
        """
        Recurse for each item in the tuple and return as separate lines.
        """
        lines = (self.visit(item, **kwargs) for item in o)
        return self.join_lines(*lines)

    visit_list = visit_tuple

    # Handler for IR nodes

    def visit_InternalNode(self, o, **kwargs):
        """
        Format :any:`InternalNode` as

        .. code-block:: none

           
             ...body...
        """
        header = self.format_node(repr(o))
        self.depth += self.style.indent_default
        body = self.visit(o.body, **kwargs)
        self.depth -= self.style.indent_default
        return self.join_lines(header, body)


    def visit_Conditional(self, o, **kwargs):
        """
        Format :any:`Conditional` as

        .. code-block:: none

           
             
               ...
             
               ...
        """
        header = self.format_node(repr(o))
        self.depth += self.style.indent_default
        conditions = [self.format_node('If', self.visit(o.condition, **kwargs))]
        if o.else_body:
            conditions.append(self.format_node('Else'))
        self.depth += self.style.indent_default
        bodies = self.visit_all(o.body, o.else_body, **kwargs)
        self.depth -= self.style.indent_default
        self.depth -= self.style.indent_default
        body = [item for branch in zip(conditions, bodies) for item in branch]
        return self.join_lines(header, *body)

    def visit_MultiConditional(self, o, **kwargs):
        """
        Format :any:`MultiConditional` as

        .. code-block:: none

           
             
               ...
             
               ...
             
               ...
        """
        header = self.format_node(repr(o))
        self.depth += self.style.indent_default
        values = []
        for expr in o.values:
            value = f'({", ".join(self.visit_all(expr, **kwargs))})'
            values += [self.format_node('Case', value)]
        if o.else_body:
            values += [self.format_node('Default')]
        self.depth += self.style.indent_default
        bodies = self.visit_all(*o.bodies, o.else_body, **kwargs)
        self.depth -= self.style.indent_default
        self.depth -= self.style.indent_default
        body = [item for branch in zip(values, bodies) for item in branch]
        return self.join_lines(header, *body)

    def visit_TypeConditional(self, o, **kwargs):
        """
        Format :any:`TypeConditional` as

        .. code-block:: none

           
             
               ...
             
               ...
             
               ...
        """
        header = self.format_node(repr(o))
        self.depth += self.style.indent_default
        values = []
        for expr in o.values:
            value = self.visit(expr[0], **kwargs)
            values += [self.format_node('Class' if expr[1] else 'Type', value)]
        if o.else_body:
            values += [self.format_node('Default')]
        self.depth += self.style.indent_default
        bodies = self.visit_all(*o.bodies, o.else_body, **kwargs)
        self.depth -= self.style.indent_default
        self.depth -= self.style.indent_default
        body = [item for branch in zip(values, bodies) for item in branch]
        return self.join_lines(header, *body)


def pprint(ir, stream=None):
    """
    Pretty-print the given IR using :class:`Stringifier`.

    Parameters
    ----------
    ir : :any:`Node`
        The IR node starting from which to print the tree
    stream : optional
        If given, call :meth:`Stringifier.write` on this stream instead of
        :any:`sys.stdout`
    """
    if stream is None:
        stream = stdout
    stream.write(Stringifier(style=DefaultStyle()).visit(ir))
loki-ecmwf-0.3.6/loki/backend/fgen.py0000664000175000017500000012033515167130205017563 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from pymbolic.mapper.stringifier import (
    PREC_UNARY, PREC_LOGICAL_AND, PREC_LOGICAL_OR, PREC_COMPARISON, PREC_NONE
)
from pymbolic.primitives import FloorDiv, Remainder

from loki.backend.pprint import Stringifier
from loki.backend.style import FortranStyle

from loki.expression import LokiStringifyMapper, StringLiteral
from loki.ir import get_pragma_parameters
from loki.tools import as_tuple, JoinableStringList, flatten
from loki.types import DataType, BasicType, DerivedType, ProcedureType


__all__ = ['fgen', 'fexprgen', 'FortranCodegen', 'FCodeMapper']


class FCodeMapper(LokiStringifyMapper):
    """
    A :class:`StringifyMapper`-derived visitor for Pymbolic expression trees that converts an
    expression to a string adhering to the Fortran standard.
    """
    # pylint: disable=abstract-method

    COMPARISON_OP_TO_FORTRAN = {
        "==": r"==",
        "!=": r"/=",
        "<=": r"<=",
        ">=": r">=",
        "<": r"<",
        ">": r">",
    }

    def map_logic_literal(self, expr, enclosing_prec, *args, **kwargs):
        return '.true.' if expr.value else '.false.'

    def map_float_literal(self, expr, enclosing_prec, *args, **kwargs):
        if expr.kind is not None:
            return f'{str(expr.value)}_{str(expr.kind)}'
        return str(expr.value)

    map_int_literal = map_float_literal

    def map_logical_not(self, expr, enclosing_prec, *args, **kwargs):
        return self.parenthesize_if_needed(
            ".not." + self.rec(expr.child, PREC_UNARY, *args, **kwargs),
            enclosing_prec, PREC_UNARY)

    def map_logical_and(self, expr, enclosing_prec, *args, **kwargs):
        return self.parenthesize_if_needed(
            self.join_rec(" .and. ", expr.children, PREC_LOGICAL_AND, *args, **kwargs),
            enclosing_prec, PREC_LOGICAL_AND)

    def map_logical_or(self, expr, enclosing_prec, *args, **kwargs):
        return self.parenthesize_if_needed(
            self.join_rec(" .or. ", expr.children, PREC_LOGICAL_OR, *args, **kwargs),
            enclosing_prec, PREC_LOGICAL_OR)

    def map_comparison(self, expr, enclosing_prec, *args, **kwargs):
        """
        This translates the C-style notation for comparison operators used internally in Pymbolic
        to the corresponding Fortran comparison operators.
        """
        return self.parenthesize_if_needed(
            self.format("%s %s %s", self.rec(expr.left, PREC_COMPARISON, *args, **kwargs),
                        self.COMPARISON_OP_TO_FORTRAN[expr.operator],
                        self.rec(expr.right, PREC_COMPARISON, *args, **kwargs)),
            enclosing_prec, PREC_COMPARISON)

    def map_literal_list(self, expr, enclosing_prec, *args, **kwargs):
        values = ', '.join(self.rec(c, PREC_NONE, *args, **kwargs) for c in expr.elements)
        if expr.dtype is not None:
            return f'(/ {fgen(expr.dtype)} :: {values} /)'
        return f'(/ {values} /)'

    def map_foreign(self, expr, *args, **kwargs):
        try:
            return super().map_foreign(expr, *args, **kwargs)
        except ValueError:
            return f'! Not supported: {str(expr)}\n'

    def map_loop_range(self, expr, enclosing_prec, *args, **kwargs):
        children = [self.rec(child, PREC_NONE, *args, **kwargs) if child is not None else ''
                    for child in expr.children]
        # Do not unnecessarily print `:1` stepping for loops
        if expr.step is None or str(expr.step) == '1':
            children = children[:-1]
        return self.parenthesize_if_needed(self.join(',', children), enclosing_prec, PREC_NONE)

    # Suppress Pymbolics's conservative default bracketing by override
    # the multiplicative primitives to exclude `Product` and
    # `Quotient` nodes.
    # This is done to suppress the default bracketing, which can cause
    # round-off deviations for agressively optimising compilers. Since
    # we explicitly handle bracketing in our expression nodes, we can
    # drop this here... famous last words!
    multiplicative_primitives = (FloorDiv, Remainder)


class FortranCodegen(Stringifier):
    """
    Tree visitor to generate standardized Fortran code from IR.
    """
    # pylint: disable=unused-argument

    def __init__(self, style, depth=0):
        super().__init__(
            style=style, depth=depth, line_cont=' &\n{}& '.format,
            symgen=FCodeMapper()
        )

    def apply_label(self, line, label):
        """
        Apply a label to the given (formatted) line by replacing indentation with the label.

        :param str line: the formatted line.
        :param label: the label to apply.
        :type label: str or NoneType

        :return: the line with the label applied if given, else the original line.
        :rtype: str
        """
        if label is not None:
            # Replace indentation by label
            indent = max(1, len(line) - len(line.lstrip()) - 1)
            line = f'{label:{indent}} {line.lstrip()}'
        return line

    # Handler for outer objects

    def visit_Sourcefile(self, o, **kwargs):
        """
        Format as
          ...modules...
          ...subroutines...
        """
        return self.visit(o.ir, **kwargs)

    def _construct_module_header(self, o, **kwargs):
        return self.format_line('MODULE ', o.name)

    def _construct_module_footer(self, o, **kwargs):
        return self.format_line('END MODULE ', o.name if self.style.module_end_named else '')

    def visit_Module(self, o, **kwargs):
        """
        Format as
          MODULE 
            ...spec...
          CONTAINS
            ...routines...
          END MODULE
        """
        header = self._construct_module_header(o, **kwargs)
        footer = self._construct_module_footer(o, **kwargs)

        self.depth += self.style.module_spec_indent

        docstring = self.visit(o.docstring, **kwargs)

        # Format any access-specifiers
        access_spec = []
        if o.default_access_spec is not None:
            access_spec += [self.format_line(o.default_access_spec)]
        if o.public_access_spec:
            access_spec += [self.format_line('PUBLIC :: ', ', '.join(o.public_access_spec))]
        if o.private_access_spec:
            access_spec += [self.format_line('PRIVATE :: ', ', '.join(o.private_access_spec))]

        if access_spec:
            # Handle the spec in parts to deal with access specifiers
            import_part, implicit_part, decl_part = o.spec_parts
            spec = ''
            if import_part:
                spec += self.visit(import_part, **kwargs) + '\n'
            if implicit_part:
                spec += self.visit(implicit_part, **kwargs) + '\n'
            spec += self.join_lines(*access_spec) + '\n'
            if decl_part:
                spec += self.visit(decl_part, **kwargs) + '\n'
        else:
            spec = self.visit(o.spec, **kwargs)
        self.depth -= self.style.module_spec_indent

        # Render the routines
        self.depth += self.style.module_contains_indent
        contains = self.visit(o.contains, **kwargs)
        self.depth -= self.style.module_contains_indent

        return self.join_lines(header, docstring, spec, contains, footer)

    def _construct_procedure_footer(self, o, **kwargs):
        ftype = 'FUNCTION' if o.is_function else 'SUBROUTINE'
        return self.format_line('END ', ftype, ' ', o.name if self.style.procedure_end_named else '')

    def _construct_function_header(self, o, **kwargs):
        prefix = self.join_items(o.prefix, sep=' ')
        if o.prefix:
            prefix += ' '
        if not o.result_name in o.variable_map:
            prefix += f'{self.visit(o.return_type)} '
        arguments = self.join_items(o.argnames)
        result = f' RESULT({o.result_name})' if o.result_name\
                and o.result_name.lower() != o.name.lower() else ''
        if isinstance(o.bind, str):
            bind_c = f' BIND(c, name="{o.bind}")'
        elif isinstance(o.bind, StringLiteral):
            bind_c = f' BIND(c, name={o.bind})'
        else:
            bind_c = ''

        return self.format_line(prefix, 'FUNCTION ', o.name, ' (', arguments, ')', result, bind_c)

    def visit_Function(self, o, **kwargs):
        """
        Format as
          [] FUNCTION  ([]) [RESULT()] [BIND(c, name=)]
            ...docstring...
            ...spec...
            ...body...
          [CONTAINS]
            [...member...]
          END FUNCTION 
        """

        header = self._construct_function_header(o, **kwargs)
        footer = self._construct_procedure_footer(o, **kwargs)

        self.depth += self.style.procedure_spec_indent
        docstring = self.visit(o.docstring, **kwargs)
        spec = self.visit(o.spec, **kwargs)
        self.depth -= self.style.procedure_spec_indent

        self.depth += self.style.procedure_body_indent
        body = self.visit(o.body, **kwargs)
        self.depth -= self.style.procedure_body_indent

        self.depth += self.style.procedure_contains_indent
        contains = self.visit(o.contains, **kwargs)
        self.depth -= self.style.procedure_contains_indent
        if contains:
            return self.join_lines(header, docstring, spec, body, contains, footer)

        return self.join_lines(header, docstring, spec, body, footer)

    def _construct_subroutine_header(self, o, **kwargs):
        prefix = self.join_items(o.prefix, sep=' ')
        if o.prefix:
            prefix += ' '
        arguments = self.join_items(o.argnames)
        if isinstance(o.bind, str):
            bind_c = f' BIND(c, name="{o.bind}")'
        elif isinstance(o.bind, StringLiteral):
            bind_c = f' BIND(c, name={o.bind})'
        else:
            bind_c = ''

        return self.format_line(prefix, 'SUBROUTINE ', o.name, ' (', arguments, ')', bind_c)

    def visit_Subroutine(self, o, **kwargs):
        """
        Format as
           []  ([]) [RESULT()] [BIND(c, name=)]
            ...docstring...
            ...spec...
            ...body...
          [CONTAINS]
            [...member...]
          END  
        """
        header = self._construct_subroutine_header(o, **kwargs)
        footer = self._construct_procedure_footer(o, **kwargs)

        self.depth += self.style.procedure_spec_indent
        docstring = self.visit(o.docstring, **kwargs)
        spec = self.visit(o.spec, **kwargs)
        self.depth -= self.style.procedure_spec_indent

        self.depth += self.style.procedure_body_indent
        body = self.visit(o.body, **kwargs)
        self.depth -= self.style.procedure_body_indent

        self.depth += self.style.procedure_contains_indent
        contains = self.visit(o.contains, **kwargs)
        self.depth -= self.style.procedure_contains_indent
        if contains:
            return self.join_lines(header, docstring, spec, body, contains, footer)

        return self.join_lines(header, docstring, spec, body, footer)

    # Handler for AST base nodes

    def visit_Node(self, o, **kwargs):
        """
        Format non-supported nodes as
          ! 
        """
        return self.format_line('! <', repr(o), '>')

    def visit_tuple(self, o, **kwargs):
        """
        Recurse for each item in the tuple and return as separate lines.
        Insert labels if existing.
        """
        lines = []
        for item in o:
            line = self.visit(item, **kwargs)
            line = self.apply_label(line, getattr(item, 'label', None))
            lines.append(line)
        return self.join_lines(*lines)

    visit_list = visit_tuple

    def visit_str(self, o, **kwargs):
        return o

    # Handler for IR nodes

    def visit_Intrinsic(self, o, **kwargs):
        """
        Format intrinsic nodes.
        """
        return self.format_line(str(o.text).lstrip())

    def visit_RawSource(self, o, **kwargs):
        """
        Format raw source nodes.
        """
        return o.text

    def visit_Comment(self, o, **kwargs):
        """
        Format comments.
        """
        text = o.text
        if not text:
            text = o.source.string if o.source else ''
        return self.format_line(str(text).lstrip(), no_wrap=True)

    def visit_Pragma(self, o, **kwargs):
        """
        Format pragmas.
        """
        if o.content is not None:
            # Deconstruct and re-assemble pragma from parameters
            line_cont = f' &\n!${o.keyword} & '
            items = [f'!${o.keyword}']
            for k, v in get_pragma_parameters(o, only_loki_pragmas=False).items():
                if v:
                    # Need to filter all old line continuations
                    values = [i.replace('&', '').strip().split(' ') for i in as_tuple(v)]
                    # v can be a list if the key occurs more than once
                    items += flatten([(k + '(', *i, ')') for i in values])
                else:
                    items += [k]

            # Ensure '!$ &' line continuation in final string
            return str(JoinableStringList(
                items, sep=' ', width=self.style.linewidth, cont=line_cont, separable=True
            ))
        return o.source.string

    def visit_CommentBlock(self, o, **kwargs):
        """
        Format comment blocks.
        """
        comments = self.visit_all(o.comments, **kwargs)
        return self.join_lines(*comments)

    def visit_PreprocessorDirective(self, o, **kwargs):
        """
        Format preprocessor directives.
        """
        return self.format_line(str(o.text).lstrip(), no_wrap=True, no_indent=True)

    def _construct_type_attributes(self, o, **kwargs):
        attributes = []
        assert len(o.symbols) > 0
        types = [v.type for v in o.symbols]

        # Ensure all variable types are equal, except for shape and dimension
        # TODO: Should extend to deeper recursion of `variables` if
        # the symbol has a known derived type
        ignore = ['shape', 'dimensions', 'variables', 'source', 'initial']

        # Statement functions can share declarations with scalars, so we collect the variable types here
        _var_types = [t.dtype.return_type.dtype if isinstance(t.dtype, ProcedureType) else t.dtype for t in types]
        _procedure_types = [t for t in types if isinstance(t.dtype, ProcedureType)]

        if _procedure_types:
            # Statement functions are the only symbol with ProcedureType that should appear
            # in a VariableDeclaration as all other forms of procedure declarations (bindings,
            # pointers, EXTERNAL statements) are handled by ProcedureDeclaration.
            # However, the fact that statement function declarations can appear mixed with actual
            # variable declarations forbids this in this case.
            assert all(t.is_stmt_func for t in _procedure_types)
            # TODO: We can't fully compare statement functions, yet but we can make at least sure
            # other declared attributes are compatible and that all have the same return type
            ignore += ['dtype']
            assert all(t.dtype.return_type == _procedure_types[0].dtype.return_type or
                       t.dtype.return_type.compare(_procedure_types[0].dtype.return_type, ignore=ignore)
                       for t in _procedure_types)

        assert all((t == _var_types[0]) for t in _var_types)

        is_function = isinstance(types[0].dtype, ProcedureType) and types[0].dtype.is_function
        if is_function:
            assert types[0].is_stmt_func
            # Return type of a function (great syntax there, Fortran!)
            attributes = [self.visit(types[0].dtype.return_type, **kwargs)]

        # Declaration type and attributes
        dtype = self.visit(types[0], **kwargs)
        if str(dtype):
            attributes += [dtype]

        # Dimensions specification
        if o.dimensions:
            attributes += [f'DIMENSION({", ".join(self.visit_all(o.dimensions, **kwargs))})']

        return attributes

    def _construct_decl_variables(self, o, **kwargs):
        # Declared entities
        variables = []
        for v in o.symbols:
            # This is a bit dubious, but necessary, as we otherwise pick up
            # array dimensions from the internal representation of the variable.
            var = self.visit(v, **kwargs) if o.dimensions is None else v.basename
            initial = ''
            if v.type.initial is not None:
                op = '=>' if v.type.pointer else '='
                initial = f' {op} {self.visit(v.type.initial, **kwargs)}'
            variables += [f'{var}{initial}']
        return variables

    def visit_VariableDeclaration(self, o, **kwargs):
        """
        Format declaration as
          [] [, DIMENSION(...)] :: var [= initial] [, var [= initial] ] ...
        """
        # Construct type attributes to the left of `::`
        attributes = self._construct_type_attributes(o, **kwargs)

        # Construct variable symbols to be declared to the right of `::`
        variables = self._construct_decl_variables(o, **kwargs)

        # In-line comment
        comment = None
        if o.comment:
            comment = str(self.visit(o.comment, **kwargs))

        return self.format_line(
            self.join_items(attributes), ' :: ', self.join_items(variables),
            comment=comment
        )

    def visit_ProcedureDeclaration(self, o, **kwargs):
        """
        Format procedure declaration as
          [PROCEDURE[()]] [, POINTER] [, INTENT(...)] [, ...] :: var [=> initial] [, var [=> initial] ] ...
        or
          [MODULE] PROCEDURE [, ...] :: var [, ...]
        or
          GENERIC [, PUBLIC|PRIVATE] :: var => bind_name [, bind_name [, ...]]
        or
          FINAL :: var [, var [, ...]]
        """
        assert len(o.symbols) > 0
        types = [v.type for v in o.symbols]

        # Ensure all symbol types are equal, except for shape and dimension
        # TODO: We can't fully compare procedure types, yet, but we can make at least sure
        # names match and other declared attributes are compatible
        ignore = ['dtype', 'shape', 'dimensions', 'symbols', 'source', 'initial']
        assert all(isinstance(t.dtype, ProcedureType) for t in types)
        assert all(t.compare(types[0], ignore=ignore) for t in types)
        if isinstance(o.interface, DataType):
            assert all(t.dtype.return_type.dtype == o.interface for t in types)
        elif o.interface is not None:
            assert all(t.dtype.name == o.interface for t in types)

        if o.external:
            # This is an EXTERNAL statement (i.e., a kind of forward declaration)
            assert o.interface is None
            assert all(t.dtype.is_function for t in types) or all(not t.dtype.is_function for t in types)
            if types[0].dtype.is_function:
                # EXTERNAL statement for functions must include return_type
                assert all(t.dtype.return_type.compare(types[0].dtype.return_type) for t in types)
                attributes = [self.visit(types[0].dtype.return_type, **kwargs)]
            else:
                attributes = []
            # NB: no need to provide EXTERNAL here, as EXTERNAL is specified by visit_SymbolAttributes
        elif o.interface:
            # This is a PROCEDURE declaration with interface provided
            attributes = [f'PROCEDURE({self.visit(o.interface, **kwargs)})']
        elif o.module:
            attributes = ['MODULE PROCEDURE']
        elif o.generic:
            attributes = ['GENERIC']
        elif o.final:
            attributes = ['FINAL']
        else:
            # This is a PROCEDURE declaration without interface provided
            # (as they can appear in a derived type component declaration)
            attributes = ['PROCEDURE']

        decl_attrs = self.visit(types[0], **kwargs)
        if str(decl_attrs):
            attributes += [decl_attrs]

        symbols = []
        for v in o.symbols:
            var = self.visit(v, **kwargs)
            if v.type.initial is not None:
                symbols += [f'{var} => {self.visit(v.type.initial, **kwargs)}']
            elif v.type.bind_names is not None and o.interface is None:
                bind_names = [self.visit(n, **kwargs) for n in v.type.bind_names]
                symbols += [f'{var} => {self.join_items(bind_names)}']
            else:
                symbols += [var]

        comment = None
        if o.comment:
            comment = str(self.visit(o.comment, **kwargs))

        return self.format_line(
            self.join_items(attributes), ' :: ', self.join_items(symbols),
            comment=comment
        )

    def visit_DataDeclaration(self, o, **kwargs):
        """
        Format as
          DATA  /  /
        """
        values = self.visit_all(o.values, **kwargs)
        return self.format_line('DATA ', self.visit(o.variable, **kwargs), ' / ', self.join_items(values), ' /')

    def visit_StatementFunction(self, o, **kwargs):
        """
        Format as
          () = 
        """
        name = self.visit(o.variable, **kwargs)
        arguments = self.visit_all(o.arguments, **kwargs)
        rhs = self.visit(o.rhs, **kwargs)
        return self.format_line(name, '(', self.join_items(arguments), ') = ', rhs)

    def visit_Import(self, o, **kwargs):
        """
        Format imports according to their type as
          #include "..."
        or
          include "..."
        or
          USE [,  ::]  [, ONLY: ]
        or
          USE [,  ::]  [, ]
        or
          IMPORT 
        """
        if o.c_import:
            return f'#include "{o.module}"'
        if o.f_include:
            return self.format_line('include "', o.module, '"')

        if o.nature:
            use_stmt = f'USE, {o.nature} :: '
        else:
            use_stmt = 'USE '

        if o.rename_list:
            rename_list = [f'{self.visit(local, **kwargs)} => {use}' for use, local in o.rename_list]
            return self.format_line(use_stmt, o.module, ', ', self.join_items(rename_list))
        if not o.symbols:
            return self.format_line(use_stmt, o.module)

        symbols = []
        for s in o.symbols:
            if s.type.use_name:
                symbols += [f'{self.visit(s, **kwargs)} => {s.type.use_name}']
            else:
                symbols += [self.visit(s, **kwargs)]

        if o.f_import:
            return self.format_line('IMPORT ', self.join_items(symbols))
        return self.format_line(use_stmt, o.module, ', ONLY: ', self.join_items(symbols))

    def visit_Interface(self, o, **kwargs):
        """
        Format interface node as
          INTERFACE []
            ...body...
          END INTERFACE
        """
        if o.abstract:
            header = self.format_line('ABSTRACT INTERFACE')
            footer = self.format_line('END INTERFACE')
        elif o.spec:
            generic_spec = self.visit(o.spec, **kwargs)
            header = self.format_line('INTERFACE ', generic_spec)
            footer = self.format_line('END INTERFACE ', generic_spec)
        else:
            header = self.format_line('INTERFACE')
            footer = self.format_line('END INTERFACE')
        self.depth += self.style.indent_default
        body = self.visit(o.body, **kwargs)
        self.depth -= self.style.indent_default
        return self.join_lines(header, body, footer)

    def visit_Loop(self, o, **kwargs):
        """
        Format loop with explicit range as
          [name:] DO [label] =
            ...body...
          END DO [name]
        """
        pragma = self.visit(o.pragma, **kwargs)
        pragma_post = self.visit(o.pragma_post, **kwargs)
        control = f'{self.visit(o.variable, **kwargs)}={self.visit(o.bounds, **kwargs)}'
        header_name = f'{o.name}: ' if o.name else ''
        label = f'{o.loop_label} ' if o.loop_label else ''
        header = self.format_line(header_name, 'DO ', label, control)
        if o.has_end_do:
            footer_name = f' {o.name}' if o.name else ''
            enddo = 'END DO' if self.style.loop_end_space else 'ENDDO'
            footer = self.format_line(enddo, footer_name)
            footer = self.apply_label(footer, o.loop_label)
        else:
            footer = None
        self.depth += self.style.loop_indent
        body = self.visit(o.body, **kwargs)
        self.depth -= self.style.loop_indent
        return self.join_lines(pragma, header, body, footer, pragma_post)

    def visit_WhileLoop(self, o, **kwargs):
        """
        Format loop as
          [name:] DO [label] [WHILE ()]
            ...body...
          END DO [name]
        """
        pragma = self.visit(o.pragma, **kwargs)
        pragma_post = self.visit(o.pragma_post, **kwargs)
        control = ''
        if o.condition is not None:
            control = f' WHILE ({self.visit(o.condition, **kwargs)})'
        header_name = f'{o.name}: ' if o.name else ''
        label = f' {o.loop_label}' if o.loop_label else ''
        header = self.format_line(header_name, 'DO', label, control)
        if o.has_end_do:
            footer_name = f' {o.name}' if o.name else ''
            enddo = 'END DO' if self.style.loop_end_space else 'ENDDO'
            footer = self.format_line(enddo, footer_name)
            footer = self.apply_label(footer, o.loop_label)
        else:
            footer = None
        self.depth += self.style.loop_indent
        body = self.visit(o.body, **kwargs)
        self.depth -= self.style.loop_indent
        return self.join_lines(pragma, header, body, footer, pragma_post)

    def visit_Conditional(self, o, **kwargs):
        """
        Format conditional as
          IF () 
        or
          [name:] IF () THEN
            ...body...
          [ELSE IF () THEN [name]]
            [...body...]
          [ELSE [name]]
            [...body...]
          END IF [name]
        """
        if o.inline:
            # No indentation and only a single body node
            cond = self.visit(o.condition, **kwargs)
            d = self.depth
            self.depth = 0
            body = self.visit(o.body, **kwargs)
            self.depth = d
            # Undo the indentation, so that we may re-format and re-indent
            line = f'IF ({cond}) ' + ''.join(body.lstrip().split('&\n&'))
            return self.format_line(line)

        name = kwargs.pop('name', f' {o.name}' if o.name else '')
        is_elseif = kwargs.pop('is_elseif', False)

        if is_elseif:
            header = self.format_line('ELSE IF', ' (', self.visit(o.condition, **kwargs), ') THEN', name)
        else:
            header = f'{name[1:]}: IF' if name else 'IF'
            header = self.format_line(header, ' (', self.visit(o.condition, **kwargs), ') THEN')

        self.depth += self.style.conditional_indent
        body = self.visit(o.body, **kwargs)
        if o.has_elseif:
            self.depth -= self.style.conditional_indent
            else_body = [self.visit(o.else_body, is_elseif=True, name=name, **kwargs)]
        else:
            else_body = [self.visit(o.else_body, **kwargs)]
            self.depth -= self.style.conditional_indent
            if o.else_body:
                else_body = [self.format_line('ELSE', name)] + else_body
            endif = 'END IF' if self.style.conditional_end_space else 'ENDIF'
            else_body += [self.format_line(endif, name)]

        return self.join_lines(header, body, *else_body)

    def visit_MultiConditional(self, o, **kwargs):
        """
        Format as
          [name:] SELECT CASE ()
          CASE () [name]
            ...body...
          [CASE () [name]]
            [...body...]
          [CASE DEFAULT [name]]
            [...body...]
          END SELECT [name]
        """
        header_name = f'{o.name}: ' if o.name else ''
        header = self.format_line(header_name, 'SELECT CASE (', self.visit(o.expr, **kwargs), ')')
        cases = []
        name = f' {o.name}' if o.name else ''
        for value in o.values:
            case = self.visit_all(as_tuple(value), **kwargs)
            cases.append(self.format_line('CASE (', self.join_items(case), ')', name))
        if o.else_body:
            cases.append(self.format_line('CASE DEFAULT', name))
        footer = self.format_line('END SELECT', name)
        self.depth += self.style.indent_default
        bodies = self.visit_all(*o.bodies, o.else_body, **kwargs)
        self.depth -= self.style.indent_default
        branches = [item for branch in zip(cases, bodies) for item in branch]
        return self.join_lines(header, *branches, footer)

    def visit_TypeConditional(self, o, **kwargs):
        """
        Format as
          [name:] SELECT TYPE ()
          [CLASS IS () [name]]
            [...body...]
          [TYPE IS () [name]]
            [...body...]
          [CLASS DEFAULT [name]]
            [...body...]
          END SELECT [name]
        """
        header_name = f'{o.name}: ' if o.name else ''
        header = self.format_line(header_name, 'SELECT TYPE (', self.visit(o.expr, **kwargs), ')')
        cases = []
        name = f' {o.name}' if o.name else ''
        for value in o.values:
            case = self.visit(value[0], **kwargs)
            guard = 'CLASS' if value[1] else 'TYPE'
            cases.append(self.format_line(guard, ' IS (', case, ')', name))
        if o.else_body:
            cases.append(self.format_line('CLASS DEFAULT', name))
        footer = self.format_line('END SELECT', name)
        self.depth += self.style.indent_default
        bodies = self.visit_all(*o.bodies, o.else_body, **kwargs)
        self.depth -= self.style.indent_default
        branches = [item for branch in zip(cases, bodies) for item in branch]
        return self.join_lines(header, *branches, footer)

    def visit_Assignment(self, o, **kwargs):
        """
        Format statement as
           = 
        or
           => 
        """
        lhs = self.visit(o.lhs, **kwargs)
        rhs = self.visit(o.rhs, **kwargs)
        comment = None
        if o.comment:
            comment = f'  {self.visit(o.comment, **kwargs)}'
        if o.ptr:
            return self.format_line(lhs, ' => ', rhs, comment=comment)
        return self.format_line(lhs, ' = ', rhs, comment=comment)

    def visit_MaskedStatement(self, o, **kwargs):
        """
        Format masked assignment as
          WHERE ()
            ...body...
          [ELSEWHERE ()]
            []...body...]
          [ELSEWHERE]
            [...body...]
          END WHERE
        or
          WHERE () 
        """
        if o.inline:
            cond = self.visit(o.conditions[0], **kwargs)
            assignment = self.visit(o.bodies[0][0], **kwargs).strip()
            return self.format_line('WHERE (', cond, ') ', assignment)

        cases = [self.format_line('WHERE (', self.visit(o.conditions[0], **kwargs), ')')]
        for cond in o.conditions[1:]:
            cases += [self.format_line('ELSEWHERE (', self.visit(cond, **kwargs), ')')]
        if o.default:
            cases += [self.format_line('ELSEWHERE')]
        footer = self.format_line('END WHERE')

        self.depth += self.style.indent_default
        bodies = self.visit_all(*o.bodies, o.default, **kwargs)
        self.depth -= self.style.indent_default

        branches = [item for branch in zip(cases, bodies) for item in branch]
        return self.join_lines(*branches, footer)

    def visit_Forall(self, o, **kwargs):
        """
        Format FORALL element in one of two manners:
          1) Single-line FORALL statement (inlined):
            FORALL ( = [,  = ] ... [, ]) assign-stmt
          2) Multi-line FORALL construct:
            [name:] FORALL ( = [,  = ] ... [, ])
                ...body...
            END FORALL [name]

        Variable bounds with an optional mask condition constitute the "triplets" - specification list.
        """
        # Generate named bounds
        triplets = [f"{self.visit(variable, **kwargs)} = {self.visit(bound, **kwargs)}"
                    for variable, bound in o.named_bounds]
        # Generate optional mask
        if o.mask is not None:
            triplets.append(self.visit(o.mask, **kwargs))
        # Generate full header
        name = f"{o.name}: " if o.name is not None else ""
        header = self.format_line(name, "FORALL(", ", ".join(triplets), ")")
        # Generate a single-line FORALL statement with one assignment
        if o.inline:
            assignment = self.visit(o.body[0], **kwargs).lstrip()
            return f"{header} {assignment}"
        # Generate a multi-line FORALL construct
        name = f" {o.name}" if o.name is not None else ""
        footer = self.format_line('END FORALL', name)
        self.depth += self.style.indent_default
        body = self.visit(o.body, **kwargs)
        self.depth -= self.style.indent_default
        return self.join_lines(header, body, footer)

    def visit_Section(self, o, **kwargs):
        """
        Format the section's body.
        """
        return self.visit(o.body, **kwargs)

    def visit_Associate(self, o, **kwargs):
        """
        Format scope as
          ASSOCIATE ()
            ...body...
          END ASSOCIATE
        """
        assocs = [f'{self.visit(a[1], **kwargs)}=>{self.visit(a[0], **kwargs)}' for a in o.associations]
        header = self.format_line('ASSOCIATE (', self.join_items(assocs), ')')
        footer = self.format_line('END ASSOCIATE')
        self.depth += self.style.associate_indent
        body = self.visit(o.body, **kwargs)
        self.depth -= self.style.associate_indent
        return self.join_lines(header, body, footer)

    def visit_CallStatement(self, o, **kwargs):
        """
        Format call statement as
          CALL ()
        """
        pragma = self.visit(o.pragma, **kwargs)
        name = self.visit(o.name, **kwargs)
        args = self.visit_all(o.arguments, **kwargs)
        if o.kwarguments:
            args += tuple(f'{self.visit(arg[0], **kwargs)}={self.visit(arg[1], **kwargs)}' for arg in o.kwarguments)
        call = self.format_line('CALL ', name, '(', self.join_items(args), ')')
        return self.join_lines(pragma, call)

    def visit_Allocation(self, o, **kwargs):
        """
        Format allocation statement as
          ALLOCATE( [, SOURCE=])
        """
        items = self.visit_all(o.variables, **kwargs)
        if o.data_source is not None:
            items += (f'SOURCE={self.visit(o.data_source, **kwargs)}', )
        if o.status_var is not None:
            items += (f'STAT={self.visit(o.status_var, **kwargs)}', )
        return self.format_line('ALLOCATE (', self.join_items(items), ')')

    def visit_Deallocation(self, o, **kwargs):
        """
        Format de-allocation statement as
          DEALLOCATE()
        """
        items = self.visit_all(o.variables, **kwargs)
        if o.status_var is not None:
            items += (f'STAT={self.visit(o.status_var, **kwargs)}', )
        return self.format_line('DEALLOCATE (', self.join_items(items), ')')

    def visit_Nullify(self, o, **kwargs):
        """
        Format pointer nullification as
          NULLIFY()
        """
        items = self.visit_all(o.variables, **kwargs)
        return self.format_line('NULLIFY (', self.join_items(items), ')')

    def visit_SymbolAttributes(self, o, **kwargs):
        """
        Format declaration attributes as
          [()] [, ]
        """
        attributes = []

        if isinstance(o.dtype, ProcedureType):
            typename = ''
        elif isinstance(o.dtype, DerivedType):
            if o.polymorphic:
                typename = f'CLASS({o.dtype.name})'
            else:
                typename = f'TYPE({o.dtype.name})'
        else:
            typename = self.visit(o.dtype)

        selector = []
        if o.length:
            selector += [f'LEN={self.visit(o.length, **kwargs)}']
        if o.kind:
            selector += [f'KIND={self.visit(o.kind, **kwargs)}']
        if selector:
            typename += '(' + self.join_items(selector) + ')'

        if typename:
            attributes += [typename]

        if o.external:
            attributes += ['EXTERNAL']
        if o.save:
            attributes += ['SAVE']
        if o.allocatable:
            attributes += ['ALLOCATABLE']
        if o.pointer:
            attributes += ['POINTER']
        if o.value:
            attributes += ['VALUE']
        if o.optional:
            attributes += ['OPTIONAL']
        if o.parameter:
            attributes += ['PARAMETER']
        if o.target:
            attributes += ['TARGET']
        if o.contiguous:
            attributes += ['CONTIGUOUS']
        if o.intent:
            attributes += [f'INTENT({o.intent.upper()})']

        # Access spec
        if o.private:
            attributes += ['PRIVATE']
        if o.public:
            attributes += ['PUBLIC']
        if o.protected:
            attributes += ['PROTECTED']

        # Binding attributes
        if o.pass_attr is True:
            attributes += ['PASS']
        elif o.pass_attr is False:
            attributes += ['NOPASS']
        elif o.pass_attr is not None:
            attributes += [f'PASS({o.pass_attr!s})']
        if o.non_overridable:
            attributes += ['NON_OVERRIDABLE']
        if o.deferred:
            attributes += ['DEFERRED']

        return self.join_items(attributes)

    def visit_TypeDef(self, o, **kwargs):
        """
        Format type definition as
          TYPE [, BIND(c) ::] 
            ...declarations...
          END TYPE 
        """
        attrs = []
        if o.abstract:
            attrs += ['ABSTRACT']
        if o.extends:
            attrs += [f'EXTENDS({o.extends})']
        if o.bind_c:
            attrs += ['BIND(C)']
        if o.private:
            attrs += ['PRIVATE']
        if o.public:
            attrs += ['PUBLIC']

        if attrs:
            attrs = f', {self.join_items(attrs)} ::'
        else:
            attrs = ''
        header = self.format_line('TYPE', attrs, ' ', o.name)
        footer = self.format_line('END TYPE ', o.name)
        self.depth += self.style.indent_default
        body = self.visit(o.body, **kwargs)
        self.depth -= self.style.indent_default
        return self.join_lines(header, body, footer)

    def visit_BasicType(self, o, **kwargs):
        type_map = {BasicType.LOGICAL: 'LOGICAL', BasicType.INTEGER: 'INTEGER',
                    BasicType.REAL: 'REAL', BasicType.CHARACTER: 'CHARACTER',
                    BasicType.COMPLEX: 'COMPLEX', BasicType.DEFERRED: ''}
        return type_map[o]

    def visit_DerivedType(self, o, **kwargs):
        return o.name

    def visit_ProcedureType(self, o, **kwargs):
        return o.name

    def visit_Enumeration(self, o, **kwargs):
        """
        Format enum as
          ENUM, BIND(C)
            ENUMERATOR :: name [= value]
            ...
          END ENUM
        """
        header = self.format_line('ENUM, BIND(C)')
        footer = self.format_line('END ENUM')
        self.depth += self.style.indent_default
        body = []
        for var in o.symbols:
            name = self.visit(var, **kwargs)
            if var.type.initial is None:
                initial = ''
            else:
                initial = f' = {self.visit(var.type.initial, **kwargs)}'
            body += [self.format_line('ENUMERATOR :: ', name, initial)]
        self.depth -= self.style.indent_default
        return self.join_lines(header, *body, footer)


def fgen(ir, style=None, depth=0, conservative=False):
    """
    Generate standardized Fortran code from one or many IR objects/trees.
    """
    style = style if style else FortranStyle()

    if conservative:
        # pylint: disable=import-outside-toplevel,cyclic-import
        from loki.backend.fgencon import FortranCodegenConservative

        return FortranCodegenConservative(style=style, depth=depth).visit(ir) or ''

    return FortranCodegen(style=style, depth=depth).visit(ir) or ''


"""
Expose the expression generator for testing purposes.
"""
fexprgen = FCodeMapper()
loki-ecmwf-0.3.6/loki/backend/cgen.py0000664000175000017500000005263115167130205017563 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from operator import gt
from pymbolic.mapper.stringifier import (
    PREC_UNARY, PREC_LOGICAL_OR, PREC_LOGICAL_AND, PREC_NONE, PREC_CALL
)

from loki.backend.pprint import Stringifier
from loki.backend.style import DefaultStyle

from loki.expression import (
    symbols as sym, LokiStringifyMapper, Array, symbolic_op, Literal
)
from loki.ir import Import, FindNodes, FindVariables, FindRealLiterals
from loki.logging import warning
from loki.tools import as_tuple
from loki.types import BasicType, SymbolAttributes, DerivedType

__all__ = ['cgen', 'CCodegen', 'CCodeMapper', 'IntrinsicTypeC']


class IntrinsicTypeC:
    """
    Mapping Fortran type to corresponding C type.
    """

    # pylint: disable=unused-argument

    def __call__(self, _type, *args, **kwargs):
        return self.get_str_from_symbol_attr(_type, *args, **kwargs)

    def get_str_from_symbol_attr(self, _type, *args, **kwargs):
        if _type.dtype == BasicType.LOGICAL:
            return 'int'
        if _type.dtype == BasicType.INTEGER:
            return 'int'
        if _type.dtype == BasicType.REAL:
            if str(_type.kind) in ['real32']:
                return 'float'
            return 'double'
        raise ValueError(str(_type))

c_intrinsic_type = IntrinsicTypeC()

class CCodeMapper(LokiStringifyMapper):
    """
    A :class:`StringifyMapper`-derived visitor for Pymbolic expression trees that converts an
    expression to a string adhering to standardized C.
    """

    # pylint: disable=abstract-method, unused-argument

    def __init__(self, intrinsic_type_mapper, *args, **kwargs):
        super().__init__()
        self.intrinsic_type_mapper = intrinsic_type_mapper

    def map_logic_literal(self, expr, enclosing_prec, *args, **kwargs):
        return super().map_logic_literal(expr, enclosing_prec, *args, **kwargs).lower()

    def map_float_literal(self, expr, enclosing_prec, *args, **kwargs):
        if expr.kind is not None:
            _type = SymbolAttributes(BasicType.REAL, kind=expr.kind)
            return f'({self.intrinsic_type_mapper(_type)}) {str(expr.value)}'
        return str(expr.value)

    def map_int_literal(self, expr, enclosing_prec, *args, **kwargs):
        if expr.kind is not None:
            _type = SymbolAttributes(BasicType.INTEGER, kind=expr.kind)
            return f'({self.intrinsic_type_mapper(_type)}) {str(expr.value)}'
        return str(expr.value)

    def map_string_literal(self, expr, enclosing_prec, *args, **kwargs):
        return f'"{expr.value}"'

    def map_cast(self, expr, enclosing_prec, *args, **kwargs):
        _type = SymbolAttributes(BasicType.from_fortran_type(expr.name), kind=expr.kind)
        expression = self.parenthesize_if_needed(
            self.join_rec('', expr.parameters, PREC_NONE, *args, **kwargs),
            PREC_CALL, PREC_NONE)
        return self.parenthesize_if_needed(
            self.format('(%s) %s', self.intrinsic_type_mapper(_type), expression), enclosing_prec, PREC_CALL)

    def map_variable_symbol(self, expr, enclosing_prec, *args, **kwargs):
        if expr.parent is not None:
            parent = self.rec(expr.parent, PREC_NONE, *args, **kwargs)
            return self.format('%s.%s', parent, expr.basename)
        return self.format('%s', expr.name)

    def map_meta_symbol(self, expr, enclosing_prec, *args, **kwargs):
        return self.rec(expr._symbol, enclosing_prec, *args, **kwargs)

    map_scalar = map_meta_symbol
    map_array = map_meta_symbol

    def map_array_subscript(self, expr, enclosing_prec, *args, **kwargs):
        name_str = self.rec(expr.aggregate, PREC_NONE, *args, **kwargs)
        if expr.aggregate.type is not None:
            if expr.aggregate.type.pointer and name_str.startswith('*'):
                # Strip the pointer '*' because subscript dereference
                name_str = name_str[1:]
            index_str = ''
            for index in expr.index_tuple:
                d = self.format(self.rec(index, PREC_NONE, *args, **kwargs))
                if d:
                    index_str += self.format('[%s]', d)
            return self.format('%s%s', name_str, index_str)
        return self.format('%s', name_str)

    map_string_subscript = map_array_subscript

    def map_logical_not(self, expr, enclosing_prec, *args, **kwargs):
        return self.parenthesize_if_needed(
            "!" + self.rec(expr.child, PREC_UNARY, *args, **kwargs),
            enclosing_prec, PREC_UNARY)

    def map_logical_or(self, expr, enclosing_prec, *args, **kwargs):
        return self.parenthesize_if_needed(
            self.join_rec(" || ", expr.children, PREC_LOGICAL_OR, *args, **kwargs),
            enclosing_prec, PREC_LOGICAL_OR)

    def map_logical_and(self, expr, enclosing_prec, *args, **kwargs):
        return self.parenthesize_if_needed(
            self.join_rec(" && ", expr.children, PREC_LOGICAL_AND, *args, **kwargs),
            enclosing_prec, PREC_LOGICAL_AND)

    def map_range_index(self, expr, enclosing_prec, *args, **kwargs):
        return self.rec(expr.upper, enclosing_prec, *args, **kwargs) if expr.upper else ''

    def map_power(self, expr, enclosing_prec, *args, **kwargs):
        return self.parenthesize_if_needed(
            self.format('pow(%s, %s)', self.rec(expr.base, PREC_NONE, *args, **kwargs),
                        self.rec(expr.exponent, PREC_NONE, *args, **kwargs)),
            enclosing_prec, PREC_NONE)

    def map_c_reference(self, expr, enclosing_prec, *args, **kwargs):
        return self.format(' (&%s)', self.rec(expr.expression, PREC_NONE, *args, **kwargs))

    def map_c_dereference(self, expr, enclosing_prec, *args, **kwargs):
        return self.format(' (*%s)', self.rec(expr.expression, PREC_NONE, *args, **kwargs))

    def map_inline_call(self, expr, enclosing_prec, *args, **kwargs):
        if expr.function.name.lower() == 'mod':
            parameters = [self.rec(param, PREC_NONE, *args, **kwargs) for param in expr.parameters]
            # TODO: this check is not quite correct, as it should evaluate the
            #  expression(s) of both arguments/parameters and choose the integer version of modulo ('%')
            #  instead of the floating-point version ('fmod')
            #  whenever the mentioned evaluations result in being of kind 'integer' ...
            #  as an example: 'celing(3.1415)' got an floating point value in it, however it evaluates/returns
            #  an integer, in that case the wrong modulo function/operation is chosen
            if any(var.type.dtype != BasicType.INTEGER for var in FindVariables().visit(expr.parameters)) or\
                    FindRealLiterals().visit(expr.parameters):
                return f'fmod({parameters[0]}, {parameters[1]})'
            return f'({parameters[0]})%({parameters[1]})'

        if expr.function.name.lower() == 'present':
            return self.format('true /*ATTENTION: present({%s})*/', expr.parameters[0].name)

        return super().map_inline_call(expr, enclosing_prec, *args, **kwargs)


class CCodegen(Stringifier):
    """
    Tree visitor to generate standardized C code from IR.
    """
    # pylint: disable=abstract-method, unused-argument

    standard_imports = ['stdio.h', 'stdbool.h', 'float.h', 'math.h']

    def __init__(self, style, depth=0, **kwargs):
        symgen = kwargs.get('symgen', CCodeMapper(c_intrinsic_type))
        line_cont = kwargs.get('line_cont', '\n{}  '.format)
        super().__init__(style=style, depth=depth, line_cont=line_cont, symgen=symgen)

    # Handler for outer objects

    def visit_Sourcefile(self, o, **kwargs):
        """
        Format as
          ...modules...
          ...subroutines...
        """
        return self.visit(o.ir, **kwargs)

    def visit_Module(self, o, **kwargs):
        # Assuming this will be put in header files...
        spec = self.visit(o.spec, **kwargs)
        routines = self.visit(o.routines, **kwargs)
        return self.join_lines(spec, routines)

    def _subroutine_header(self, o, **kwargs):
        """
        Helper function/header for :func:`~loki.backend.CCodegen.visit_Subroutine`.
        """
        # Some boilerplate imports...
        header = [self.format_line('#include <', name, '>') for name in self.standard_imports]
        # ...and imports from the spec
        spec_imports = FindNodes(Import).visit(o.spec)
        header += [self.visit(spec_imports, **kwargs)]
        return header

    def _subroutine_argument_keyword(self, a):
        return ''

    def _subroutine_argument_pass_by(self, a):
        if isinstance(a, Array):
            return '* restrict '
        if isinstance(a.type.dtype, DerivedType):
            return '*'
        if a.type.pointer:
            return '*'
        if a.type.optional:
            return '*'
        return ''

    def _subroutine_optional_args(self, a):
        if a.type.optional:
            warning(f'Argument "{a}" is optional! No support for optional arguments in {self.__class__.__name__}.')
        return ''

    def _subroutine_declaration(self, o, **kwargs):
        """
        Helper function/function declaration part for :func:`~loki.backend.CCodegen.visit_Subroutine`.
        """
        # pass_by, var_keywords = self._subroutine_arguments(o, **kwargs)
        # arguments = [f'{k}{self.visit(a.type, **kwargs)} {p}{a.name}'
        #              for a, p, k in zip(o.arguments, pass_by, var_keywords)]
        arguments = [
            (f'{self._subroutine_argument_keyword(a)}{self.visit(a.type, **kwargs)} '
            f'{self._subroutine_argument_pass_by(a)}{a.name}{self._subroutine_optional_args(a)}')
            for a in o.arguments
        ]
        opt_header = kwargs.get('header', False)
        end = ' {' if not opt_header else ';'
        # check whether to return something and define function return type accordingly
        if o.is_function:
            return_type = self.symgen.intrinsic_type_mapper(o.return_type)
        else:
            return_type = 'void'
        declaration = [self.format_line(f'{return_type} ', o.name, '(', self.join_items(arguments), ')', end)]
        return declaration

    def _subroutine_body(self, o, **kwargs):
        """
        Helper function/body for :func:`~loki.backend.CCodegen.visit_Subroutine`.
        """
        self.depth += 1

        # ...and generate the spec without imports and argument declarations
        body = [self.visit(o.spec, skip_imports=True, skip_argument_declarations=True, **kwargs)]

        # Fill the body
        body += [self.visit(o.body, **kwargs)]

        # if something to be returned, add 'return ' statement
        if o.is_function and o.result_name is not None:
            body += [self.format_line(f'return {o.result_name.lower()};')]

        # Close everything off
        self.depth -= 1
        return body

    def _subroutine_footer(self, o, **kwargs):
        """
        Helper function/footer for :func:`~loki.backend.CCodegen.visit_Subroutine`.
        """
        footer = [self.format_line('}')]
        return footer

    def visit_Interface(self, o, **kwargs):
        return None

    def visit_Subroutine(self, o, **kwargs):
        """
        Format as:

          ...imports...
           () {
            ...spec without imports and argument declarations...
            ...body...
          }
        """
        opt_header = kwargs.get('header', False)
        opt_guards = kwargs.get('guards', False)
        opt_guard_name = kwargs.get('guard_name', None)

        header = self._subroutine_header(o, **kwargs)
        declaration = self._subroutine_declaration(o, **kwargs)
        body = self._subroutine_body(o, **kwargs) if not opt_header else []
        footer = self._subroutine_footer(o, **kwargs) if not opt_header else []

        if opt_guards:
            guard_name = f'{o.name.upper()}_H' if opt_guard_name is None else opt_guard_name
            header = [self.format_line(f'#ifndef {guard_name}'), self.format_line(f'#define {guard_name}\n\n')] + header
            footer += ['\n#endif']

        return self.join_lines(*header, '\n', *declaration, *body, *footer)

    visit_Function = visit_Subroutine

    # Handler for AST base nodes

    def visit_Node(self, o, **kwargs):
        """
        Format non-supported nodes as
          // 
        """
        return self.format_line('// <', repr(o), '>')

    # Handler for IR nodes

    def visit_Intrinsic(self, o, **kwargs):  # pylint: disable=unused-argument
        """
        Format intrinsic nodes.
        """
        return self.format_line(str(o.text).lstrip())

    def visit_Comment(self, o, **kwargs):  # pylint: disable=unused-argument
        """
        Format comments.
        """
        text = o.text
        if text is None and o.source:
            text = o.source.string
        text = str(text).lstrip().replace('!', '//', 1)
        return self.format_line(text, no_wrap=True)

    def visit_CommentBlock(self, o, **kwargs):
        """
        Format comment blocks.
        """
        comments = self.visit_all(o.comments, **kwargs)
        return self.join_lines(*comments)

    def visit_VariableDeclaration(self, o, **kwargs):
        """
        Format declaration as
            [= ]
        """
        types = [v.type for v in o.symbols]
        # Ensure all variable types are equal, except for shape and dimension
        ignore = ['shape', 'dimensions', 'source']
        assert all(t.compare(types[0], ignore=ignore) for t in types)
        dtype = self.visit(types[0], **kwargs)
        assert len(o.symbols) > 0
        variables = []
        for v in o.symbols:
            if kwargs.get('skip_argument_declarations') and v.type.intent:
                continue
            var = self.visit(v, **kwargs)
            initial = ''
            if v.initial is not None:
                initial = f' = {self.visit(v.initial, **kwargs)}'
            if v.type.pointer or v.type.allocatable:
                var = '*' + var
            variables += [f'{var}{initial}']
        if not variables:
            return None
        comment = None
        if o.comment:
            comment = str(self.visit(o.comment, **kwargs))
        return self.format_line(dtype, ' ', self.join_items(variables), ';', comment=comment)

    def visit_Import(self, o, **kwargs):  # pylint: disable=unused-argument
        """
        Format C imports as
          #include ""
        """
        if not kwargs.get('skip_imports') and o.c_import:
            return self.format_line('#include "', str(o.module), '"')
        return None

    def visit_Loop(self, o, **kwargs):
        """
        Format loop with explicit range as
          for (=; ;  += ) {
            ...body...
          }
        """
        control = 'for ({var} = {start}; {var} {crit} {end}; {var} += {incr})'.format(
            var=self.visit(o.variable, **kwargs), start=self.visit(o.bounds.start, **kwargs),
            end=self.visit(o.bounds.stop, **kwargs),
            crit='<=' if not o.bounds.step or symbolic_op(o.bounds.step, gt, Literal(0)) else '>=',
            incr=self.visit(o.bounds.step, **kwargs) if o.bounds.step else 1)
        header = self.format_line(control, ' {')
        footer = self.format_line('}')
        self.depth += 1
        body = self.visit(o.body, **kwargs)
        self.depth -= 1
        return self.join_lines(header, body, footer)

    def visit_WhileLoop(self, o, **kwargs):
        """
        Format loop as
          while () {
            ...body...
          }
        """
        if o.condition is not None:
            condition = self.visit(o.condition, **kwargs)
        else:
            condition = '1'
        header = self.format_line('while (', condition, ') {')
        footer = self.format_line('}')
        self.depth += 1
        body = self.visit(o.body, **kwargs)
        self.depth -= 1
        return self.join_lines(header, body, footer)

    def visit_Conditional(self, o, **kwargs):
        """
        Format conditional as
          if () {
            ...body...
          [ } else if () { ]
            [...body...]
          [ } else { ]
            [...body...]
          }
        """
        is_elseif = kwargs.pop('is_elseif', False)
        if is_elseif:
            header = self.format_line('} else if (', self.visit(o.condition, **kwargs), ') {')
        else:
            header = self.format_line('if (', self.visit(o.condition, **kwargs), ') {')
        self.depth += 1
        body = self.visit(o.body, **kwargs)
        if o.has_elseif:
            self.depth -= 1
            else_body = [self.visit(o.else_body, is_elseif=True, **kwargs)]
        else:
            else_body = [self.visit(o.else_body, **kwargs)]
            self.depth -= 1
            if o.else_body:
                else_body = [self.format_line('} else {')] + else_body
            else_body += [self.format_line('}')]
        return self.join_lines(header, body, *else_body)

    def visit_Assignment(self, o, **kwargs):
        """
        Format statement as
           =  []
        """
        lhs = self.visit(o.lhs, **kwargs)
        rhs = self.visit(o.rhs, **kwargs)
        comment = None
        if o.comment:
            comment = f'  {self.visit(o.comment, **kwargs)}'
        return self.format_line(lhs, ' = ', rhs, ';', comment=comment)

    def visit_Section(self, o, **kwargs):
        """
        Format the section's body.
        """
        return self.visit(o.body, **kwargs)

    def visit_CallStatement(self, o, **kwargs):
        """
        Format call statement as
          ()
        """
        args = self.visit_all(o.arguments, **kwargs)
        assert not o.kwarguments
        return self.format_line(str(o.name), '(', self.join_items(args), ');')

    def visit_SymbolAttributes(self, o, **kwargs):  # pylint: disable=unused-argument
        if isinstance(o.dtype, DerivedType):
            return f'struct {o.dtype.name}'
        return self.symgen.intrinsic_type_mapper(o)

    def visit_TypeDef(self, o, **kwargs):
        """
        Format type definition/struct as
          struct  {
            ...declarations...
          };
        """
        header = self.format_line('struct ', o.name.lower(), ' {')
        footer = self.format_line('};')
        self.depth += 1
        decls = self.visit(o.declarations, **kwargs)
        self.depth -= 1
        return self.join_lines(header, decls, footer)

    def visit_MultiConditional(self, o, **kwargs):
        """
        Format as
          switch case () {
          case :
          {
            ...body...
          }
          [case :]
          {
            [...body...]
          }
          [default:] {
            [...body...]
          }
          }

        E.g., the following

        select case (in)
            case (:2)
                out = 1
            case (4, 5, 7:9)
                out = 2
            case (6)
                out = 3
            case default
                out = 4
        end select

        becomes

        switch (in) {
            case 0:
            case 1:
            case 2:
            {
              out = 1;
              break;
            }
            case 4:
            case 5:
            case 7:
            case 8:
            case 9:
            {
              out = 2;
              break;
            }
            case 6:
            {
              out = 3;
              break;
            }
            default:
            {
              out = 4;
              breal;
            }
        }
        """
        header = self.format_line('switch (', self.visit(o.expr, **kwargs), ') {')
        cases = []
        end_cases = []
        for value in o.values:
            sub_cases = []
            for val in value:
                if not isinstance(val, sym.RangeIndex):
                    sub_cases.append(self.visit(val, **kwargs))
                else:
                    assert (val.lower is None or isinstance(val.lower, sym.IntLiteral))\
                            and isinstance(val.upper, sym.IntLiteral)
                    lower = val.lower.value if val.lower is not None else 0
                    sub_cases.extend([str(v) for v in list(range(lower, val.upper.value + 1))])
            case = ()
            for sub_case in sub_cases:
                case += (self.format_line('case ', self.join_items(as_tuple(sub_case)), ':'),)
            cases.append(self.join_lines(*case, self.format_line('{')))
            end_cases.append(self.join_lines(self.format_line('break;'), self.format_line('}')))
        if o.else_body:
            cases.append(self.join_lines(self.format_line('default: '), self.format_line('{')))
            end_cases.append(self.join_lines(self.format_line('break;'), self.format_line('}')))
        footer = self.format_line('}')
        self.depth += 1
        bodies = self.visit_all(*o.bodies, o.else_body, **kwargs)
        self.depth -= 1
        branches = [item for branch in zip(cases, bodies, end_cases) for item in branch]
        return self.join_lines(header, *branches, footer)


def cgen(ir, **kwargs):
    """
    Generate standardized C code from one or many IR objects/trees.
    """
    style = kwargs.pop('style', DefaultStyle())
    depth = kwargs.pop('depth', 0)
    return CCodegen(style=style, depth=depth).visit(ir, **kwargs)
loki-ecmwf-0.3.6/loki/backend/cudagen.py0000664000175000017500000001225015167130205020246 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.backend.cppgen import CppCodegen, CppCodeMapper, IntrinsicTypeCpp
from loki.backend.style import DefaultStyle

from loki.ir import Import, FindNodes
from loki.expression import Array

__all__ = ['cudagen', 'CudaCodegen', 'CudaCodeMapper']


class IntrinsicTypeCuda(IntrinsicTypeCpp):
    """
    Mapping Fortran type to corresponding CUDA type.
    """
    # pylint: disable=unnecessary-pass
    pass

cuda_intrinsic_type = IntrinsicTypeCuda()


class CudaCodeMapper(CppCodeMapper):
    """
    A :class:`StringifyMapper`-derived visitor for Pymbolic expression trees that converts an
    expression to a string adhering to standardized CUDA.
    """
    # pylint: disable=abstract-method, unused-argument, unnecessary-pass
    pass


class CudaCodegen(CppCodegen):
    """
    Tree visitor to generate standardized CUDA code from IR.
    """

    standard_imports = ['stdio.h', 'stdbool.h', 'float.h',
            'math.h', 'cuda.h', 'cuda_runtime.h']

    def __init__(self, depth=0, indent='  ', linewidth=90, **kwargs):
        symgen = kwargs.pop('symgen', CudaCodeMapper(cuda_intrinsic_type))
        super().__init__(depth=depth, indent=indent, linewidth=linewidth,
                         symgen=symgen, **kwargs)

    def _subroutine_header(self, o, **kwargs):
        opt_header = kwargs.get('header', False)
        opt_extern = kwargs.get('extern', False)
        if opt_header or opt_extern:
            header = []
        else:
            # Some boilerplate imports...
            header = [self.format_line('#include <', name, '>') for name in self.standard_imports]
            # ...and imports from the spec
            spec_imports = FindNodes(Import).visit(o.spec)
            header += [self.visit(spec_imports, **kwargs)]
        if o.prefix and "global" in o.prefix[0].lower():
            # include launcher and header file
            header += [self.format_line('')]
            if not opt_header:
                header += [self.format_line('#include "', o.name, '.h', '"')]
                header += [self.format_line('#include "', o.name, '_launch.h', '"')]
        return header

    def _subroutine_argument_pass_by(self, a):
        if isinstance(a, Array):
            return '* __restrict__ '
        return super()._subroutine_argument_pass_by(a)

    def _subroutine_declaration(self, o, **kwargs):
        arguments = [
            (f'{self._subroutine_argument_keyword(a)}{self.visit(a.type, **kwargs)} '
            f'{self._subroutine_argument_pass_by(a)}{a.name}')
            for a in o.arguments
        ]
        opt_header = kwargs.get('header', False)
        end = ' {' if not opt_header else ';'
        prefix = ''
        if o.prefix and "global" in o.prefix[0].lower():
            prefix = '__global__ '
        if o.prefix and "device" in o.prefix[0].lower():
            prefix = '__device__ '
        if o.is_function:
            return_type = self.symgen.intrinsic_type_mapper(o.return_type)
        else:
            return_type = 'void'
        opt_extern = kwargs.get('extern', False)
        declaration = [self.format_line('extern "C" {\n')] if opt_extern else []
        declaration += [self.format_line(prefix, f'{return_type} ', o.name, '(', self.join_items(arguments), ')', end)]
        return declaration

    def _subroutine_body(self, o, **kwargs):
        self.depth += 1
        # ...and generate the spec without imports and argument declarations
        body = [self.visit(o.spec, skip_imports=True, skip_argument_declarations=True, **kwargs)]
        # Fill the body
        body += [self.visit(o.body, **kwargs)]
        opt_extern = kwargs.get('extern', False)
        if opt_extern:
            body += [self.format_line('cudaDeviceSynchronize();')]
        # if something to be returned, add 'return ' statement
        if o.is_function and o.result_name is not None:
            body += [self.format_line(f'return {o.result_name.lower()};')]
        # Close everything off
        self.depth -= 1
        return body

    def _subroutine_footer(self, o, **kwargs):
        postfix = ''
        opt_extern = kwargs.get('extern', False)
        footer = [self.format_line('}'), self.format_line(postfix)]
        footer += [self.format_line('\n} // extern')] if opt_extern else []
        return footer

    def visit_CallStatement(self, o, **kwargs):
        args = self.visit_all(o.arguments, **kwargs)
        if o.kwarguments:
            raise RuntimeError(f'Keyword arguments in call to {o.name} not supported in CUDA code.')
        chevron = f'<<<{",".join([str(elem) for elem in o.chevron])}>>>' if o.chevron is not None else ''
        return self.format_line(str(o.name), chevron, '(', self.join_items(args), ');')


def cudagen(ir, **kwargs):
    """
    Generate standardized CUDA code from one or many IR objects/trees.
    """
    style = kwargs.pop('style', DefaultStyle())
    depth = kwargs.pop('depth', 0)
    return CudaCodegen(style=style, depth=depth).visit(ir, **kwargs)
loki-ecmwf-0.3.6/loki/subroutine.py0000664000175000017500000004114515167130205017455 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.expression import symbols as sym
from loki.frontend import (
    parse_omni_ast, parse_fparser_ast, get_fparser_node,
    parse_regex_source
)
from loki.ir import (
    nodes as ir, FindNodes, Transformer, ExpressionTransformer,
    pragmas_attached
)
from loki.logging import debug
from loki.program_unit import ProgramUnit
from loki.tools import as_tuple, CaseInsensitiveDict
from loki.types import BasicType, ProcedureType, SymbolAttributes



__all__ = ['Subroutine']


class Subroutine(ProgramUnit):
    """
    Class to handle and manipulate a single subroutine.

    Parameters
    ----------
    name : str
        Name of the subroutine.
    args : iterable of str, optional
        The names of the dummy args.
    docstring : tuple of :any:`Node`, optional
        The subroutine docstring in the original source.
    spec : :any:`Section`, optional
        The spec of the subroutine.
    body : :any:`Section`, optional
        The body of the subroutine.
    contains : :any:`Section`, optional
        The internal-subprogram part following a ``CONTAINS`` statement
        declaring member procedures
    prefix : iterable, optional
        Prefix specifications for the procedure
    bind : optional
        Bind information (e.g., for Fortran ``BIND(C)`` annotation).
    ast : optional
        Frontend node for this subroutine (from parse tree of the frontend).
    source : :any:`Source`
        Source object representing the raw source string information from the
        read file.
    parent : :any:`Scope`, optional
        The enclosing parent scope of the subroutine, typically a :any:`Module`
        or :any:`Subroutine` object. Declarations from the parent scope remain
        valid within the subroutine's scope (unless shadowed by local
        declarations).
    rescope_symbols : bool, optional
        Ensure that the type information for all :any:`TypedSymbol` in the
        subroutine's IR exist in the subroutine's scope or the scope's parents.
        Defaults to `False`.
    symbol_attrs : :any:`SymbolTable`, optional
        Use the provided :any:`SymbolTable` object instead of creating a new
    incomplete : bool, optional
        Mark the object as incomplete, i.e. only partially parsed. This is
        typically the case when it was instantiated using the :any:`Frontend.REGEX`
        frontend and a full parse using one of the other frontends is pending.
    parser_classes : :any:`RegexParserClass`, optional
        Provide the list of parser classes used during incomplete regex parsing
    """

    is_function = False

    def __init__(self, *args, parent=None, symbol_attrs=None, **kwargs):
        super().__init__(parent=parent)

        if symbol_attrs:
            self.symbol_attrs.update(symbol_attrs)

        self.__initialize__(*args, **kwargs)

    def __initialize__(
            self, name, docstring=None, spec=None, contains=None,
            ast=None, source=None, rescope_symbols=False, incomplete=False,
            parser_classes=None, body=None, args=None, prefix=None, bind=None,
    ):
        # First, store additional Subroutine-specific properties
        self._dummies = as_tuple(a.lower() for a in as_tuple(args))  # Order of dummy arguments
        self.prefix = as_tuple(prefix)
        self.bind = bind

        # Additional IR components
        if body is not None and not isinstance(body, ir.Section):
            body = ir.Section(body=body)
        self.body = body

        super().__initialize__(
            name=name, docstring=docstring, spec=spec, contains=contains,
            ast=ast, source=source, rescope_symbols=rescope_symbols,
            incomplete=incomplete, parser_classes=parser_classes
        )

    def __getstate__(self):
        _ignore = ('_ast', '_parent')
        return dict((k, v) for k, v in self.__dict__.items() if k not in _ignore)

    def __setstate__(self, s):
        self.__dict__.update(s)

        self._ast = None

        # Re-register all encapulated member procedures
        for member in self.members:
            self.symbol_attrs[member.name] = SymbolAttributes(ProcedureType(procedure=member))

        # Ensure that we are attaching all symbols to the newly create ``self``.
        self.rescope_symbols()

    @classmethod
    def from_omni(cls, ast, raw_source, definitions=None, parent=None, type_map=None):
        """
        Create :any:`Subroutine` from :any:`OMNI` parse tree

        Parameters
        ----------
        ast :
            The OMNI parse tree
        raw_source : str
            Fortran source string
        definitions : list
            List of external :any:`Module` to provide derived-type and procedure declarations
        parent : :any:`Scope`, optional
            The enclosing parent scope of the subroutine, typically a :any:`Module`.
        type_map : dict, optional
            A mapping from type hash identifiers to type definitions, as provided in
            OMNI's ``typeTable`` parse tree node
        """
        type_map = type_map or {}
        if ast.tag != 'FfunctionDefinition':
            ast = ast.find('globalDeclarations/FfunctionDefinition')
        return parse_omni_ast(
            ast=ast, definitions=definitions, raw_source=raw_source,
            type_map=type_map, scope=parent
        )

    @classmethod
    def from_fparser(cls, ast, raw_source, definitions=None, pp_info=None, parent=None):
        """
        Create :any:`Subroutine` from :any:`FP` parse tree

        Parameters
        ----------
        ast :
            The FParser parse tree
        raw_source : str
            Fortran source string
        definitions : list
            List of external :any:`Module` to provide derived-type and procedure declarations
        pp_info :
            Preprocessing info as obtained by :any:`sanitize_input`
        parent : :any:`Scope`, optional
            The enclosing parent scope of the subroutine, typically a :any:`Module`.
        """
        if ast.__class__.__name__ not in ('Subroutine_Subprogram', 'Function_Subprogram'):
            ast = get_fparser_node(ast, ('Subroutine_Subprogram', 'Function_Subprogram'))
        # Note that our Fparser interface returns a tuple with the
        # Subroutine object always last but potentially containing
        # comments before the Subroutine object
        return parse_fparser_ast(
            ast, pp_info=pp_info, definitions=definitions,
            raw_source=raw_source, scope=parent
        )[-1]

    @classmethod
    def from_regex(cls, raw_source, parser_classes=None, parent=None):
        """
        Create :any:`Subroutine` from source regex'ing

        Parameters
        ----------
        raw_source : str
            Fortran source string
        parent : :any:`Scope`, optional
            The enclosing parent scope of the subroutine, typically a :any:`Module`.
        """
        ir_ = parse_regex_source(raw_source, parser_classes=parser_classes, scope=parent)
        return [node for node in ir_.body if isinstance(node, cls)][0]

    def register_in_parent_scope(self):
        """
        Insert the type information for this object in the parent's symbol table

        If :attr:`parent` is `None`, this does nothing.
        """
        if self.parent:
            self.parent.symbol_attrs[self.name] = SymbolAttributes(self.procedure_type)

    def clone(self, **kwargs):
        """
        Create a copy of the subroutine with the option to override individual
        parameters.

        Parameters
        ----------
        **kwargs :
            Any parameters from the constructor of :any:`Subroutine`.

        Returns
        -------
        :any:`Subroutine`
            The cloned subroutine object.
        """
        # Collect all properties bespoke to Subroutine
        if self.argnames and 'args' not in kwargs:
            kwargs['args'] = self.argnames
        if self.body and 'body' not in kwargs:
            kwargs['body'] = self.body
        if self.prefix and 'prefix' not in kwargs:
            kwargs['prefix'] = self.prefix
        if self.bind and 'bind' not in kwargs:
            kwargs['bind'] = self.bind

        # Rebuild body (other IR components are taken care of in super class)
        if 'body' in kwargs:
            kwargs['body'] = Transformer({}, rebuild_scopes=True).visit(kwargs['body'])

        # Escalate to parent class
        return super().clone(**kwargs)

    @property
    def _canonical(self):
        """
        Base definition for comparing :any:`Subroutine` objects.
        """
        return (
            self.name, self._dummies, self.prefix, self.bind,
            self.docstring, self.spec, self.body, self.contains,
            self.symbol_attrs
        )

    def __eq__(self, other):
        if isinstance(other, Subroutine):
            return self._canonical == other._canonical
        return super().__eq__(other)

    def __hash__(self):
        return hash(self._canonical)

    @property
    def procedure_symbol(self):
        """
        Return the procedure symbol for this subroutine
        """
        return sym.Variable(name=self.name, type=SymbolAttributes(self.procedure_type), scope=self.parent)

    @property
    def procedure_type(self):
        """
        Return the :any:`ProcedureType` of this subroutine
        """
        return ProcedureType(procedure=self)

    variables = ProgramUnit.variables

    @variables.setter
    def variables(self, variables):
        """
        Set the variables property and ensure that the internal declarations match.

        Note that arguments also count as variables and therefore any
        removal from this list will also remove arguments from the subroutine signature.
        """
        # Use the parent's property setter
        ProgramUnit.variables.__set__(self, variables) # pylint: disable=unnecessary-dunder-call,no-member

        # Filter the dummy list in case we removed an argument
        varnames = [str(v.name).lower() for v in variables]
        self._dummies = as_tuple(arg for arg in self._dummies if str(arg).lower() in varnames)

    @property
    def arguments(self):
        """
        Return arguments in order of the defined signature (dummy list).
        """

        #Load symbol_map
        #Note that if the map is not loaded, Python will recreate it for every arguement,
        #resulting in a large overhead.
        symbol_map = self.symbol_map
        return as_tuple(symbol_map.get(arg, sym.Variable(name=arg)) for arg in self._dummies)

    @arguments.setter
    def arguments(self, arguments):
        """
        Set the arguments property and ensure that internal declarations and signature match.

        Note that removing arguments from this property does not actually remove declarations.
        """
        # FIXME: This will fail if one of the argument is declared via an interface!

        # First map variables to existing declarations
        declarations = FindNodes((ir.VariableDeclaration, ir.ProcedureDeclaration)).visit(self.spec)
        decl_map = dict((v, decl) for decl in declarations for v in decl.symbols)

        arguments = as_tuple(arguments)
        for arg in arguments:
            if arg not in decl_map:
                # By default, append new variables to the end of the spec
                assert arg.type.intent is not None
                if isinstance(arg.type, ProcedureType):
                    new_decl = ir.ProcedureDeclaration(symbols=(arg, ))
                else:
                    new_decl = ir.VariableDeclaration(symbols=(arg, ))
                self.spec.append(new_decl)

        # Set new dummy list according to input
        self._dummies = as_tuple(arg.name.lower() for arg in arguments)

    @property
    def argnames(self):
        """
        Return names of arguments in order of the defined signature (dummy list)
        """
        return [a.name for a in self.arguments]

    members = ProgramUnit.subroutines

    @property
    def ir(self):
        """
        All components of the intermediate representation in this subroutine
        """
        return (self.docstring, self.spec, self.body, self.contains)

    @property
    def interface(self):
        """
        Interface object that defines the `Subroutine` signature in header files.
        """

        # Remove all local variable declarations from interface routine spec
        # and duplicate all argument symbols within a new subroutine scope
        arg_names = [arg.name for arg in self.arguments]
        routine = Subroutine(name=self.name, args=arg_names, spec=None, body=None)
        decl_map = {}
        for decl in FindNodes((ir.VariableDeclaration, ir.ProcedureDeclaration)).visit(self.spec):
            if any(v.name in arg_names for v in decl.symbols):
                assert all(v.name in arg_names and v.type.intent is not None for v in decl.symbols), \
                    "Declarations must have intents and dummy and local arguments cannot be mixed."
                # Replicate declaration with re-scoped variables
                variables = as_tuple(v.clone(scope=routine) for v in decl.symbols)
                decl_map[decl] = decl.clone(symbols=variables)
            else:
                decl_map[decl] = None  # Remove local variable declarations
        routine.spec = Transformer(decl_map).visit(self.spec)
        return ir.Interface(body=(routine,))

    def enrich(self, definitions, recurse=False):
        """
        Apply :any:`ProgramUnit.enrich` and expand enrichment to calls declared
        via interfaces

        Parameters
        ----------
        definitions : list of :any:`ProgramUnit`
            A list of all available definitions
        recurse : bool, optional
            Enrich contained scopes
        """
        # First, enrich imported symbols
        super().enrich(definitions, recurse=recurse)

        # Secondly, take care of procedures that are declared via interface block includes
        # and therefore are not discovered via module imports
        definitions_map = CaseInsensitiveDict((r.name, r) for r in as_tuple(definitions))
        with pragmas_attached(self, ir.CallStatement, attach_pragma_post=False):
            for call in FindNodes(ir.CallStatement).visit(self.body):

                # Clone symbol to ensure Deferred symbols are
                # recognised ProcedureSymbols
                symbol = call.name.clone()
                routine = definitions_map.get(symbol.name)

                if not routine and symbol.parent:
                    # Type-bound procedure: try to obtain procedure from typedef
                    if (dtype := symbol.parent.type.dtype) is not BasicType.DEFERRED:
                        if (typedef := dtype.typedef) is not BasicType.DEFERRED:
                            if proc_symbol := typedef.variable_map.get(symbol.name_parts[-1]):
                                if (dtype := proc_symbol.type.dtype) is not BasicType.DEFERRED:
                                    if dtype.procedure is not BasicType.DEFERRED:
                                        routine = dtype.procedure

                is_not_enriched = (
                    symbol.scope is None or                         # No scope attached
                    symbol.type.dtype is BasicType.DEFERRED or      # Wrong datatype
                    symbol.type.dtype.procedure is not routine      # ProcedureType not linked
                )

                # Always update the call symbol to ensure it is up-to-date
                call._update(name=symbol)

                # Skip already enriched symbols and routines without definitions
                if not (routine and is_not_enriched):
                    debug('Cannot enrich call to %s', symbol)
                    continue

                # Remove existing symbol from symbol table if defined in interface block
                for node in [node for intf in self.interfaces for node in intf.body]:
                    if getattr(node, 'name', None) == symbol:
                        if node.parent == self:
                            node.parent = None

                # Need to update the call's symbol to establish link to routine
                symbol = symbol.clone(scope=self, type=symbol.type.clone(dtype=routine.procedure_type))
                call._update(name=symbol)

        # Rebuild local symbols to ensure correct symbol types
        self.body = ExpressionTransformer(inplace=True).visit(self.body)

    def __repr__(self):
        """ String representation """
        return f'Subroutine:: {self.name}'
loki-ecmwf-0.3.6/loki/logging.py0000664000175000017500000001235015167130205016700 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
"""
Loki's logger classes and logging utilities.
"""

import logging
import sys


__all__ = ['logger', 'log_levels', 'set_log_level', 'FileLogger',
           'debug', 'detail', 'perf', 'info', 'warning', 'error', 'log']


def FileLogger(name, filename, level=None, file_level=None, fmt=None,
               mode='a'):
    """
    Logger that emits to a single logfile, as well as stdout/stderr.
    """
    level = level or INFO
    file_level = file_level or level

    _logger = logging.getLogger(name)
    _logger.setLevel(level if level <= file_level else file_level)

    fmt = fmt or '%(asctime)s %(name)s[%(process)d] %(levelname)s %(message)s'
    fh = logging.FileHandler(str(filename), mode=mode)
    fh.setFormatter(logging.Formatter(fmt))
    fh.setLevel(file_level)
    _logger.addHandler(fh)

    # Install the colored logging handlers
    try:
        import coloredlogs  # pylint: disable=import-outside-toplevel
        coloredlogs.install(level=level, logger=_logger)
    except ImportError:
        pass

    # TODO: For concurrent file writes, initialize queue and
    # main logging thread.

    return _logger


# Initialize base logger
logger = logging.getLogger('Loki')
stream_handler = logging.StreamHandler()
logger.addHandler(stream_handler)

# This one is primarily used by loki.build
default_logger = logger

# Note, this a remnant from loki.build.logging, which not only adds
# colour, but also adds hostname and timestamps, etc. to the log line
# We might want to re-eanble this under some specific logging options

# coloredlogs.install(level=default_level, logger=logger)


# Define available log levels
DEBUG = logging.DEBUG
INFO = logging.INFO
WARNING = logging.WARNING
ERROR = logging.ERROR
PERF = 15
DETAIL = 12

# Internally accepted log levels
log_levels = {
    'DEBUG': DEBUG,
    'DETAIL': DETAIL,
    'PERF': PERF,
    'INFO': INFO,
    'WARNING': WARNING,
    'ERROR': ERROR,
    # Lower case keywords for env variables
    'debug': DEBUG,
    'detail': DETAIL,
    'perf': PERF,
    'info': INFO,
    'warning': WARNING,
    'error': ERROR,
    # Enum keys for idempotence
    DEBUG: DEBUG,
    DETAIL: DETAIL,
    PERF: PERF,
    INFO: INFO,
    WARNING: WARNING,
    ERROR: ERROR,
}

# Internally used log colours (in simple mode)
NOCOLOR = '%s'
RED = '\033[1;37;31m%s\033[0m'
BLUE = '\033[1;37;34m%s\033[0m'
GREEN = '\033[1;37;32m%s\033[0m'
colors = {
    DEBUG: NOCOLOR,
    DETAIL: GREEN,
    PERF: GREEN,
    INFO: GREEN,
    WARNING: BLUE,
    ERROR: RED,
}

def set_log_level(level):
    """
    Set the log level for the Loki logger.
    """
    if level not in log_levels.values():
        raise ValueError(f'Illegal logging level {level}')

    logger.setLevel(level)


def log(msg, level, *args, **kwargs):
    """
    Wrapper of the main Python's logging function. Print 'msg % args' with
    the severity 'level'.

    :param msg: the message to be printed.
    """
    color = colors[level] if sys.stdout.isatty() and sys.stderr.isatty() else '%s'
    logger.log(level, color % msg, *args, **kwargs)


def debug(msg, *args, **kwargs):
    """
    Logger method for most verbose level of output

    Parameters
    ----------
    msg : str
        Message to log at :any:`DEBUG` level.
    """
    log(msg, DEBUG, *args, **kwargs)

def detail(msg, *args, **kwargs):
    """
    Logger method for detailed, per-file information.

    This level should be used for timing and detailed information at a
    per-file level, which can get verbose.

    Parameters
    ----------
    msg : str
        Message to log at :any:`DETAIL` level.
    """
    log(msg, DETAIL, *args, **kwargs)

def perf(msg, *args, **kwargs):
    """
    Logger method for performance information.

    This level should be used for timing individual processes at a
    global granularity during batch-processing.

    Parameters
    ----------
    msg : str
        Message to log at :any:`DETAIL` level.
    """
    log(msg, PERF, *args, **kwargs)

def info(msg, *args, **kwargs):
    """
    Logger method for high-level progress information.

    This is the default output logging and should only be used at a
    global granularity during batch-processing.

    Parameters
    ----------
    msg : str
        Message to log at :any:`INFO` level.
    """
    log(msg, INFO, *args, **kwargs)

def warning(msg, *args, **kwargs):
    """
    Logger method for high-level progress information.

    This level should be used for potentially dangerous, but not fatal
    information.

    Parameters
    ----------
    msg : str
        Message to log at :any:`WARN` level.
    """
    log(msg, WARNING, *args, **kwargs)

def error(msg, *args, **kwargs):
    """
    Logger method for high-level progress information.

    This level should be used to provide additional information in
    case of failures.

    Parameters
    ----------
    msg : str
        Message to log at :any:`ERROR` level.
    """
    log(msg, ERROR, *args, **kwargs)
loki-ecmwf-0.3.6/loki/transformations/0000775000175000017500000000000015167130205020130 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/transformations/extract/0000775000175000017500000000000015167130205021602 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/transformations/extract/__init__.py0000664000175000017500000000611615167130205023717 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
"""
Transformations sub-package that provides various forms of
source-code extraction into standalone :any:`Subroutine` objects.

The various extractions mechanisms are provided as standalone utility
methods, or via the :any:`ExtractTransformation` class for for batch
processing.

These utilities represent the conceptual inverse operation to
"inlining", as done by the :any:`InlineTransformation`.
"""

from loki.transformations.extract.internal import * # noqa
from loki.transformations.extract.outline import * # noqa

from loki.batch import Transformation


__all__ = ['ExtractTransformation']


class ExtractTransformation(Transformation):
    """
    :any:`Transformation` class to apply several types of source
    extraction when batch-processing large source trees via the
    :any:`Scheduler`.

    Parameters
    ----------
    extract_internals : bool
        Extract internal procedure (see :any:`extract_internal_procedures`);
        default: False.
    outline_regions : bool
        Outline pragma-annotated code regions to :any:`Subroutine` objects.
        (see :any:`outline_pragma_regions`); default: True.
    """
    def __init__(self, extract_internals=False, outline_regions=True):
        self.extract_internals = extract_internals
        self.outline_regions = outline_regions

    def transform_module(self, module, **kwargs):
        """
        Extract internals procedures and marked subroutines and add
        them to the given :any:`Module`.
        """

        # Extract internal (contained) procedures into standalone ones
        if self.extract_internals:
            for routine in module.subroutines:
                new_routines = extract_internal_procedures(routine)
                module.contains.append(new_routines)

        # Extract pragma-marked code regions into standalone subroutines
        if self.outline_regions:
            for routine in module.subroutines:
                new_routines = outline_pragma_regions(routine)
                module.contains.append(new_routines)

    def transform_file(self, sourcefile, **kwargs):
        """
        Extract internals procedures and marked subroutines and add
        them to the given :any:`Sourcefile`.
        """

        # Extract internal (contained) procedures into standalone ones
        if self.extract_internals:
            for routine in sourcefile.subroutines:
                new_routines = extract_internal_procedures(routine)
                sourcefile.ir.append(new_routines)

        # Extract pragma-marked code regions into standalone subroutines
        if self.outline_regions:
            for routine in sourcefile.subroutines:
                new_routines = outline_pragma_regions(routine)
                sourcefile.ir.append(new_routines)
loki-ecmwf-0.3.6/loki/transformations/extract/tests/0000775000175000017500000000000015167130205022744 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/transformations/extract/tests/__init__.py0000664000175000017500000000057015167130205025057 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
loki-ecmwf-0.3.6/loki/transformations/extract/tests/test_extract_transformation.py0000664000175000017500000000775615167130205031174 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki import Subroutine, Module, Sourcefile
from loki.frontend import available_frontends
from loki.ir import nodes as ir, FindNodes

from loki.transformations.extract import ExtractTransformation


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('outline_regions', [False, True])
@pytest.mark.parametrize('extract_internals', [False, True])
def test_extract_transformation_module(extract_internals, outline_regions, frontend, tmp_path):
    """
    Test basic subroutine extraction from marker pragmas in modules.
    """
    fcode = """
module test_extract_mod
implicit none
contains

subroutine outer(n, a, b)
  integer, intent(in) :: n
  real(kind=8), intent(inout) :: a, b(n)
  real(kind=8) :: x(n), y(n, n+1)
  integer :: i, j

  x(:) = a
  do i=1, n
    y(i,:) = b(i)
  end do

  !$loki outline name(test1)
  do i=1, n
    do j=1, n+1
      x(i) = x(i)  + y(i, j)
    end do
  end do
  !$loki end outline

  do i=1, n
    call plus_one(x, i=i)
  end do

contains
  subroutine plus_one(f, i)
    real(kind=8), intent(inout) :: f(:)
    integer, intent(in) :: i

    f(i) = f(i) + 1.0
  end subroutine plus_one
end subroutine outer
end module test_extract_mod
"""
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    ExtractTransformation(
        extract_internals=extract_internals, outline_regions=outline_regions
    ).apply(module)

    routines = tuple(r for r in module.contains.body if isinstance(r, Subroutine))
    assert len(routines) == 1 + (1 if extract_internals else 0) + (1 if outline_regions else 0)
    assert ('plus_one' in module) ==  extract_internals
    assert ('test1' in module) ==  outline_regions

    outer = module['outer']
    assert len(FindNodes(ir.CallStatement).visit(outer.body)) == (2 if outline_regions else 1)
    outer_internals = tuple(r for r in outer.contains.body if isinstance(r, Subroutine))
    assert len(outer_internals) == (0 if extract_internals else 1)


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('outline_regions', [False, True])
@pytest.mark.parametrize('extract_internals', [False, True])
def test_extract_transformation_sourcefile(extract_internals, outline_regions, frontend):
    """
    Test internal procedure extraction and region outlining from subroutines.
    """
    fcode = """
subroutine outer(n, a, b)
  integer, intent(in) :: n
  real(kind=8), intent(inout) :: a, b(n)
  real(kind=8) :: x(n), y(n, n+1)
  integer :: i, j

  x(:) = a
  do i=1, n
    y(i,:) = b(i)
  end do

  !$loki outline name(test1)
  do i=1, n
    do j=1, n+1
      x(i) = x(i)  + y(i, j)
    end do
  end do
  !$loki end outline

  do i=1, n
    call plus_one(x, i=i)
  end do

contains
  subroutine plus_one(f, i)
    real(kind=8), intent(inout) :: f(:)
    integer, intent(in) :: i

    f(i) = f(i) + 1.0
  end subroutine plus_one
end subroutine outer
"""
    sourcefile = Sourcefile.from_source(fcode, frontend=frontend)

    ExtractTransformation(
        extract_internals=extract_internals, outline_regions=outline_regions
    ).apply(sourcefile)

    routines = tuple(r for r in sourcefile.ir.body if isinstance(r, Subroutine))
    assert len(routines) == 1 + (1 if extract_internals else 0) + (1 if outline_regions else 0)
    assert ('plus_one' in sourcefile) ==  extract_internals
    assert ('test1' in sourcefile) ==  outline_regions

    outer = sourcefile['outer']
    assert len(FindNodes(ir.CallStatement).visit(outer.body)) == (2 if outline_regions else 1)
    outer_internals = tuple(r for r in outer.contains.body if isinstance(r, Subroutine))
    assert len(outer_internals) == (0 if extract_internals else 1)
loki-ecmwf-0.3.6/loki/transformations/extract/tests/test_extract_internal.py0000664000175000017500000004754015167130205027735 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki.frontend import available_frontends, OMNI
from loki.ir import CallStatement, Import, FindNodes, FindInlineCalls
from loki.sourcefile import Sourcefile
from loki.subroutine import Subroutine

from loki.transformations.extract import extract_internal_procedures


@pytest.mark.parametrize('frontend', available_frontends())
def test_extract_internal_procedures_basic_scalar(frontend):
    """
    Tests that a global scalar is correctly added as argument of `inner`.
    """
    fcode = """
        subroutine outer()
            implicit none
            integer :: x
            x = 42
            call inner()
            contains
            subroutine inner()
                integer :: y
                integer :: z
                y = 1
                z = x + y
            end subroutine inner
        end subroutine outer
    """
    src = Sourcefile.from_source(fcode, frontend=frontend)
    routines = extract_internal_procedures(src.routines[0])
    assert len(routines) == 1
    assert routines[0].name == "inner"
    inner = routines[0]
    outer = src.routines[0]
    assert 'x' in inner.arguments

    call = FindNodes(CallStatement).visit(outer.body)[0]
    assert 'x' in (arg[0] for arg in call.kwarguments)

@pytest.mark.parametrize('frontend', available_frontends())
def test_extract_internal_procedures_contains_emptied(frontend):
    """
    Tests that the contains section does not contain any functions or subroutines after processing.
    """
    fcode = """
        subroutine outer()
            implicit none
            integer :: x
            x = 42
            call inner()
            contains
            subroutine inner()
                integer :: y
                integer :: z
                y = 1
                z = x + y
            end subroutine inner
            function f() result(res)
                integer :: y
                integer :: z
                integer :: res
                y = 1
                z = y
                res = 2 * z
            end function f
        end subroutine outer
    """
    src = Sourcefile.from_source(fcode, frontend=frontend)
    outer = src.routines[0]
    extract_internal_procedures(outer)
    # NOTE: Functions in Loki are also typed as Subroutines.
    assert not any(isinstance(r, Subroutine) for r in outer.contains.body)

@pytest.mark.parametrize('frontend', available_frontends())
def test_extract_internal_procedures_basic_array(frontend):
    """
    Tests that a global array variable (and a scalar) is correctly added as argument of `inner`.
    """

    fcode = """
        subroutine outer()
            implicit none
            integer :: x
            real :: arr(3)
            arr = 71.0
            x = 42
            call inner()
            contains
            subroutine inner()
                integer :: y
                integer :: z
                y = 1
                z = x + y + arr(1)
            end subroutine inner
        end subroutine outer
    """
    src = Sourcefile.from_source(fcode, frontend=frontend)
    routines = extract_internal_procedures(src.routines[0])
    assert len(routines) == 1
    inner = routines[0]
    outer = src.routines[0]
    assert 'x' in inner.arguments
    assert 'arr(3)' in inner.arguments

    call = FindNodes(CallStatement).visit(outer.body)[0]
    kwargdict = dict(call.kwarguments)
    assert kwargdict['x'] == 'x'
    assert kwargdict['arr'] == 'arr'

@pytest.mark.parametrize('frontend', available_frontends())
def test_extract_internal_procedures_existing_call_args(frontend):
    """
    Tests that variable resolution process works correctly when the parent contains a call to
    the extracted function that already has some calling arguments.
    Test also that new args are introduced as kw arguments.
    """

    fcode = """
        subroutine outer()
            implicit none
            integer :: x
            integer :: y
            integer :: z
            real :: arr(3)
            arr = 71.0
            x = 42
            y = 1
            call inner(x, y)
            call inner(x, y = 1)
            ! Note, 'call inner(y = 1, x)' is disallowed by Fortran and not tested.
            call inner(x = 1, y = 1)
            call inner(y = 1, x = 1)
            contains
            subroutine inner(x, y)
                integer, intent(in) :: x
                integer, intent(in) :: y
                z = x + y + arr(1)
            end subroutine inner
        end subroutine outer
    """
    src = Sourcefile.from_source(fcode, frontend=frontend)
    outer = src.routines[0]
    extract_internal_procedures(outer)
    calls = FindNodes(CallStatement).visit(outer.body)

    for call in calls:
        kwargdict = dict(call.kwarguments)
        assert kwargdict['arr'] == 'arr'
        assert kwargdict['z'] == 'z'

    assert 'x' == calls[0].arguments[0]
    assert 'y' == calls[0].arguments[1]
    assert len(calls[0].arguments) == 2

    assert 'x' == calls[1].arguments[0]
    assert len(calls[1].arguments) == 1
    assert 'y' in tuple(arg[0] for arg in calls[1].kwarguments)

    assert len(calls[2].arguments) == 0
    kwargdict = dict(calls[2].kwarguments)
    assert kwargdict['x'] == 1
    assert kwargdict['y'] == 1

    assert len(calls[3].arguments) == 0
    kwargdict = dict(calls[3].kwarguments)
    assert kwargdict['x'] == 1
    assert kwargdict['y'] == 1

@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'Parser fails on missing constants module')]))
def test_extract_internal_procedures_basic_import(frontend):
    """
    Tests that a global imported binding is correctly introduced to the contained subroutine.
    """

    fcode = """
        subroutine outer()
            use constants, only: c1, c2
            implicit none
            integer :: x
            x = 42 + c1
            call inner()
            contains
            subroutine inner()
                integer :: y
                integer :: z
                y = 1
                z = x + y + c2
            end subroutine inner
        end subroutine outer
    """
    src = Sourcefile.from_source(fcode, frontend=frontend)
    routines = extract_internal_procedures(src.routines[0])
    assert len(routines) == 1
    inner = routines[0]
    assert "c2" in inner.import_map
    assert "c1" not in inner.import_map
    assert 'c2' not in inner.arguments

@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'Parser fails on missing type_mod module')]))
def test_extract_internal_procedures_recursive_definition(frontend):
    """
    Tests that whenever a global in the contained subroutine depends on another
    global variable, both are introduced as arguments,
    even if there is no explicit reference to the latter.
    """
    fcode = """
        subroutine outer(klon, klev, mt)
            use type_mod, only: mytype
            implicit none
            integer, intent(in) :: klon
            integer, intent(in) :: klev
            type(mytype), intent(in) :: mt
            integer :: somearr(klon, mt%a%b)
            integer :: x(klon)
            integer :: somevar(klon, klev + 1)

            x(klon - 1) = 42
            call inner()
            contains
            subroutine inner()
                integer :: y
                integer :: z
                y = 1
                z = x(1) + y + somevar(1, 1) - somearr(1, 1)
            end subroutine inner
        end subroutine outer
    """
    src = Sourcefile.from_source(fcode, frontend=frontend)
    routines = extract_internal_procedures(src.routines[0])
    assert len(routines) == 1
    outer = src.routines[0]
    inner = routines[0]
    assert 'x(klon)' in inner.arguments
    assert 'somevar(klon, klev + 1)' in inner.arguments
    assert 'klon' in inner.arguments
    assert 'klev' in inner.arguments
    assert 'mt' in inner.arguments
    assert 'mt%a' not in inner.arguments
    assert 'mt%a%b' not in inner.arguments

    call = FindNodes(CallStatement).visit(outer.body)[0]
    kwargdict = dict(call.kwarguments)
    assert kwargdict['x'] == 'x'
    assert kwargdict['klon'] == 'klon'
    assert kwargdict['somearr'] == 'somearr'
    assert kwargdict['somevar'] == 'somevar'
    assert kwargdict['klev'] == 'klev'
    assert kwargdict['mt'] == 'mt'
    assert 'mt%a' not in kwargdict
    assert 'mt%a%b' not in kwargdict

    assert 'x' not in call.arguments
    assert 'klon' not in call.arguments
    assert 'somearr' not in call.arguments
    assert 'somevar' not in call.arguments
    assert 'klev' not in call.arguments
    assert 'mt' not in call.arguments
    assert 'mt%a' not in call.arguments
    assert 'mt%a%b' not in call.arguments

    # Test that intent of 'klon' and 'klev' is also 'in' inside inner (because intent is given in parent).
    klon = inner.variable_map['klon']
    klev = inner.variable_map['klev']
    assert klon.type.intent == "in"
    assert klev.type.intent == "in"

@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'Parser fails on missing parkind1 module')]))
def test_extract_internal_procedures_recursive_definition_import(frontend):
    """
    Tests that whenever globals in the contained subroutine depend on imported bindings,
    the globals are introduced as arguments, and the imports are added to the contained subroutine.
    """
    fcode = """
        subroutine outer()
            use parkind1, only: jprb, jpim
            implicit none
            real(kind=jprb) :: x(3)
            integer(kind=jpim) :: ii(30)
            ii = 72
            x(1) = 42
            call inner()
            contains
            subroutine inner()
                integer :: y
                integer :: z
                y = 1
                ii(4) = 2
                z = x(1) + y
            end subroutine inner
        end subroutine outer
    """
    src = Sourcefile.from_source(fcode, frontend=frontend)
    routines = extract_internal_procedures(src.routines[0])
    assert len(routines) == 1
    outer = src.routines[0]
    inner = routines[0]
    assert 'x(3)' in inner.arguments
    assert 'ii(30)' in inner.arguments
    call = FindNodes(CallStatement).visit(outer.body)[0]
    kwargdict = dict(call.kwarguments)
    assert kwargdict['x'] == 'x'
    assert kwargdict['ii'] == 'ii'

    imports = FindNodes(Import).visit(inner.spec)
    modules = set()
    symbols = set()
    for imp in imports:
        modules.add(imp.module)
        for sym in imp.symbols:
            symbols.add(sym)
    assert "parkind1" in modules
    assert len(modules) == 1
    assert "jprb" in symbols
    assert "jpim" in symbols
    assert len(symbols) == 2

@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'Parser fails on missing parkind1 module')]))
def test_extract_internal_procedures_kind_resolution(frontend):
    """
    Tests that an unresolved kind parameter in inner scope is resolved from import in outer scope.
    """
    fcode = """
        subroutine outer()
            use parkind1, only: jpim
            implicit none
            call inner()
            contains
            subroutine inner()
                integer(kind = jpim) :: y
                integer(kind=8) :: z
                z = y
            end subroutine inner
        end subroutine outer
    """
    src = Sourcefile.from_source(fcode, frontend=frontend)
    routines = extract_internal_procedures(src.routines[0])
    inner = routines[0]
    assert "jpim" in inner.import_map

@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'Parser fails on missing stuff module')]))
def test_extract_internal_procedures_derived_type_resolution(frontend):
    """
    Tests that an unresolved derived type in inner scope is resolved from import in outer scope.
    """
    fcode = """
        subroutine outer()
            use stuff, only: mytype
            implicit none
            call inner()
            contains
            subroutine inner()
                type(mytype) :: y
                integer :: z
                z = y%a
            end subroutine inner
        end subroutine outer
    """
    src = Sourcefile.from_source(fcode, frontend=frontend)
    routines = extract_internal_procedures(src.routines[0])
    inner = routines[0]
    assert "mytype" in inner.import_map

@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'Parser fails on missing types module')]))
def test_extract_internal_procedures_derived_type_field(frontend):
    """
    Test that when a derived type field, i.e 'a%b' is a global in the scope of the contained subroutine,
    the derived type itself, that is, 'a', is introduced as an the argument in the transformation.
    """
    fcode = """
        subroutine outer()
            use types, only: my_type, your_type
            implicit none
            type(my_type) :: xtyp
            type(your_type) :: ytyp
            call inner()
            contains
            subroutine inner()
                integer :: y
                integer :: z
                y = 1
                xtyp%a = 40
                ytyp%val%b = 10.0
                z = y + ytyp%something_else
            end subroutine inner
        end subroutine outer
    """
    src = Sourcefile.from_source(fcode, frontend=frontend)
    routines = extract_internal_procedures(src.routines[0])
    outer = src.routines[0]
    inner = routines[0]
    assert 'xtyp' in inner.arguments
    assert 'ytyp' in inner.arguments

    call = FindNodes(CallStatement).visit(outer.body)[0]
    kwargdict = dict(call.kwarguments)
    assert kwargdict['xtyp'] == 'xtyp'
    assert kwargdict['ytyp'] == 'ytyp'

    imports = FindNodes(Import).visit(inner.spec)
    modules = set()
    symbols = set()
    for imp in imports:
        modules.add(imp.module)
        for sym in imp.symbols:
            symbols.add(sym)
    assert "types" in modules
    assert len(modules) == 1
    assert "my_type" in symbols
    assert "your_type" in symbols
    assert len(symbols) == 2

@pytest.mark.parametrize('frontend', available_frontends())
def test_extract_internal_procedures_intent(frontend):
    """
    This test is just to document the current behaviour: when a global is
    introduced as an argument to the extracted contained procedure,
    its intent will be 'inout', unless the intent is specified in the parent procedure.
    """
    fcode = """
        subroutine outer(v, p)
            implicit none
            integer, intent(in) :: v
            integer, intent(out) :: p
            integer :: x(3)
            x = 4
            call inner()
            p = 400
            contains
            subroutine inner()
                integer :: y
                integer :: z
                y = 1
                z = x(1) + v + y + p
            end subroutine inner
        end subroutine outer
    """
    src = Sourcefile.from_source(fcode, frontend=frontend)
    routines = extract_internal_procedures(src.routines[0])
    assert len(routines) == 1
    outer = src.routines[0]
    inner = routines[0]
    assert inner.variable_map['v'].type.intent == "in"
    assert inner.variable_map['x'].type.intent == "inout"
    assert inner.variable_map['p'].type.intent == "out"

    # Also check that the intents don't change in the parent.
    assert outer.variable_map['v'].type.intent == "in"
    assert outer.variable_map['x'].type.intent is None
    assert outer.variable_map['p'].type.intent == "out"

@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'Parser fails on undefined symbols')]))
def test_extract_internal_procedures_undefined_in_parent(frontend):
    """
    This test is just to document current behaviour:
    an exception is raised if a global inside the contained procedure does not
    have a definition in the parent scope.
    """
    fcode = """
        subroutine outer()
            implicit none
            integer :: x
            x = 42
            call inner()
            contains
            subroutine inner()
                integer :: y
                y = 1
                z = x + y + g + f ! 'z', 'g', 'f' undefined in contained subroutine and parent.
            end subroutine inner
        end subroutine outer
    """
    src = Sourcefile.from_source(fcode, frontend=frontend)
    with pytest.raises(RuntimeError):
        extract_internal_procedures(src.routines[0])

@pytest.mark.parametrize('frontend', available_frontends())
def test_extract_internal_procedures_multiple_internal_procedures(frontend):
    """
    Basic test to check that multiple contained procedures can also be handled.
    """
    fcode = """
        subroutine outer()
            implicit none
            integer :: x, gx
            x = 42
            gx = 10
            call inner1()
            call inner2()
            contains
            subroutine inner1()
                integer :: y
                integer :: z
                y = 1
                z = x + y
            end subroutine inner1
            subroutine inner2()
                integer :: gy
                integer :: gz
                gy = 1
                gz = gx + gy
            end subroutine inner2
        end subroutine outer
    """
    src = Sourcefile.from_source(fcode, frontend=frontend)
    routines = extract_internal_procedures(src.routines[0])
    assert len(routines) == 2
    assert routines[0].name == "inner1"
    assert routines[1].name == "inner2"
    outer = src.routines[0]
    inner1 = routines[0]
    inner2 = routines[1]
    assert 'x' in inner1.arguments
    assert 'gx' in inner2.arguments

    call = [call for call in FindNodes(CallStatement).visit(outer.body) if call.name == "inner1"][0]
    assert 'x' in (arg[0] for arg in call.kwarguments)
    call = [call for call in FindNodes(CallStatement).visit(outer.body) if call.name == "inner2"][0]
    assert 'gx' in (arg[0] for arg in call.kwarguments)

@pytest.mark.parametrize('frontend', available_frontends())
def test_extract_internal_procedures_basic_scalar_function(frontend):
    """
    Basic test for scalars highlighting that the inner procedure may also be a function.
    """
    fcode = """
        subroutine outer()
            implicit none
            integer :: x
            integer :: y
            x = 42
            y = inner()
            contains
            function inner() result(z)
                integer :: y
                integer :: z
                y = 1
                z = x + y
            end function inner
        end subroutine outer
    """
    src = Sourcefile.from_source(fcode, frontend=frontend)
    routines = extract_internal_procedures(src.routines[0])
    assert len(routines) == 1
    assert routines[0].name == "inner"
    inner = routines[0]
    outer = src.routines[0]
    assert 'x' in inner.arguments

    call = list(FindInlineCalls().visit(outer.body))[0]
    assert 'x' in call.kw_parameters

@pytest.mark.parametrize('frontend', available_frontends())
def test_extract_internal_procedures_basic_scalar_function_both(frontend):
    """
    Basic test for scalars highlighting that the outer and inner procedure may be functions.
    """
    fcode = """
        function outer() result(outer_res)
            implicit none
            integer :: x
            integer :: outer_res
            x = 42
            outer_res = inner()

            contains

            function inner() result(z)
                integer :: y
                integer :: z
                y = 1
                z = x + y
            end function inner
        end function outer
    """
    src = Sourcefile.from_source(fcode, frontend=frontend)
    routines = extract_internal_procedures(src.routines[0])
    assert len(routines) == 1
    assert routines[0].name == "inner"
    inner = routines[0]
    outer = src.routines[0]
    assert 'x' in inner.arguments

    call = list(FindInlineCalls().visit(outer.body))[0]
    assert 'x' in call.kw_parameters
loki-ecmwf-0.3.6/loki/transformations/extract/tests/test_outline.py0000664000175000017500000004101115167130205026031 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest
import numpy as np

from loki import Module, Subroutine
from loki.jit_build import jit_compile_lib, Builder, Obj
from loki.frontend import available_frontends
from loki.ir import FindNodes, Assignment, CallStatement
from loki.types import BasicType

from loki.transformations.extract.outline import outline_pragma_regions


@pytest.fixture(scope='function', name='builder')
def fixture_builder(tmp_path):
    yield Builder(source_dirs=tmp_path, build_dir=tmp_path/'build')
    Obj.clear_cache()


def assignment_symbols(node):
    return [(assign.lhs, assign.rhs) for assign in FindNodes(Assignment).visit(node)]


def call_symbols(node):
    return [(call.name, call.arguments) for call in FindNodes(CallStatement).visit(node)]


def argument_symbols(routine):
    return tuple(routine.arguments)


def argument_intents(routine):
    return {arg.name: arg.type.intent for arg in routine.arguments}


@pytest.mark.parametrize('frontend', available_frontends())
def test_outline_pragma_regions(frontend):
    """
    A very simple :any:`outline_pragma_regions` test case
    """
    fcode = """
subroutine test_outline(a, b, c)
  integer, intent(out) :: a, b, c

  a = 5
  a = 1

!$loki outline in(a) out(b)
  b = a
!$loki end outline

  c = a + b
end subroutine test_outline
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(Assignment).visit(routine.body)) == 4
    assert len(FindNodes(CallStatement).visit(routine.body)) == 0
    assert assignment_symbols(routine.body) == [('a', 5), ('a', 1), ('b', 'a'), ('c', 'a + b')]

    routines = outline_pragma_regions(routine)
    assert len(routines) == 1 and routines[0].name == f'{routine.name}_outlined_0'

    assert len(FindNodes(Assignment).visit(routine.body)) == 3
    assert len(FindNodes(Assignment).visit(routines[0].body)) == 1
    assert len(FindNodes(CallStatement).visit(routine.body)) == 1
    assert assignment_symbols(routine.body) == [('a', 5), ('a', 1), ('c', 'a + b')]
    assert assignment_symbols(routines[0].body) == [('b', 'a')]
    assert call_symbols(routine.body) == [(f'{routine.name}_outlined_0', ('a', 'b'))]
    assert argument_symbols(routines[0]) == ('a', 'b')
    assert argument_intents(routines[0]) == {'a': 'in', 'b': 'out'}


@pytest.mark.parametrize('frontend', available_frontends())
def test_outline_pragma_regions_multiple(frontend):
    """
    Test hoisting with multiple groups and multiple regions per group
    """
    fcode = """
subroutine test_outline_mult(a, b, c)
  integer, intent(out) :: a, b, c

  a = 1
  a = a + 1
  a = a + 1
!$loki outline name(oiwjfklsf) inout(a)
  a = a + 1
!$loki end outline
  a = a + 1

!$loki outline in(a) out(b)
  b = a
!$loki end outline

!$loki outline in(a,b) out(c)
  c = a + b
!$loki end outline
end subroutine test_outline_mult
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(Assignment).visit(routine.body)) == 7
    assert len(FindNodes(CallStatement).visit(routine.body)) == 0
    assert assignment_symbols(routine.body) == [
        ('a', 1), ('a', 'a + 1'), ('a', 'a + 1'), ('a', 'a + 1'),
        ('a', 'a + 1'), ('b', 'a'), ('c', 'a + b')
    ]

    routines = outline_pragma_regions(routine)
    assert len(routines) == 3
    assert routines[0].name == 'oiwjfklsf'
    assert all(routines[i].name == f'{routine.name}_outlined_{i}' for i in (1,2))

    assert len(FindNodes(Assignment).visit(routine.body)) == 4
    assert all(len(FindNodes(Assignment).visit(r.body)) == 1 for r in routines)
    assert len(FindNodes(CallStatement).visit(routine.body)) == 3
    assert assignment_symbols(routine.body) == [('a', 1), ('a', 'a + 1'), ('a', 'a + 1'), ('a', 'a + 1')]
    assert [assignment_symbols(r.body) for r in routines] == [[('a', 'a + 1')], [('b', 'a')], [('c', 'a + b')]]
    assert call_symbols(routine.body) == [
        ('oiwjfklsf', ('a',)),
        (f'{routine.name}_outlined_1', ('a', 'b')),
        (f'{routine.name}_outlined_2', ('a', 'b', 'c'))
    ]
    assert argument_symbols(routines[0]) == ('a',)
    assert argument_intents(routines[0]) == {'a': 'inout'}
    assert argument_symbols(routines[1]) == ('a', 'b')
    assert argument_intents(routines[1]) == {'a': 'in', 'b': 'out'}
    assert argument_symbols(routines[2]) == ('a', 'b', 'c')
    assert argument_intents(routines[2]) == {'a': 'in', 'b': 'in', 'c': 'out'}


@pytest.mark.parametrize('frontend', available_frontends())
def test_outline_pragma_regions_arguments(frontend):
    """
    Test hoisting with multiple groups and multiple regions per group
    and automatic derivation of arguments
    """
    fcode = """
subroutine test_outline_args(a, b, c)
  integer, intent(out) :: a, b, c

  a = 1
  a = a + 1
  a = a + 1
!$loki outline name(func_a)
  a = a + 1
!$loki end outline
  a = a + 1

!$loki outline name(func_b)
  b = a
!$loki end outline

! partially override arguments
!$loki outline name(func_c) inout(b)
  c = a + b
!$loki end outline
end subroutine test_outline_args
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(Assignment).visit(routine.body)) == 7
    assert len(FindNodes(CallStatement).visit(routine.body)) == 0
    assert assignment_symbols(routine.body) == [
        ('a', 1), ('a', 'a + 1'), ('a', 'a + 1'), ('a', 'a + 1'),
        ('a', 'a + 1'), ('b', 'a'), ('c', 'a + b')
    ]

    routines = outline_pragma_regions(routine)
    assert len(routines) == 3
    assert [r.name for r in routines] == ['func_a', 'func_b', 'func_c']

    assert len(routines[0].arguments) == 1
    assert routines[0].arguments[0] == 'a' and routines[0].arguments[0].type.intent == 'inout'

    assert set(argument_symbols(routines[1])) == {'a', 'b'}
    assert routines[1].variable_map['a'].type.intent == 'in'
    assert routines[1].variable_map['b'].type.intent == 'out'

    assert set(argument_symbols(routines[2])) == {'a', 'b', 'c'}
    assert routines[2].variable_map['a'].type.intent == 'in'
    assert routines[2].variable_map['b'].type.intent == 'inout'
    assert routines[2].variable_map['c'].type.intent == 'out'

    assert len(FindNodes(Assignment).visit(routine.body)) == 4
    assert all(len(FindNodes(Assignment).visit(r.body)) == 1 for r in routines)
    assert len(FindNodes(CallStatement).visit(routine.body)) == 3
    assert assignment_symbols(routine.body) == [('a', 1), ('a', 'a + 1'), ('a', 'a + 1'), ('a', 'a + 1')]
    assert [assignment_symbols(r.body) for r in routines] == [[('a', 'a + 1')], [('b', 'a')], [('c', 'a + b')]]
    assert call_symbols(routine.body) == [('func_a', ('a',)), ('func_b', ('a', 'b')), ('func_c', ('a', 'b', 'c'))]


@pytest.mark.parametrize('frontend', available_frontends())
def test_outline_pragma_regions_arrays(frontend):
    """
    Test hoisting with array variables
    """
    fcode = """
subroutine test_outline_arr(a, b, n)
  integer, intent(out) :: a(n), b(n)
  integer, intent(in) :: n
  integer :: j

!$loki outline
  do j=1,n
    a(j) = j
  end do
!$loki end outline

!$loki outline
  do j=1,n
    b(j) = j
  end do
!$loki end outline

!$loki outline
  do j=1,n-1
    b(j) = b(j+1) - a(j)
  end do
  b(n) = 1
!$loki end outline
end subroutine test_outline_arr
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(Assignment).visit(routine.body)) == 4
    assert len(FindNodes(CallStatement).visit(routine.body)) == 0
    assert assignment_symbols(routine.body) == [
        ('a(j)', 'j'), ('b(j)', 'j'), ('b(j)', 'b(j + 1) - a(j)'), ('b(n)', 1)
    ]

    routines = outline_pragma_regions(routine)

    assert len(FindNodes(Assignment).visit(routine.body)) == 0
    assert len(FindNodes(CallStatement).visit(routine.body)) == 3

    assert len(routines) == 3

    assert {(arg, arg.type.intent) for arg in routines[0].arguments} == {('a(n)', 'out'), ('n', 'in')}
    assert {(arg, arg.type.intent) for arg in routines[1].arguments} == {('b(n)', 'out'), ('n', 'in')}
    assert {(arg, arg.type.intent) for arg in routines[2].arguments} == {('a(n)', 'in'), ('b(n)', 'inout'), ('n', 'in')}
    assert routines[0].variable_map['a'].dimensions[0].scope is routines[0]
    assert [assignment_symbols(r.body) for r in routines] == [
        [('a(j)', 'j')], [('b(j)', 'j')], [('b(j)', 'b(j + 1) - a(j)'), ('b(n)', 1)]
    ]
    assert call_symbols(routine.body) == [
        (f'{routine.name}_outlined_0', ('a', 'n')),
        (f'{routine.name}_outlined_1', ('b', 'n')),
        (f'{routine.name}_outlined_2', ('a', 'b', 'n'))
    ]


@pytest.mark.parametrize('frontend', available_frontends())
def test_outline_pragma_regions_imports(tmp_path, builder, frontend):
    """
    Test hoisting with correct treatment of imports
    """
    fcode_module = """
module outline_mod
  implicit none
  integer, parameter :: param = 1
  integer :: arr1(10)
  integer :: arr2(10)
end module outline_mod
    """.strip()

    fcode = """
module test_outline_imps_mod
  implicit none
contains
  subroutine test_outline_imps(a, b)
    use outline_mod, only: param, arr1, arr2
    integer, intent(out) :: a(10), b(10)
    integer :: j

!$loki outline
    do j=1,10
      a(j) = param
    end do
!$loki end outline

!$loki outline
    do j=1,10
      arr1(j) = j+1
    end do
!$loki end outline

    arr2(:) = arr1(:)

!$loki outline
    do j=1,10
      b(j) = arr2(j) - a(j)
    end do
!$loki end outline
  end subroutine test_outline_imps
end module test_outline_imps_mod
"""
    ext_module = Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path])
    module = Module.from_source(fcode, frontend=frontend, definitions=ext_module, xmods=[tmp_path])
    refname = f'ref_{module.name}_{frontend}'
    reference = jit_compile_lib([module, ext_module], path=tmp_path, name=refname, builder=builder)
    function = getattr(getattr(reference, module.name), module.subroutines[0].name)

    # Test the reference solution
    a = np.zeros(shape=(10,), dtype=np.int32)
    b = np.zeros(shape=(10,), dtype=np.int32)
    function(a, b)
    assert np.all(a == [1] * 10)
    assert np.all(b == range(1,11))
    (tmp_path/f'{module.name}.f90').unlink()

    assert len(FindNodes(Assignment).visit(module.subroutines[0].body)) == 4
    assert len(FindNodes(CallStatement).visit(module.subroutines[0].body)) == 0

    # Apply transformation
    routines = outline_pragma_regions(module.subroutines[0])

    assert len(FindNodes(Assignment).visit(module.subroutines[0].body)) == 1
    assert len(FindNodes(CallStatement).visit(module.subroutines[0].body)) == 3

    assert len(routines) == 3

    assert {(str(a), a.type.intent) for a in routines[0].arguments} == {('a(10)', 'out')}
    assert {(str(a), a.type.intent) for a in routines[1].arguments} == set()
    assert {(str(a), a.type.intent) for a in routines[2].arguments} == {('a(10)', 'in'), ('b(10)', 'out')}

    # Insert created routines into module
    module.contains.append(routines)

    obj = jit_compile_lib([module, ext_module], path=tmp_path, name=f'{module.name}_{frontend}', builder=builder)
    mod_function = getattr(getattr(obj, module.name), module.subroutines[0].name)

    # Test transformation
    a = np.zeros(shape=(10,), dtype=np.int32)
    b = np.zeros(shape=(10,), dtype=np.int32)
    mod_function(a, b)
    assert np.all(a == [1] * 10)
    assert np.all(b == range(1,11))


@pytest.mark.parametrize('frontend', available_frontends())
def test_outline_pragma_regions_derived_args(tmp_path, builder, frontend):
    """
    Test subroutine extraction with derived-type arguments.
    """

    fcode = """
module test_outline_dertype_mod
  implicit none

  type rick
    integer :: a(10), b(10)
  end type rick
contains

  subroutine test_outline_imps(a, b)
    integer, intent(out) :: a(10), b(10)
    type(rick) :: dave
    integer :: j

    dave%a(:) = a(:)
    dave%b(:) = b(:)

!$loki outline
    do j=1,10
      dave%a(j) = j + 1
    end do

    dave%b(:) = dave%b(:) + 42
!$loki end outline

    a(:) = dave%a(:)
    b(:) = dave%b(:)
  end subroutine test_outline_imps
end module test_outline_dertype_mod
"""
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    refname = f'ref_{module.name}_{frontend}'
    reference = jit_compile_lib([module], path=tmp_path, name=refname, builder=builder)
    function = getattr(getattr(reference, module.name), module.subroutines[0].name)

    # Test the reference solution
    a = np.zeros(shape=(10,), dtype=np.int32)
    b = np.zeros(shape=(10,), dtype=np.int32)
    function(a, b)
    assert np.all(a == range(2,12))
    assert np.all(b == 42)
    (tmp_path/f'{module.name}.f90').unlink()

    assert len(FindNodes(Assignment).visit(module.subroutines[0].body)) == 6
    assert len(FindNodes(CallStatement).visit(module.subroutines[0].body)) == 0

    # Apply transformation
    routines = outline_pragma_regions(module.subroutines[0])

    assert len(FindNodes(Assignment).visit(module.subroutines[0].body)) == 4
    assert len(FindNodes(CallStatement).visit(module.subroutines[0].body)) == 1

    # Check for a single derived-type argument
    assert len(routines) == 1
    assert len(routines[0].arguments) == 1
    assert routines[0].arguments[0] == 'dave'
    assert routines[0].arguments[0].type.dtype.name == 'rick'
    assert routines[0].arguments[0].type.intent == 'inout'

    # Insert created routines into module
    module.contains.append(routines)

    obj = jit_compile_lib([module], path=tmp_path, name=f'{module.name}_{frontend}', builder=builder)
    mod_function = getattr(getattr(obj, module.name), module.subroutines[0].name)

    # Test the transformed module solution
    a = np.zeros(shape=(10,), dtype=np.int32)
    b = np.zeros(shape=(10,), dtype=np.int32)
    mod_function(a, b)
    assert np.all(a == range(2,12))
    assert np.all(b == 42)


@pytest.mark.parametrize('frontend', available_frontends())
def test_outline_pragma_regions_associates(tmp_path, builder, frontend):
    """
    Test subroutine extraction with derived-type arguments.
    """

    fcode = """
module test_outline_assoc_mod
  implicit none

  type rick
    integer :: a(10), b(10)
  end type rick
contains

  subroutine test_outline_imps(a, b)
    integer, intent(out) :: a(10), b(10)
    type(rick) :: dave
    integer :: j

    associate(c=>dave%a, d=>dave%b)

    c(:) = a(:)
    d(:) = b(:)

!$loki outline
    do j=1,10
      c(j) = j + 1
    end do

    d(:) = d(:) + 42
!$loki end outline

    a(:) = c(:)
    b(:) = d(:)
    end associate
  end subroutine test_outline_imps
end module test_outline_assoc_mod
"""
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    routine = module.subroutines[0]
    refname = f'ref_{module.name}_{frontend}'
    reference = jit_compile_lib([module], path=tmp_path, name=refname, builder=builder)
    function = getattr(getattr(reference, module.name), routine.name)

    # Test the reference solution
    a = np.zeros(shape=(10,), dtype=np.int32)
    b = np.zeros(shape=(10,), dtype=np.int32)
    function(a, b)
    assert np.all(a == range(2,12))
    assert np.all(b == 42)
    (tmp_path/f'{module.name}.f90').unlink()

    assert len(FindNodes(Assignment).visit(routine.body)) == 6
    assert len(FindNodes(CallStatement).visit(routine.body)) == 0

    # Apply transformation
    outlined = outline_pragma_regions(routine)

    assert len(FindNodes(Assignment).visit(routine.body)) == 4
    calls = FindNodes(CallStatement).visit(routine.body)
    assert len(calls) == 1
    assert calls[0].arguments == ('c', 'd')

    # Check for a single derived-type argument
    assert len(outlined) == 1
    assert len(outlined[0].arguments) == 2
    assert outlined[0].arguments[0].name == 'c'
    assert outlined[0].arguments[0].type.shape == (10,)
    assert outlined[0].arguments[0].type.dtype == BasicType.INTEGER
    assert outlined[0].arguments[0].type.intent == 'out'
    assert outlined[0].arguments[1].name == 'd'
    assert outlined[0].arguments[1].type.shape == (10,)
    assert outlined[0].arguments[1].type.dtype == BasicType.INTEGER
    assert outlined[0].arguments[1].type.intent == 'inout'

    # Insert created routines into module
    module.contains.append(outlined)

    obj = jit_compile_lib(
        [module], path=tmp_path, name=f'{module.name}_{frontend}', builder=builder
    )
    mod_function = getattr(getattr(obj, module.name), routine.name)
    a = np.zeros(shape=(10,), dtype=np.int32)
    b = np.zeros(shape=(10,), dtype=np.int32)
    mod_function(a, b)
    assert np.all(a == range(2,12))
    assert np.all(b == 42)
loki-ecmwf-0.3.6/loki/transformations/extract/internal.py0000664000175000017500000002327515167130205024001 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.subroutine import Subroutine
from loki.expression import DeferredTypeSymbol, Array
from loki.ir import (
    CallStatement, Transformer, FindNodes, FindVariables,
    FindInlineCalls, SubstituteExpressions
)
from loki.tools import OrderedSet
from loki.types import DerivedType


__all__ = [
    'extract_contained_procedures', 'extract_contained_procedure',
    'extract_internal_procedures', 'extract_internal_procedure'
]


def extract_internal_procedures(procedure):
    """
    This transform creates "standalone" :any:`Subroutine`s
    from the internal procedures (subroutines or functions) of ``procedure``.

    A list of :any:`Subroutine`s corresponding to each internal subroutine of
    ``procedure`` is returned and ``procedure`` itself is
    modified (see below).
    This function does the following transforms:
    1. all global bindings from the point of view of the internal procedures(s) are introduced
    as imports or dummy arguments to the modified internal procedures(s) to make them standalone.
    2. all calls or invocations of the internal procedures in parent are modified accordingly.
    3. All procedures are removed from the CONTAINS block of ``procedure``.

    As a basic example of this transformation, the Fortran subroutine:

    .. code-block::

        subroutine outer()
            integer :: y
            integer :: o
            o = 0
            y = 1
            call inner(o)
            contains
            subroutine inner(o)
               integer, intent(inout) :: o
               integer :: x
               x = 4
               o = x + y ! Note, 'y' is "global" here!
            end subroutine inner
        end subroutine outer

    is modified to:

    .. code-block::

        subroutine outer()
            integer :: y
            integer :: o
            o = 0
            y = 1
            call inner(o, y) ! 'y' now passed as argument.
            contains
        end subroutine outer

    and the (modified) child:

    .. code-block::

        subroutine inner(o, y)
               integer, intent(inout) :: o
               integer, intent(inout) :: y
               integer :: x
               x = 4
               o = x + y ! Note, 'y' is no longer "global"
        end subroutine inner

    is returned.
    """
    new_procedures = []
    for r in procedure.subroutines:
        new_procedures += [extract_internal_procedure(procedure, r.name)]

    # Remove all subroutines (or functions) from the CONTAINS section.
    newbody = tuple(r for r in procedure.contains.body if not isinstance(r, Subroutine))
    procedure.contains = procedure.contains.clone(body=newbody)
    return new_procedures


def extract_internal_procedure(procedure, name):
    """
    Extract a single internal procedure with name ``name`` from the parent procedure ``procedure``.

    This function does the following transforms:
    1. all global bindings from the point of view of the internal procedure are introduced
    as imports or dummy arguments to the modified internal procedure returned from this function.
    2. all calls or invocations of the internal procedure in the parent are modified accordingly.

    See also the "driver" function ``extract_internal_procedures``, which applies this function to each
    internal procedure of a parent procedure and additionally empties the CONTAINS section of subroutines.
    """
    inner = procedure.subroutine_map[name] # Fetch the subprocedure to extract (or crash with 'KeyError').

    # Check if there are variables that don't have a scope. This means that they are not defined anywhere
    # and execution cannot continue.
    undefined = tuple(v for v in FindVariables().visit(inner.body) if not v.scope)
    if undefined:
        msg = f"The following variables appearing in the internal procedure '{inner.name}' are undefined "
        msg += f"in both '{inner.name}' and the parent procedure '{procedure.name}': "
        for u in undefined:
            msg += f"{u.name}, "
        raise RuntimeError(msg)

    ## PRODUCING VARIABLES TO INTRODUCE AS DUMMY ARGUMENTS TO `inner`.
    # Produce a list of variables defined in the scope of `procedure` that need to be resolved in `inner`'s scope
    # by introducing them as dummy arguments.
    # The second line drops any derived type fields, don't want them, since want to resolve the derived type itself.
    vars_to_resolve = [v for v in FindVariables().visit(inner.body) if v.scope is procedure]
    vars_to_resolve = [v for v in vars_to_resolve if not v.parent]

    # Save any `DeferredTypeSymbol`s for later, they are in fact defined through imports in `procedure`,
    # and therefore not to be added as arguments to `inner`. (the next step removes them from `vars_to_resolve`)
    var_imports_to_add = tuple(v for v in vars_to_resolve if isinstance(v, DeferredTypeSymbol))

    # Lookup the definition of the variables in `vars_to_resolve` from the scope of `procedure`.
    # This provides maximal information on them.
    vars_to_resolve = [proc_var for v in vars_to_resolve if \
        (proc_var := procedure.variable_map.get(v.name))]

    # For each array in `vars_to_resolve`, append any non-literal shape variables to `vars_to_resolve`,
    # if not already there.
    arr_shapes = []
    for var in vars_to_resolve:
        if isinstance(var, Array):
            # Dropping variables with parents here to handle the case that the array dimension(s)
            # are defined through the field of a derived type.
            arr_shapes += list(v for v in FindVariables().visit(var.shape) if not v.parent)
    for v in arr_shapes:
        if v.name not in vars_to_resolve:
            vars_to_resolve.append(v)
    vars_to_resolve = tuple(vars_to_resolve)

    ## PRODUCING IMPORTS TO INTRODUCE TO `inner`.
    # Get all variables from `inner.spec`. Need to search them for resolving kinds and derived types for
    # variables that do not need resolution.
    inner_spec_vars = tuple(FindVariables().visit(inner.spec))

    # Produce derived types appearing in `vars_to_resolve` or in `inner.spec` that need to be resolved
    # from imports of `procedure`.
    dtype_imports_to_add = tuple(v.type.dtype for v in vars_to_resolve + inner_spec_vars \
        if isinstance(v.type.dtype, DerivedType))

    # Produce kinds appearing in `vars_to_resolve` or in `inner.spec` that need to be resolved
    # from imports of `procedure`.
    kind_imports_to_add = tuple(v.type.kind for v in vars_to_resolve + inner_spec_vars \
        if v.type.kind and hasattr(v.type.kind, 'scope') and v.type.kind.scope is procedure)

    # Produce all imports to add.
    # Here the imports are also tidied to only import what is strictly necessary, and with single
    # USE statements for each module.
    imports_to_add = []
    to_lookup_from_imports = dtype_imports_to_add + kind_imports_to_add + var_imports_to_add
    for val in to_lookup_from_imports:
        imp = procedure.import_map[val.name]
        matching_import = tuple(i for i in imports_to_add if i.module == imp.module)
        if matching_import:
            # Have already encountered module name, modify existing.
            matching_import = matching_import[0]
            imports_to_add.remove(matching_import)
            newimport = matching_import.clone(symbols=tuple(OrderedSet(matching_import.symbols + imp.symbols)))
        else:
            # Have not encountered the module yet, add new one.
            newsyms = tuple(s for s in imp.symbols if s.name == val.name)
            newimport = imp.clone(symbols=newsyms)
        imports_to_add.append(newimport)

    ## MAKING THE CHANGES TO `inner`
    # Change `inner` to take `vars_to_resolve` as dummy arguments and add all necessary imports.
    # Here also rescoping all variables to the scope of `inner` and specifying intent as "inout",
    # if not set in `procedure` scope.
    # Note: After these lines, `inner` should be self-contained or there is a bug.
    inner.arguments += tuple(
        v.clone(type=v.type.clone(intent=v.type.intent or 'inout'), scope=inner)
        for v in vars_to_resolve
    )
    inner.spec.prepend(imports_to_add)

    ## TRANSFORMING CALLS TO `inner` in `procedure`.
    # The resolved variables are all added as keyword arguments to each call.
    # (to avoid further modification of the call if it already happens to contain kwargs).
    # Here any dimensions in the variables are dropped, since they should not appear in the call.
    # Note that functions need different visitors and mappers than subroutines.
    call_map = {}
    if inner.is_function:
        for call in FindInlineCalls().visit(procedure.body):
            if call.routine is inner:
                newkwargs = tuple((v.name, v.clone(dimensions=None, scope=procedure)) for v in vars_to_resolve)
                call_map[call] = call.clone(kw_parameters=call.kwarguments + newkwargs)
        procedure.body = SubstituteExpressions(call_map).visit(procedure.body)
    else:
        for call in FindNodes(CallStatement).visit(procedure.body):
            if call.routine is inner:
                newkwargs = tuple((v.name, v.clone(dimensions=None, scope=procedure)) for v in vars_to_resolve)
                call_map[call] = call.clone(kwarguments=tuple(call.kwarguments) + newkwargs)
        procedure.body = Transformer(call_map).visit(procedure.body)

    return inner


# Aliases to the original names
extract_contained_procedures = extract_internal_procedures
extract_contained_procedure = extract_internal_procedure
loki-ecmwf-0.3.6/loki/transformations/extract/outline.py0000664000175000017500000002255115167130205023640 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.analyse import dataflow_analysis_attached
from loki.expression import symbols as sym, Variable
from loki.ir import (
    CallStatement, PragmaRegion, Section, FindNodes,
    FindVariables, Transformer, is_loki_pragma,
    get_pragma_parameters, pragma_regions_attached
)
from loki.logging import info
from loki.subroutine import Subroutine
from loki.tools import as_tuple, OrderedSet
from loki.types import BasicType



__all__ = ['outline_region', 'outline_pragma_regions']


def order_variables_by_type(variables, imports=None):
    """
    Apply a default ordering to variables based on their type, so that
    their use in declaration lists is unified.
    """
    variables = sorted(variables, key=str)  # Lexicographical base order

    derived = tuple(
        v for v in variables
        if not isinstance(v, (sym.Scalar, sym.Array)) or not isinstance(v.type.dtype, BasicType)
    )

    if imports:
        # Order derived types by the order of their type in imports
        imported_symbols = tuple(s for i in imports for s in i.symbols if not i.c_import)
        derived = tuple(sorted(derived, key=lambda x: imported_symbols.index(x.type.dtype.name)))

    # Order declarations by type and put arrays before scalars
    non_derived = tuple(v for v in variables if v not in derived)
    arrays = tuple(v for v in non_derived if isinstance(v, sym.Array))
    scalars = tuple(v for v in non_derived if isinstance(v, sym.Scalar))
    assert len(derived) + len(arrays) + len(scalars) == len(variables)

    return derived + arrays + scalars


def outline_region(region, name, imports, intent_map=None):
    """
    Creates a new :any:`Subroutine` object from a given :any:`PragmaRegion`.

    Parameters
    ----------
    region : :any:`PragmaRegion`
        The region that holds the body for which to create a subroutine.
    name : str
        Name of the new subroutine
    imports : tuple of :any:`Import`, optional
        List of imports to replicate in the new subroutine
    intent_map : dict, optional
        Mapping of instent strings to list of variables to override intents

    Returns
    -------
    tuple of :any:`CallStatement` and :any:`Subroutine`
        The newly created call and respectice subroutine.
    """
    intent_map = intent_map or {}
    imports = as_tuple(imports)
    imported_symbols = OrderedSet(var for imp in imports for var in imp.symbols)
    # Special-case for IFS-style C-imports
    imported_symbols |= OrderedSet(
        str(imp.module).split('.', maxsplit=1)[0] for imp in imports if imp.c_import
    )

    # Create the external subroutine containing the routine's imports and the region's body
    spec = Section(body=imports)
    body = Section(body=Transformer().visit(region.body))
    region_routine = Subroutine(name, spec=spec, body=body)

    # Filter derived-type component accesses and only use the root parent
    region_uses_symbols = OrderedSet(s.parents[0] if s.parent else s for s in region.uses_symbols)
    region_defines_symbols = OrderedSet(s.parents[0] if s.parent else s for s in region.defines_symbols)

    # Use dataflow analysis to find in, out and inout variables to that region
    # (ignoring any symbols that are external imports)
    region_in_args = region_uses_symbols - region_defines_symbols - imported_symbols
    region_inout_args = region_uses_symbols & region_defines_symbols - imported_symbols
    region_out_args = region_defines_symbols - region_uses_symbols - imported_symbols

    # Remove any parameters from in args
    region_in_args = OrderedSet(arg for arg in region_in_args if not arg.type.parameter)

    # Extract arguments given in pragma annotations
    pragma_in_args = OrderedSet(v.clone(scope=region_routine) for v in intent_map['in'])
    pragma_inout_args = OrderedSet(v.clone(scope=region_routine) for v in intent_map['inout'])
    pragma_out_args = OrderedSet(v.clone(scope=region_routine) for v in intent_map['out'])

    # Override arguments according to pragma annotations
    region_in_args = (region_in_args - (pragma_inout_args | pragma_out_args)) | pragma_in_args
    region_inout_args = (region_inout_args - (pragma_in_args | pragma_out_args)) | pragma_inout_args
    region_out_args = (region_out_args - (pragma_in_args | pragma_inout_args)) | pragma_out_args

    # Now fix the order
    region_inout_args = as_tuple(region_inout_args)
    region_in_args = as_tuple(region_in_args)
    region_out_args = as_tuple(region_out_args)

    # Set the list of variables used in region routine (to create declarations)
    # and put all in the new scope
    region_routine_variables = tuple(
        v.clone(dimensions=v.type.shape or None, scope=region_routine)
        for v in FindVariables().visit(region.body)
        if v.clone(dimensions=None) not in imported_symbols
    )
    # Filter out derived-type component variables from declarations
    region_routine_variables = tuple(
        v.parents[0] if v.parent else v for v in region_routine_variables
    )

    # Build the call signature
    region_routine_var_map = {v.name: v for v in region_routine_variables}
    region_routine_arguments = []
    for intent, args in zip(('in', 'inout', 'out'), (region_in_args, region_inout_args, region_out_args)):
        for arg in args:
            local_var = region_routine_var_map.get(arg.name, arg)
            # Sanitise argument types
            local_var = local_var.clone(
                type=local_var.type.clone(intent=intent, allocatable=None, target=None),
                scope=region_routine
            )

            region_routine_var_map[arg.name] = local_var
            region_routine_arguments += [local_var]

    # Order the arguments and local declaration lists and put arguments first
    region_routine_locals = tuple(
        v for v in region_routine_variables if not v in region_routine_arguments
    )
    region_routine_arguments = order_variables_by_type(region_routine_arguments, imports=imports)
    region_routine_locals = order_variables_by_type(region_routine_locals, imports=imports)

    region_routine.variables = region_routine_arguments + region_routine_locals
    region_routine.arguments = region_routine_arguments

    # Ensure everything has been rescoped
    region_routine.rescope_symbols()

    # Create the call according to the wrapped code region
    call_arg_map = {v.name: v for v in region_in_args + region_inout_args + region_out_args}
    call_arguments = tuple(call_arg_map[a.name] for a in region_routine_arguments)
    call = CallStatement(name=Variable(name=name), arguments=call_arguments, kwarguments=())

    return call, region_routine


def outline_pragma_regions(routine):
    """
    Convert regions annotated with ``!$loki outline`` pragmas to subroutine calls.

    The pragma syntax for regions to convert to subroutines is
    ``!$loki outline [name(...)] [in(...)] [out(...)] [inout(...)]``
    and ``!$loki end outline``.

    A new subroutine is created with the provided name (or an auto-generated default name
    derived from the current subroutine name) and the content of the pragma region as body.

    Variables provided with the ``in``, ``out`` and ``inout`` options are used as
    arguments in the routine with the corresponding intent, all other variables used in this
    region are assumed to be local variables.

    The pragma region in the original routine is replaced by a call to the new subroutine.

    Parameters
    ----------
    routine : :any:`Subroutine`
        The routine from which to extract marked pragma regions.

    Returns
    -------
    list of :any:`Subroutine`
        the list of newly created subroutines.
    """
    counter = 0
    routines = []
    imports = routine.imports
    parent_vmap = routine.variable_map
    mapper = {}
    with pragma_regions_attached(routine):
        with dataflow_analysis_attached(routine):
            for region in FindNodes(PragmaRegion).visit(routine.body):
                if not is_loki_pragma(region.pragma, starts_with='outline'):
                    continue

                # Name the external routine
                parameters = get_pragma_parameters(region.pragma, starts_with='outline')
                name = parameters.get('name', f'{routine.name}_outlined_{counter}')
                counter += 1

                # Extract explicitly requested symbols from context
                intent_map = {}
                intent_map['in'] = tuple(parent_vmap[v] for v in parameters.get('in', '').split(',') if v)
                intent_map['inout'] = tuple(parent_vmap[v] for v in parameters.get('inout', '').split(',') if v)
                intent_map['out'] = tuple(parent_vmap[v] for v in parameters.get('out', '').split(',') if v)

                call, region_routine = outline_region(region, name, imports, intent_map=intent_map)

                # insert into list of new routines
                routines.append(region_routine)

                # Replace region by call in original routine
                mapper[region] = call

            routine.body = Transformer(mapper=mapper).visit(routine.body)
    info('%s: converted %d region(s) to calls', routine.name, counter)

    return routines
loki-ecmwf-0.3.6/loki/transformations/argument_shape.py0000664000175000017500000002310715167130205023507 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Inter-procedural analysis passes to derive and augment argument array shapes.

A pair of utility :any:`Transformation` classes that allows thea shape
of array arguments with deferred dimensions to be derived from the
calling context via inter-procedural analysis.

To infer the declared dimensions of array arguments
:any:`ArgumentArrayShapeAnalysis` needs to be applied first to set the
``shape`` property on respective :any:`Array` symbols, before
:any:`ExplicitArgumentArrayShapeTransformation` can be applied in a
reverse traversal order to apply the necessary changes to argument
declarations and call signatures.
"""

from loki.batch import Transformation
from loki.expression import Array, symbols as sym, simplify
from loki.ir import (
    FindNodes, CallStatement, Transformer, FindVariables, SubstituteExpressions
)
from loki.tools import as_tuple, CaseInsensitiveDict, OrderedSet
from loki.types import BasicType


__all__ = ['ArgumentArrayShapeAnalysis', 'ExplicitArgumentArrayShapeTransformation']


class ArgumentArrayShapeAnalysis(Transformation):
    """
    Infer shape of array arguments with deferred shape.

    An inter-procedural analysis pass that passively infers the shape
    symbols for argument arrays from calling contexts and sets the
    ``shape`` attribute on :any:`Array` symbols accordingly.

    The shape information is propagated from a caller to the called
    subroutine in a forward traversal of the call-tree. If the
    call-side shape of an array argument is either set, or has already
    been derived (possibly with conflicting information), this
    transformation will have no effect.

    Note: This transformation does not affect the generated source
    code, as it only sets the ``shape`` property, which is ignored
    during the code generation step (:any:`fgen`). To actively change
    the argument array declarations and the resulting source code, the
    :any:`ExplicitArgumentArrayShapeTransformation` needs to be applied
    `after` this transformation.
    """

    def transform_subroutine(self, routine, **kwargs):  # pylint: disable=arguments-differ

        for call in FindNodes(CallStatement).visit(routine.body):

            # Skip if call-side info is not available or call is not active
            if call.routine is BasicType.DEFERRED or call.not_active:
                continue

            routine = call.routine

            # Create a variable map with new shape information from source
            vmap = {}
            for arg, val in call.arg_iter():
                if isinstance(arg, Array) and len(arg.shape) > 0:
                    # Only create new shapes for deferred dimension args
                    if all(d == ':' for d in arg.shape):
                        if len(val.shape) == len(arg.shape):
                            # We're passing the full value array, copy shape
                            vmap[arg] = arg.clone(type=arg.type.clone(shape=val.shape))
                        else:
                            # Passing a sub-array of val, find the right index
                            new_shape = [s for s, d in zip(val.shape, val.dimensions)
                                         if d == ':']
                            vmap[arg] = arg.clone(type=arg.type.clone(shape=new_shape))

                    elif str(arg.shape[-1])[-1] == '*':
                        expl_shape_len = len(arg.shape) - 1

                        dims = val.shape
                        if val.dimensions:
                            dims = [d for d in val.dimensions if not isinstance(d, sym.Scalar)]
                            # sanitise unbounded ranges in argument
                            dims = [sym.RangeIndex((d.lower or getattr(val.shape, 'lower', sym.IntLiteral(1)),
                                                    d.upper or getattr(val.shape, 'upper', val.shape[i])))
                                                    for i, d in enumerate(dims)]

                        # determine argument dimension sizes
                        sizes = [simplify(sym.Sum((getattr(d, 'upper', d),
                                          sym.Product((sym.IntLiteral(-1), getattr(d, 'lower', sym.IntLiteral(1)))),
                                          sym.IntLiteral(1)))) for d in dims]

                        # determine explicit size corresponding to assumed size
                        expl_size = simplify(sym.Product(tuple(s for s in sizes[expl_shape_len:])))

                        new_shape = list(arg.shape)
                        if isinstance(arg.shape[expl_shape_len], sym.RangeIndex):
                            lower = arg.shape[expl_shape_len].lower
                            upper = simplify(sym.Sum((expl_size, sym.IntLiteral(-1), lower)))
                            new_shape[expl_shape_len] = sym.RangeIndex((lower, upper))
                        else:
                            new_shape[expl_shape_len] = expl_size
                        vmap[arg] = arg.clone(type=arg.type.clone(shape=as_tuple(new_shape)))

            # Propagate the updated variables to variable definitions in routine
            routine.variables = [vmap.get(v, v) for v in routine.variables]

            # And finally propagate this to the variable instances
            vname_map = CaseInsensitiveDict((k.name, v) for k, v in vmap.items())
            vmap_body = {}
            for v in FindVariables(unique=False).visit(routine.body):
                if v.name in vname_map:
                    new_shape = vname_map[v.name].shape
                    vmap_body[v] = v.clone(type=v.type.clone(shape=new_shape))
            routine.body = SubstituteExpressions(vmap_body).visit(routine.body)


class ExplicitArgumentArrayShapeTransformation(Transformation):
    """
    Add dimensions to array arguments and adjust call signatures.

    Adjusts array argument declarations by inserting explicit shape
    variables according to the ``shape`` property of the :any:`Array`
    symbol. This property can be derived from the calling context via
    :any:`ArgumentArrayShapeAnalysis`.

    If the :any:`Scalar` symbol defining an array dimension is not yet
    known in the local :any:`Subroutine`, it gets added to the call
    signature. In the caller routine, the respective :any:`Scalar`
    argument is added to the :any:`CallStatement` via keyword-argument
    notation.

    Note: Since the :any:`CallStatement` needs updating after the called
    :any:`Subroutine` signature, this transformation has to be applied
    in reverse order via ``Scheduler.process(..., reverse=True)``.
    """

    # We need to traverse call tree in reverse to ensure called
    # procedures are updated before callers.
    reverse_traversal = True

    def transform_subroutine(self, routine, **kwargs):  # pylint: disable=arguments-differ

        def assumed(dims):
            return all(d == ':' for d in dims) or str(dims[-1])[-1] == '*'

        # First, replace assumed array shapes with concrete shapes for
        # all arguments if the shape is known.
        arg_map = {}
        for arg in routine.arguments:
            if isinstance(arg, Array):
                if not assumed(arg.shape) and assumed(arg.dimensions):
                    arg_map[arg] = arg.clone(dimensions=tuple(arg.shape))
        routine.spec = SubstituteExpressions(arg_map).visit(routine.spec)

        # We also need to ensure that all potential integer dimensions
        # are passed as arguments in deep subroutine call trees.
        call_map = {}
        for call in FindNodes(CallStatement).visit(routine.body):

            # Skip if call-side info is not available or call is not active
            if call.routine is BasicType.DEFERRED or call.not_active:
                continue

            callee = call.routine
            imported_symbols = callee.imported_symbols
            if callee.parent is not None:
                imported_symbols += callee.parent.imported_symbols

            # Collect all potential dimension variables and filter for scalar integers
            dims = OrderedSet(d for arg in callee.arguments if isinstance(arg, Array) for d in arg.shape)
            dim_vars = tuple(d for d in FindVariables().visit(as_tuple(dims)))

            # Add all new dimension arguments to the callee signature
            new_args = tuple(d for d in dim_vars if d not in callee.arguments)
            new_args = tuple(d for d in new_args if d.type.dtype == BasicType.INTEGER)
            new_args = tuple(d for d in new_args if d not in imported_symbols)
            new_args = tuple(d.clone(scope=routine, type=d.type.clone(intent='IN')) for d in new_args)
            callee.arguments += new_args

            # Map all local dimension args to unknown callee dimension args
            if len(callee.arguments) > len(list(call.arg_iter())):
                arg_keys = dict(call.arg_iter()).keys()
                missing = [a for a in callee.arguments if a not in arg_keys
                           and not a.type.optional and a in dim_vars]

                # Add missing dimension variables (scalars
                new_kwargs = tuple((str(m), m) for m in missing if m.type.dtype == BasicType.INTEGER)
                call_map[call] = call.clone(kwarguments=call.kwarguments + new_kwargs)

        # Replace all adjusted calls on the caller-side
        routine.body = Transformer(call_map).visit(routine.body)
loki-ecmwf-0.3.6/loki/transformations/build_system/0000775000175000017500000000000015167130205022633 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/transformations/build_system/module_wrap.py0000664000175000017500000001621715167130205025532 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.batch import Transformation, SchedulerConfig
from loki.expression import Variable
from loki.ir import Import, Section, Interface, FindNodes, Transformer
from loki.module import Module
from loki.subroutine import Subroutine
from loki.tools import as_tuple


__all__ = ['ModuleWrapTransformation']


class ModuleWrapTransformation(Transformation):
    """
    Utility transformation that ensures all transformed kernel
    subroutines are wrapped in a module

    The module name is derived from the subroutine name and :data:`module_suffix`.

    Any previous import of wrapped subroutines via interfaces or C-style header
    imports of interface blocks is replaced by a Fortran import (``USE``).

    Parameters
    ----------
    module_suffix : str
        Special suffix to signal module names like `_MOD`
    replace_ignore_items : bool
        Debug flag to toggle the replacement of calls to subroutines
        in the ``ignore``. Default is ``True``.
    """

    # This transformation is applied over the file graph
    traverse_file_graph = True

    # This transformation recurses from the Sourcefile down
    recurse_to_modules = True
    recurse_to_procedures = True
    recurse_to_internal_procedures = False

    # This transformation changes the names of items and creates new items
    renames_items = True
    creates_items = True

    def __init__(self, module_suffix, replace_ignore_items=True):
        self.module_suffix = module_suffix
        self.replace_ignore_items = replace_ignore_items

    def transform_file(self, sourcefile, **kwargs):
        """
        For kernel routines, wrap each subroutine in the current file in a module
        """
        role = kwargs.get('role')

        if items := kwargs.get('items'):
            # We consider the sourcefile to be a "kernel" file if all items are kernels
            if all(item.role == 'kernel' for item in items):
                role = 'kernel'
            else:
                role = 'driver'

        if role == 'kernel':
            self.module_wrap(sourcefile)

    def transform_module(self, module, **kwargs):
        """
        Update imports of wrapped subroutines
        """
        self.update_imports(module, imports=module.imports, **kwargs)

    def transform_subroutine(self, routine, **kwargs):
        """
        Update imports of wrapped subroutines
        """
        if item := kwargs.get('item'):
            # Rename the item if it has suddenly a parent
            if routine.parent and routine.parent.name.lower() != item.scope_name:
                item.name = f'{routine.parent.name.lower()}#{item.local_name}'

        # Note, C-style imports can be in the body, so use whole IR
        imports = FindNodes(Import).visit(routine.ir)
        self.update_imports(routine, imports=imports, **kwargs)

        # Interface blocks can only be in the spec
        intfs = FindNodes(Interface).visit(routine.spec)
        self.replace_interfaces(routine, intfs=intfs, **kwargs)

    def module_wrap(self, sourcefile):
        """
        Wrap target subroutines in modules and replace in source file.
        """
        for routine in sourcefile.subroutines:
            # Create wrapper module and insert into file, replacing the old
            # standalone routine
            modname = f'{routine.name}{self.module_suffix}'
            module = Module(name=modname, contains=Section(body=as_tuple(routine)))
            routine.parent = module
            sourcefile.ir._update(body=as_tuple(
                module if c is routine else c for c in sourcefile.ir.body
            ))

    def update_imports(self, source, imports, **kwargs):
        """
        Update imports of wrapped subroutines.
        """

        targets = tuple(str(t).lower() for t in as_tuple(kwargs.get('targets')))
        item = kwargs.get('item')
        if item and self.replace_ignore_items:
            targets += tuple(str(i).lower() for i in item.ignore)

        def _update_item(proc_name, module_name):
            if item and (matched_keys := SchedulerConfig.match_item_keys(proc_name, item.ignore)):
                # Add the module wrapped but ignored items to the block list because we won't be able to
                # find them as dependencies under their new name anymore
                item.config['block'] = as_tuple(item.block) + tuple(
                    module_name for name in item.ignore if name in matched_keys
                )

        # Transformer map to remove any outdated imports
        removal_map = {}

        # We go through the IR, as C-imports can be attributed to the body
        for im in imports:
            if im.c_import:
                target_symbol, *suffixes = im.module.lower().split('.', maxsplit=1)
                if targets and target_symbol.lower() in targets and not 'func.h' in suffixes:
                    # Create a new module import with explicitly qualified symbol
                    modname = f'{target_symbol}{self.module_suffix}'
                    _update_item(target_symbol.lower(), modname)
                    new_symbol = Variable(name=target_symbol, scope=source)
                    new_import = im.clone(module=modname, c_import=False, symbols=(new_symbol,))
                    source.spec.prepend(new_import)

                    # Mark current import for removal
                    removal_map[im] = None

        # Apply any scheduled import removals to spec and body
        if removal_map:
            source.spec = Transformer(removal_map).visit(source.spec)
            if isinstance(source, Subroutine):
                source.body = Transformer(removal_map).visit(source.body)

    def replace_interfaces(self, source, intfs, **kwargs):
        """
        Update explicit interfaces to actively transformed subroutines.
        """
        targets = tuple(str(t).lower() for t in as_tuple(kwargs.get('targets')))
        if self.replace_ignore_items and (item := kwargs.get('item')):
            targets += tuple(str(i).lower() for i in item.ignore)

        # Transformer map to remove any outdated interfaces
        removal_map = {}

        for i in intfs:
            for b in i.body:
                if isinstance(b, Subroutine):
                    if targets and b.name.lower() in targets:
                        # Create a new module import with explicitly qualified symbol
                        modname = f'{b.name}{self.module_suffix}'
                        new_symbol = Variable(name=f'{b.name}', scope=source)
                        new_import = Import(module=modname, c_import=False, symbols=(new_symbol,))
                        source.spec.prepend(new_import)

                        # Mark current import for removal
                        removal_map[i] = None

        # Apply any scheduled interface removals to spec
        if removal_map:
            source.spec = Transformer(removal_map).visit(source.spec)
loki-ecmwf-0.3.6/loki/transformations/build_system/__init__.py0000664000175000017500000000117415167130205024747 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.transformations.build_system.dependency import * # noqa
from loki.transformations.build_system.file_write import * # noqa
from loki.transformations.build_system.module_wrap import * # noqa
from loki.transformations.build_system.plan import * # noqa
loki-ecmwf-0.3.6/loki/transformations/build_system/tests/0000775000175000017500000000000015167130205023775 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/transformations/build_system/tests/__init__.py0000664000175000017500000000057015167130205026110 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
loki-ecmwf-0.3.6/loki/transformations/build_system/tests/test_dependency.py0000664000175000017500000011710115167130205027525 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import copy
import pytest

from loki import Sourcefile
from loki.batch import Scheduler, SchedulerConfig
from loki.frontend import available_frontends, OMNI
from loki.ir import (
    FindNodes, CallStatement, Import, Interface, Intrinsic, FindInlineCalls
)

from loki.transformations import (
    DependencyTransformation, ModuleWrapTransformation
)


@pytest.fixture(scope='function', name='config')
def fixture_config():
    return {
        'default': {
            'mode': 'idem',
            'role': 'kernel',
            'expand': True,
            'strict': True
        },
        'routines': {
            'driver': {'role': 'driver'},
            # 'driver_mod': {'role': 'driver'}
        }
    }


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('use_scheduler', [False, True])
def test_dependency_transformation_globalvar_imports(frontend, use_scheduler, tmp_path, config):
    """
    Test that global variable imports are not renamed as a
    call statement would be.
    """

    kernel_fcode = """
MODULE kernel_mod
    INTEGER :: some_const
CONTAINS
    SUBROUTINE kernel(a, b, c)
    IMPLICIT NONE
    INTEGER, INTENT(INOUT) :: a, b, c

    a = 1
    b = 2
    c = 3
  END SUBROUTINE kernel
END MODULE kernel_mod
    """.strip()

    driver_fcode = """
SUBROUTINE driver(a, b, c)
    USE kernel_mod, only: kernel
    USE kernel_mod, only: some_const
    IMPLICIT NONE
    INTEGER, INTENT(INOUT) :: a, b, c

    CALL kernel(a, b ,c)
END SUBROUTINE driver
    """.strip()

    transformation = DependencyTransformation(suffix='_test', module_suffix='_mod')

    if use_scheduler:
        (tmp_path/'kernel_mod.F90').write_text(kernel_fcode)
        (tmp_path/'driver.F90').write_text(driver_fcode)
        scheduler = Scheduler(
            paths=[tmp_path], config=SchedulerConfig.from_dict(config), frontend=frontend, xmods=[tmp_path]
        )
        scheduler.process(transformation)

        # Check that both, old and new module exist now in the scheduler graph
        assert 'kernel_test_mod#kernel_test' in scheduler.items  # for the subroutine
        assert 'kernel_mod' in scheduler.items  # for the global variable

        kernel = scheduler['kernel_test_mod#kernel_test'].source
        driver = scheduler['#driver'].source

        # Check that the not-renamed module is indeed the original one
        scheduler.item_factory.item_cache[str(tmp_path/'kernel_mod.F90')].source.make_complete(
            frontend=frontend, xmods=[tmp_path]
        )
        assert (
            Sourcefile.from_source(kernel_fcode, frontend=frontend, xmods=[tmp_path]).to_fortran() ==
            scheduler.item_factory.item_cache[str(tmp_path/'kernel_mod.F90')].source.to_fortran()
        )

    else:
        kernel = Sourcefile.from_source(kernel_fcode, frontend=frontend, xmods=[tmp_path])
        driver = Sourcefile.from_source(driver_fcode, frontend=frontend, xmods=[tmp_path])

        kernel.apply(transformation, role='kernel')
        driver['driver'].apply(transformation, role='driver', targets=('kernel', 'kernel_mod'))

    # Check that the global variable declaration remains unchanged
    assert kernel.modules[0].variables[0].name == 'some_const'

    # Check that calls and matching import have been diverted to the re-generated routine
    calls = FindNodes(CallStatement).visit(driver['driver'].body)
    assert len(calls) == 1
    assert calls[0].name == 'kernel_test'
    imports = FindNodes(Import).visit(driver['driver'].spec)
    assert len(imports) == 2
    assert isinstance(imports[0], Import)
    assert driver['driver'].spec.body[0].module == 'kernel_test_mod'
    assert 'kernel_test' in [str(s) for s in driver['driver'].spec.body[0].symbols]

    # Check that global variable import remains unchanged
    assert isinstance(imports[1], Import)
    assert driver['driver'].spec.body[1].module == 'kernel_mod'
    assert 'some_const' in [str(s) for s in driver['driver'].spec.body[1].symbols]


@pytest.mark.parametrize('frontend', available_frontends(skip=[(OMNI, 'OMNI removes access specifiers ...')]))
@pytest.mark.parametrize('use_scheduler', [False, True])
def test_dependency_transformation_access_spec_names(frontend, use_scheduler, tmp_path, config):
    """
    Test that global variable imports are not renamed as a
    call statement would be.
    """

    kernel_fcode = """
MODULE kernel_access_spec_mod

  INTEGER, PUBLIC :: some_const
  INTEGER :: another_const

  type :: t_type_1
    integer :: i_1
  end type

  type :: t_type_2
    integer :: i_2
  end type

PRIVATE
PUBLIC kernel, kernel_2, unused_kernel, another_const, t_type_2
CONTAINS
    SUBROUTINE kernel(a, b, c)
    IMPLICIT NONE
    INTEGER, INTENT(INOUT) :: a, b, c

    call kernel_2(a, b)
    call kernel_3(c)
  END SUBROUTINE kernel
  SUBROUTINE kernel_2(a, b)
    IMPLICIT NONE
    INTEGER, INTENT(INOUT) :: a, b

    a = 1
    b = 2
  END SUBROUTINE kernel_2
  SUBROUTINE kernel_3(a)
    IMPLICIT NONE
    INTEGER, INTENT(INOUT) :: a

    a = 3
  END SUBROUTINE kernel_3
  SUBROUTINE unused_kernel(a)
    IMPLICIT NONE
    INTEGER, INTENT(INOUT) :: a

    a = 3
  END SUBROUTINE unused_kernel
END MODULE kernel_access_spec_mod
    """.strip()

    driver_fcode = """
SUBROUTINE driver(a, b, c)
    USE kernel_access_spec_mod, only: kernel, another_const
    USE kernel_access_spec_mod, only: some_const
    IMPLICIT NONE
    INTEGER, INTENT(INOUT) :: a, b, c

    CALL kernel(a, b, c)
END SUBROUTINE driver
    """.strip()

    transformation = DependencyTransformation(suffix='_test', module_suffix='_mod')
    if use_scheduler:
        (tmp_path/'kernel_access_spec_mod.F90').write_text(kernel_fcode)
        (tmp_path/'driver.F90').write_text(driver_fcode)
        scheduler = Scheduler(
            paths=[tmp_path], config=SchedulerConfig.from_dict(config), frontend=frontend, xmods=[tmp_path]
        )
        scheduler.process(transformation)

        # Check that both, old and new module exist now in the scheduler graph
        assert 'kernel_access_spec_test_mod#kernel_test' in scheduler.items  # for the subroutine
        assert 'kernel_access_spec_mod' in scheduler.items  # for the global variable

        kernel = scheduler['kernel_access_spec_test_mod#kernel_test'].source
        driver = scheduler['#driver'].source

        # Check that the not-renamed module is indeed the original one
        scheduler.item_factory.item_cache[str(tmp_path/'kernel_access_spec_mod.F90')].source.make_complete(
            frontend=frontend, xmods=[tmp_path]
        )
        assert (
            Sourcefile.from_source(kernel_fcode, frontend=frontend, xmods=[tmp_path]).to_fortran() ==
            scheduler.item_factory.item_cache[str(tmp_path/'kernel_access_spec_mod.F90')].source.to_fortran()
        )

    else:
        kernel = Sourcefile.from_source(kernel_fcode, frontend=frontend, xmods=[tmp_path])
        driver = Sourcefile.from_source(driver_fcode, frontend=frontend, xmods=[tmp_path],
                                        definitions=kernel.definitions)

        kernel.apply(transformation, role='kernel')
        driver['driver'].apply(transformation, role='driver', targets=('kernel', 'kernel_access_spec_mod'))

    # Check that the global variable declaration remains unchanged
    assert kernel.modules[0].variables[0].name == 'some_const'
    assert kernel.modules[0].variables[1].name == 'another_const'

    # Check that the typedefs remains unchanged
    assert kernel.modules[0].typedefs[0].name == 't_type_1'
    assert kernel.modules[0].typedefs[1].name == 't_type_2'

    # Check that calls and matching import have been diverted to the re-generated routine
    calls = FindNodes(CallStatement).visit(driver['driver'].body)
    assert len(calls) == 1
    assert calls[0].name == 'kernel_test'
    imports = FindNodes(Import).visit(driver['driver'].spec)
    assert len(imports) == 3
    assert isinstance(imports[0], Import)
    assert driver['driver'].spec.body[0].module == 'kernel_access_spec_test_mod'
    assert 'kernel_test' in [str(s) for s in driver['driver'].spec.body[0].symbols]

    # Check that global variable import remains unchanged
    assert isinstance(imports[1], Import)
    assert driver['driver'].spec.body[1].module == 'kernel_access_spec_mod'
    assert 'another_const' in [str(s) for s in driver['driver'].spec.body[1].symbols]
    assert 'some_const' in [str(s) for s in driver['driver'].spec.body[2].symbols]

    if use_scheduler:
        assert kernel.modules[0].public_access_spec == ('kernel_test', 'kernel_2_test', 'another_const',
                                                        't_type_2')
    else:
        assert kernel.modules[0].public_access_spec == ('kernel_test', 'kernel_2_test', 'unused_kernel_test',
                                                        'another_const', 't_type_2')


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('use_scheduler', [False, True])
def test_dependency_transformation_globalvar_imports_driver_mod(frontend, use_scheduler, tmp_path, config):
    """
    Test that global variable imports are not renamed as a
    call statement would be.
    """

    kernel_fcode = """
MODULE kernel_mod
    INTEGER :: some_const
CONTAINS
    SUBROUTINE kernel(a, b, c)
    INTEGER, INTENT(INOUT) :: a, b, c

    a = 1
    b = 2
    c = 3
  END SUBROUTINE kernel
END MODULE kernel_mod
    """.strip()

    driver_fcode = """
MODULE DRIVER_MOD
    USE kernel_mod, only: kernel
    USE kernel_mod, only: some_const
CONTAINS
SUBROUTINE driver(a, b, c)
    INTEGER, INTENT(INOUT) :: a, b, c

    CALL kernel(a, b ,c)
END SUBROUTINE driver
END MODULE DRIVER_MOD
    """.strip()

    transformation = DependencyTransformation(suffix='_test', module_suffix='_mod')

    if use_scheduler:
        (tmp_path/'kernel_mod.F90').write_text(kernel_fcode)
        (tmp_path/'driver_mod.F90').write_text(driver_fcode)
        scheduler = Scheduler(
            paths=[tmp_path], config=SchedulerConfig.from_dict(config), frontend=frontend, xmods=[tmp_path]
        )
        scheduler.process(transformation)

        kernel = scheduler['kernel_test_mod#kernel_test'].source
        driver = scheduler['driver_mod#driver'].source

    else:
        kernel = Sourcefile.from_source(kernel_fcode, frontend=frontend, xmods=[tmp_path])
        driver = Sourcefile.from_source(driver_fcode, frontend=frontend, xmods=[tmp_path])

        kernel.apply(transformation, role='kernel')
        driver.apply(transformation, role='driver', targets=('kernel', 'kernel_mod'))

    # Check that the global variable declaration remains unchanged
    assert kernel.modules[0].variables[0].name == 'some_const'

    # Check that calls and matching import have been diverted to the re-generated routine
    calls = FindNodes(CallStatement).visit(driver['driver'].body)
    assert len(calls) == 1
    assert calls[0].name == 'kernel_test'
    imports = FindNodes(Import).visit(driver['driver_mod'].spec)
    assert len(imports) == 2
    assert isinstance(imports[0], Import)
    assert driver['driver_mod'].spec.body[0].module == 'kernel_test_mod'
    assert 'kernel_test' in [str(s) for s in driver['driver_mod'].spec.body[0].symbols]

    # Check that global variable import remains unchanged
    assert isinstance(imports[1], Import)
    assert driver['driver_mod'].spec.body[1].module == 'kernel_mod'
    assert 'some_const' in [str(s) for s in driver['driver_mod'].spec.body[1].symbols]


@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'C-imports need pre-processing for OMNI')]))
def test_dependency_transformation_header_includes(tmp_path, frontend):
    """
    Test injection of suffixed kernels into unchanged driver
    routines via c-header includes.
    """

    driver = Sourcefile.from_source(source="""
SUBROUTINE driver(a, b, c)
  INTEGER, INTENT(INOUT) :: a, b, c

#include "myfunc.intfb.h"
#include "myfunc.func.h"

  CALL myfunc(a, b ,c)
END SUBROUTINE driver
""", frontend=frontend)

    kernel = Sourcefile.from_source(source="""
SUBROUTINE myfunc(a, b, c)
  INTEGER, INTENT(INOUT) :: a, b, c

  a = 1
  b = 2
  c = 3
END SUBROUTINE myfunc
""", frontend=frontend)

    # Ensure header file does not exist a-priori
    header_file = tmp_path/'myfunc_test.intfb.h'
    if header_file.exists():
        header_file.unlink()

    # Apply injection transformation via C-style includes by giving `include_path`
    transformation = DependencyTransformation(suffix='_test', include_path=tmp_path)
    kernel['myfunc'].apply(transformation, role='kernel')
    driver['driver'].apply(transformation, role='driver', targets='myfunc')

    # Check that the subroutine name in the kernel source has changed
    assert len(kernel.modules) == 0
    assert len(kernel.subroutines) == 1
    assert kernel.subroutines[0].name == 'myfunc_test'
    assert kernel['myfunc_test'] == kernel.all_subroutines[0]

    # Check that the driver name has not changed
    assert len(kernel.modules) == 0
    assert len(kernel.subroutines) == 1
    assert driver.subroutines[0].name == 'driver'

    # Check that the import has been updated
    assert '#include "myfunc.intfb.h"' not in driver.to_fortran()
    assert '#include "myfunc_test.intfb.h"' in driver.to_fortran()

    # Check that imported function was not modified
    assert '#include "myfunc.func.h"' in driver.to_fortran()

    # Check that header file was generated and clean up
    assert header_file.exists()
    header_file.unlink()


@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'C-imports need pre-processing for OMNI')]))
@pytest.mark.parametrize('use_scheduler', [False, True])
@pytest.mark.parametrize('replace_ignore_items', [False, True])
def test_dependency_transformation_module_wrap(frontend, use_scheduler, replace_ignore_items, tmp_path, config):
    """
    Test injection of suffixed kernels into unchanged driver
    routines automatic module wrapping of the kernel.
    """

    driver_fcode = """
SUBROUTINE driver(a, b, c)
  INTEGER, INTENT(INOUT) :: a, b, c

#include "kernel.func.h"
#include "kernel.intfb.h"
#include "other_kernel.intfb.h"

  CALL kernel(a, b ,c)
  CALL other_kernel(a, b ,c)
END SUBROUTINE driver
    """.strip()

    kernel_fcode = """
SUBROUTINE kernel(a, b, c)
  INTEGER, INTENT(INOUT) :: a, b, c

  a = 1
  b = 2
  c = 3
END SUBROUTINE kernel
    """.strip()

    other_kernel_fcode = """
SUBROUTINE other_kernel(a, b, c)
  INTEGER, INTENT(INOUT) :: a, b, c

  a = 1
  b = 2
  c = 3
END SUBROUTINE other_kernel
    """.strip()

    transformations = (
        ModuleWrapTransformation(module_suffix='_mod', replace_ignore_items=replace_ignore_items),
        DependencyTransformation(suffix='_test', module_suffix='_mod', replace_ignore_items=replace_ignore_items)
    )

    if use_scheduler:
        (tmp_path/'kernel.F90').write_text(kernel_fcode)
        (tmp_path/'other_kernel.F90').write_text(other_kernel_fcode)
        (tmp_path/'driver.F90').write_text(driver_fcode)

        _config = copy.deepcopy(config)
        if not replace_ignore_items:
            _config['default'].update({'block': ['kernel']})

        scheduler = Scheduler(
            paths=[tmp_path], config=SchedulerConfig.from_dict(_config), frontend=frontend, xmods=[tmp_path]
        )
        for transformation in transformations:
            scheduler.process(transformation)

        if replace_ignore_items:
            kernel = scheduler['kernel_test_mod#kernel_test'].source
        else:
            for item_name in ['#kernel', 'kernel_mod#kernel', 'kernel_test_mod#kernel_test']:
                with pytest.raises(AttributeError):
                    _ = scheduler[item_name].source
            kernel = Sourcefile.from_source(kernel_fcode, frontend=frontend, xmods=[tmp_path])
        other_kernel = scheduler['other_kernel_test_mod#other_kernel_test'].source
        driver = scheduler['#driver'].source

    else:
        kernel = Sourcefile.from_source(kernel_fcode, frontend=frontend, xmods=[tmp_path])
        other_kernel = Sourcefile.from_source(other_kernel_fcode, frontend=frontend, xmods=[tmp_path])
        driver = Sourcefile.from_source(driver_fcode, frontend=frontend, xmods=[tmp_path])

        kernel.apply(transformations[0], role='kernel')
        other_kernel.apply(transformations[0], role='kernel')
        driver['driver'].apply(transformations[0], role='driver', targets=('kernel', 'other_kernel'))
        kernel.apply(transformations[1], role='kernel')
        other_kernel.apply(transformations[1], role='kernel')
        driver['driver'].apply(transformations[1], role='driver', targets=('kernel_mod', 'kernel', 'other_kernel'))

    # Check that the kernels have been wrapped
    if use_scheduler and not replace_ignore_items:
        assert len(kernel.subroutines) == 1
        assert kernel.subroutines[0].name == 'kernel'
        assert kernel['kernel'] == kernel.subroutines[0]
    else:
        assert len(kernel.subroutines) == 0
        assert len(kernel.all_subroutines) == 1
        assert kernel.all_subroutines[0].name == 'kernel_test'
        assert kernel['kernel_test'] == kernel.all_subroutines[0]
        assert len(kernel.modules) == 1
        assert kernel.modules[0].name == 'kernel_test_mod'
        assert kernel['kernel_test_mod'] == kernel.modules[0]

    assert len(other_kernel.subroutines) == 0
    assert len(other_kernel.all_subroutines) == 1
    assert other_kernel.all_subroutines[0].name == 'other_kernel_test'
    assert other_kernel['other_kernel_test'] == other_kernel.all_subroutines[0]
    assert len(other_kernel.modules) == 1
    assert other_kernel.modules[0].name == 'other_kernel_test_mod'
    assert other_kernel['other_kernel_test_mod'] == other_kernel.modules[0]

    # Check that the driver name has not changed
    assert len(driver.modules) == 0
    assert len(driver.subroutines) == 1
    assert driver.subroutines[0].name == 'driver'

    # Check that calls and imports have been diverted to the re-generated routine
    calls = FindNodes(CallStatement).visit(driver['driver'].body)
    assert len(calls) == 2
    imports = FindNodes(Import).visit(driver['driver'].ir)
    assert len(imports) == 3

    _imported_symbols = driver['driver'].imported_symbols
    _imported_modules = [str(i.module) for i in driver['driver'].imports]

    if use_scheduler and not replace_ignore_items:
        assert calls[0].name == 'kernel'
        assert 'kernel.intfb.h' in _imported_modules
    else:
        assert calls[0].name == 'kernel_test'
        assert 'kernel_test' in _imported_symbols
        assert 'kernel_test_mod' in _imported_modules

    assert calls[1].name == 'other_kernel_test'
    assert 'kernel.func.h' in _imported_modules
    assert 'other_kernel_test_mod' in _imported_modules
    assert 'other_kernel_test' in _imported_symbols


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('use_scheduler', [False, True])
@pytest.mark.parametrize('module_wrap', [True, False])
def test_dependency_transformation_replace_interface(frontend, use_scheduler, module_wrap, tmp_path, config):
    """
    Test injection of suffixed kernels defined in interface block
    into unchanged driver routines automatic module wrapping of the kernel.
    """

    driver_fcode = """
SUBROUTINE driver(a, b, c)
  IMPLICIT NONE
  INTERFACE
    SUBROUTINE KERNEL(a, b, c)
      INTEGER, INTENT(INOUT) :: a, b, c
    END SUBROUTINE KERNEL
  END INTERFACE

  INTEGER, INTENT(INOUT) :: a, b, c

  CALL kernel(a, b ,c)
END SUBROUTINE driver
    """.strip()

    kernel_fcode = """
SUBROUTINE kernel(a, b, c)
  INTEGER, INTENT(INOUT) :: a, b, c

  a = 1
  b = 2
  c = 3
END SUBROUTINE kernel
    """.strip()

    # Apply injection transformation via C-style includes by giving `include_path`
    transformations = []
    if module_wrap:
        transformations += [ModuleWrapTransformation(module_suffix='_mod')]
    transformations += [DependencyTransformation(suffix='_test', include_path=tmp_path, module_suffix='_mod')]

    if use_scheduler:
        (tmp_path/'kernel.F90').write_text(kernel_fcode)
        (tmp_path/'driver.F90').write_text(driver_fcode)
        scheduler = Scheduler(
            paths=[tmp_path], config=SchedulerConfig.from_dict(config), frontend=frontend, xmods=[tmp_path]
        )
        for transformation in transformations:
            scheduler.process(transformation)

        if module_wrap:
            kernel = scheduler['kernel_test_mod#kernel_test'].source
        else:
            kernel = scheduler['#kernel_test'].source
        driver = scheduler['#driver'].source

    else:
        kernel = Sourcefile.from_source(kernel_fcode, frontend=frontend, xmods=[tmp_path])
        driver = Sourcefile.from_source(driver_fcode, frontend=frontend, xmods=[tmp_path])

        targets = ('kernel',)
        for transformation in transformations:
            kernel.apply(transformation, role='kernel')
            driver.apply(transformation, role='driver', targets=targets)
            # The import becomes another target after the ModuleWrapTransformation
            targets += ('kernel_mod',)

    # Check that the kernel has been wrapped
    if module_wrap:
        assert len(kernel.subroutines) == 0
        assert len(kernel.all_subroutines) == 1
        assert len(kernel.modules) == 1
        assert kernel.modules[0].name == 'kernel_test_mod'
        assert kernel['kernel_test_mod'] == kernel.modules[0]
    else:
        assert len(kernel.subroutines) == 1
        assert len(kernel.modules) == 0
    assert kernel.all_subroutines[0].name == 'kernel_test'
    assert kernel['kernel_test'] == kernel.all_subroutines[0]

    # Check that the driver name has not changed
    assert len(driver.modules) == 0
    assert len(driver.subroutines) == 1
    assert driver.subroutines[0].name == 'driver'

    # Check that calls have been diverted to the re-generated routine
    calls = FindNodes(CallStatement).visit(driver['driver'].body)
    assert len(calls) == 1
    assert calls[0].name == 'kernel_test'

    if module_wrap:
        # Check that imports have been generated
        imports = FindNodes(Import).visit(driver['driver'].spec)
        assert len(imports) == 1
        assert imports[0].module.lower() == 'kernel_test_mod'
        assert 'kernel_test' in imports[0].symbols

        # Check that the newly generated USE statement appears before IMPLICIT NONE
        nodes = FindNodes((Intrinsic, Import)).visit(driver['driver'].spec)
        assert len(nodes) == 2
        assert isinstance(nodes[1], Intrinsic)
        assert nodes[1].text.lower() == 'implicit none'

    else:
        # Check that the interface has been updated
        intfs = FindNodes(Interface).visit(driver['driver'].spec)
        assert len(intfs) == 1
        assert intfs[0].symbols == ('kernel_test',)


@pytest.mark.parametrize('frontend', available_frontends())
def test_dependency_transformation_inline_call(frontend):
    """
    Test injection of suffixed kernel, accessed through inline function call.
    """

    driver = Sourcefile.from_source(source="""
SUBROUTINE driver(a, b, c)
  INTERFACE
    INTEGER FUNCTION kernel(a)
      INTEGER, INTENT(IN) :: a
    END FUNCTION kernel
  END INTERFACE

  INTEGER, INTENT(INOUT) :: a, b, c

  a = kernel(a)
  b = kernel(a)
  c = kernel(c)
END SUBROUTINE driver
""", frontend=frontend)

    kernel = Sourcefile.from_source(source="""
INTEGER FUNCTION kernel(a)
  INTEGER, INTENT(IN) :: a

  kernel = 2*a
END FUNCTION kernel
""", frontend=frontend)

    # Apply injection transformation via C-style includes by giving `include_path`
    transformations = (
        ModuleWrapTransformation(module_suffix='_mod'),
        DependencyTransformation(suffix='_test', module_suffix='_mod')
    )
    targets = ('kernel',)
    for transformation in transformations:
        kernel.apply(transformation, role='kernel')
        driver.apply(transformation, role='driver', targets=targets)
        # The import becomes another target after the ModuleWrapTransformation
        targets += ('kernel_mod',)

    # Check that the kernel has been wrapped
    assert len(kernel.subroutines) == 0
    assert len(kernel.all_subroutines) == 1
    assert kernel.all_subroutines[0].name == 'kernel_test'
    assert kernel['kernel_test'] == kernel.all_subroutines[0]
    assert kernel['kernel_test'].is_function
    assert len(kernel.modules) == 1
    assert kernel.modules[0].name == 'kernel_test_mod'
    assert kernel['kernel_test_mod'] == kernel.modules[0]

    # Check that the return name hasn't changed
    assert 'kernel' in kernel['kernel_test'].symbol_attrs
    assert kernel['kernel_test'].result_name == 'kernel'

    # Check that the driver name has not changed
    assert len(driver.modules) == 0
    assert len(driver.subroutines) == 1
    assert driver.subroutines[0].name == 'driver'

    # Check that calls and imports have been diverted to the re-generated routine
    calls = tuple(FindInlineCalls().visit(driver['driver'].body))
    assert len(calls) == 2
    calls = tuple(FindInlineCalls(unique=False).visit(driver['driver'].body))
    assert len(calls) == 3
    assert calls[0].name == 'kernel_test'
    imports = FindNodes(Import).visit(driver['driver'].spec)
    assert len(imports) == 1
    assert imports[0].module == 'kernel_test_mod'
    assert 'kernel_test' in [str(s) for s in imports[0].symbols]


@pytest.mark.parametrize('frontend', available_frontends())
def test_dependency_transformation_inline_call_result_var(frontend):
    """
    Test injection of suffixed kernel, accessed through inline function call.
    """

    driver = Sourcefile.from_source(source="""
SUBROUTINE driver(a, b, c)
  INTERFACE
    FUNCTION kernel(a) RESULT(ret)
      INTEGER, INTENT(IN) :: a
      INTEGER :: ret
    END FUNCTION kernel
  END INTERFACE

  INTEGER, INTENT(INOUT) :: a, b, c

  a = kernel(a)
  b = kernel(a)
  c = kernel(c)
END SUBROUTINE driver
""", frontend=frontend)

    kernel = Sourcefile.from_source(source="""
FUNCTION kernel(a) RESULT(ret)
  INTEGER, INTENT(IN) :: a
  INTEGER :: ret

  ret = 2*a
END FUNCTION kernel
""", frontend=frontend)

    # Apply injection transformation via C-style includes by giving `include_path`
    transformations = (
        ModuleWrapTransformation(module_suffix='_mod'),
        DependencyTransformation(suffix='_test', module_suffix='_mod')
    )
    targets = ('kernel',)
    for transformation in transformations:
        kernel.apply(transformation, role='kernel')
        driver.apply(transformation, role='driver', targets=targets)
        # The import becomes another target after the ModuleWrapTransformation
        targets += ('kernel_mod',)

    # Check that the kernel has been wrapped
    assert len(kernel.subroutines) == 0
    assert len(kernel.all_subroutines) == 1
    assert kernel.all_subroutines[0].name == 'kernel_test'
    assert kernel['kernel_test'] == kernel.all_subroutines[0]
    assert kernel['kernel_test'].is_function
    assert len(kernel.modules) == 1
    assert kernel.modules[0].name == 'kernel_test_mod'
    assert kernel['kernel_test_mod'] == kernel.modules[0]

    # Check that the driver name has not changed
    assert len(driver.modules) == 0
    assert len(driver.subroutines) == 1
    assert driver.subroutines[0].name == 'driver'

    # Check that calls and imports have been diverted to the re-generated routine
    calls = tuple(FindInlineCalls().visit(driver['driver'].body))
    assert len(calls) == 2
    calls = tuple(FindInlineCalls(unique=False).visit(driver['driver'].body))
    assert len(calls) == 3
    assert calls[0].name == 'kernel_test'
    imports = FindNodes(Import).visit(driver['driver'].spec)
    assert len(imports) == 1
    assert imports[0].module == 'kernel_test_mod'
    assert 'kernel_test' in [str(s) for s in imports[0].symbols]


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('use_scheduler', [False, True])
def test_dependency_transformation_contained_member(frontend, use_scheduler, tmp_path, config):
    """
    The scheduler currently does not recognize or allow processing contained member routines as part
    of the scheduler graph traversal. This test ensures that the transformation class
    does not recurse into contained members.
    """

    kernel_fcode = """
MODULE kernel_mod
    IMPLICIT NONE
CONTAINS
    SUBROUTINE kernel(a, b, c)
    INTEGER, INTENT(INOUT) :: a, b, c

    call set_a(1)
    b = get_b()
    c = 3

    CONTAINS

        SUBROUTINE SET_A(VAL)
            INTEGER, INTENT(IN) :: VAL
            A = VAL
        END SUBROUTINE SET_A

        FUNCTION GET_B()
            INTEGER GET_B
            GET_B = 2
        END FUNCTION GET_B
  END SUBROUTINE kernel
END MODULE kernel_mod
    """.strip()

    driver_fcode = """
SUBROUTINE driver(a, b, c)
    USE kernel_mod, only: kernel
    IMPLICIT NONE
    INTEGER, INTENT(INOUT) :: a, b, c

    CALL kernel(a, b ,c)
END SUBROUTINE driver
    """.strip()

    transformation = DependencyTransformation(suffix='_test', module_suffix='_mod')

    if use_scheduler:
        (tmp_path/'kernel_mod.F90').write_text(kernel_fcode)
        (tmp_path/'driver.F90').write_text(driver_fcode)
        scheduler = Scheduler(
            paths=[tmp_path], config=SchedulerConfig.from_dict(config), frontend=frontend, xmods=[tmp_path]
        )
        scheduler.process(transformation)

        kernel = scheduler['kernel_test_mod#kernel_test'].source
        driver = scheduler['#driver'].source
    else:
        kernel = Sourcefile.from_source(kernel_fcode, frontend=frontend, xmods=[tmp_path])
        driver = Sourcefile.from_source(driver_fcode, frontend=frontend, xmods=[tmp_path])

        kernel.apply(transformation, role='kernel', targets=('set_a', 'get_b'))
        driver['driver'].apply(transformation, role='driver', targets=('kernel', 'kernel_mod'))

    # Check that calls and matching import have been diverted to the re-generated routine
    calls = FindNodes(CallStatement).visit(driver['driver'].body)
    assert len(calls) == 1
    assert calls[0].name == 'kernel_test'
    imports = FindNodes(Import).visit(driver['driver'].spec)
    assert len(imports) == 1
    assert imports[0].module.lower() == 'kernel_test_mod'
    assert imports[0].symbols == ('kernel_test',)

    # Check that the kernel has been renamed
    assert kernel.modules[0].name.lower() == 'kernel_test_mod'
    assert kernel.modules[0].subroutines[0].name.lower() == 'kernel_test'

    # Check if contained member has been renamed
    assert kernel['kernel_test'].subroutines[0].name.lower() == 'set_a'
    assert kernel['kernel_test'].subroutines[1].name.lower() == 'get_b'

    # Check if kernel calls have been renamed
    calls = FindNodes(CallStatement).visit(kernel['kernel_test'].body)
    assert len(calls) == 1
    assert calls[0].name == 'set_a'

    calls = FindInlineCalls(unique=False).visit(kernel['kernel_test'].body)
    assert len(calls) == 1
    assert calls[0].name == 'get_b'


@pytest.mark.parametrize('frontend', available_frontends())
def test_dependency_transformation_item_filter(frontend, tmp_path, config):
    """
    Test that injection is not applied to modules that have no procedures
    in the scheduler graph, even if they have other item members.
    """

    driver_fcode = """
SUBROUTINE driver(a, b, c)
  USE HEADER_MOD, ONLY: HEADER_VAR
  USE KERNEL_MOD, ONLY: KERNEL
  IMPLICIT NONE

  INTEGER, INTENT(INOUT) :: a, b, c

  a = kernel(a)
  b = kernel(a)
  c = kernel(c) + HEADER_VAR
END SUBROUTINE driver
    """.strip()

    kernel_fcode = """
MODULE kernel_mod
IMPLICIT NONE
CONTAINS
FUNCTION kernel(a) RESULT(ret)
  INTEGER, INTENT(IN) :: a
  INTEGER :: ret

  ret = 2*a
END FUNCTION kernel
END MODULE kernel_mod
    """.strip()

    header_fcode = """
MODULE header_mod
    IMPLICIT NONE
    INTEGER :: HEADER_VAR
END MODULE header_mod
    """.strip()

    (tmp_path/'kernel_mod.F90').write_text(kernel_fcode)
    (tmp_path/'header_mod.F90').write_text(header_fcode)
    (tmp_path/'driver.F90').write_text(driver_fcode)

    # Create the scheduler such that it chases imports
    config['default']['enable_imports'] = True
    scheduler = Scheduler(
        paths=[tmp_path], config=SchedulerConfig.from_dict(config), frontend=frontend, xmods=[tmp_path]
    )

    # Make sure the header module item exists
    assert 'header_mod' in scheduler.items

    transformations = (
        ModuleWrapTransformation(module_suffix='_mod'),
        DependencyTransformation(suffix='_test', module_suffix='_mod')
    )
    for transformation in transformations:
        scheduler.process(transformation)

    kernel = scheduler['kernel_test_mod#kernel_test'].source
    header = scheduler['header_mod'].source
    driver = scheduler['#driver'].source

    # Check that the kernel mod has been changed
    assert len(kernel.subroutines) == 0
    assert len(kernel.all_subroutines) == 1
    assert kernel.all_subroutines[0].name == 'kernel_test'
    assert kernel['kernel_test'] == kernel.all_subroutines[0]
    assert kernel['kernel_test'].is_function
    assert len(kernel.modules) == 1
    assert kernel.modules[0].name == 'kernel_test_mod'
    assert kernel['kernel_test_mod'] == kernel.modules[0]

    # Check that the header name has not been changed
    assert len(header.modules) == 1
    assert header.modules[0].name == 'header_mod'
    assert header.modules[0].variables == ('header_var',)

    # Check that the driver name has not changed
    assert len(driver.modules) == 0
    assert len(driver.subroutines) == 1
    assert driver.subroutines[0].name == 'driver'

    # Check that calls and imports have been diverted to the re-generated routine
    calls = tuple(FindInlineCalls().visit(driver['driver'].body))
    assert len(calls) == 2
    calls = tuple(FindInlineCalls(unique=False).visit(driver['driver'].body))
    assert len(calls) == 3
    assert all(call.name == 'kernel_test' for call in calls)
    imports = FindNodes(Import).visit(driver['driver'].spec)
    imports = driver['driver'].import_map
    assert len(imports) == 2
    assert 'header_var' in imports and imports['header_var'].module.lower() == 'header_mod'
    assert 'kernel_test' in imports and imports['kernel_test'].module.lower() == 'kernel_test_mod'


@pytest.mark.parametrize('frontend', available_frontends())
def test_dependency_transformation_filter_items_file_graph(tmp_path, frontend, config):
    """
    Ensure that the ``items`` list given to a transformation in
    a file graph traversal is filtered to include only used items
    """
    fcode = """
module test_dependency_transformation_filter_items1_mod
implicit none
contains
subroutine proc1(arg)
    integer, intent(inout) :: arg
    arg = arg + 1
end subroutine proc1

subroutine unused_proc(arg)
    integer, intent(inout) :: arg
    arg = arg - 1
end subroutine unused_proc
end module test_dependency_transformation_filter_items1_mod

module test_dependency_transformation_filter_items2_mod
implicit none
contains
subroutine proc2(arg)
    integer, intent(inout) :: arg
    arg = arg + 2
end subroutine proc2
end module test_dependency_transformation_filter_items2_mod

module test_dependency_transformation_filter_items3_mod
implicit none
integer, parameter :: param3 = 3
contains
subroutine proc3(arg)
    integer, intent(inout) :: arg
    arg = arg + 3
end subroutine proc3
end module test_dependency_transformation_filter_items3_mod

subroutine test_dependency_transformation_filter_items_driver
use test_dependency_transformation_filter_items1_mod, only: proc1
use test_dependency_transformation_filter_items3_mod, only: param3
implicit none
integer :: i
i = param3
call proc1(i)
end subroutine test_dependency_transformation_filter_items_driver
    """

    config['routines'] = {
        'test_dependency_transformation_filter_items_driver': {'role': 'driver'},
    }

    filepath = tmp_path/'test_dependency_transformation_filter_items.F90'
    filepath.write_text(fcode)

    scheduler = Scheduler(
        paths=[tmp_path], config=config,
        seed_routines=['test_dependency_transformation_filter_items_driver'],
        frontend=frontend, xmods=[tmp_path]
    )

    # Only the driver and mod1 are in the Sgraph
    expected_dependencies = {
        '#test_dependency_transformation_filter_items_driver': {
            'test_dependency_transformation_filter_items1_mod#proc1',
            'test_dependency_transformation_filter_items3_mod'
        },
        'test_dependency_transformation_filter_items1_mod#proc1': set(),
        'test_dependency_transformation_filter_items3_mod': set()
    }

    assert set(scheduler.items) == set(expected_dependencies)
    assert set(scheduler.dependencies) == {
        (a, b) for a, deps in expected_dependencies.items() for b in deps
    }

    # The other module and procedure are in the item_factory's cache...
    assert 'test_dependency_transformation_filter_items2_mod' in scheduler.item_factory.item_cache
    assert 'test_dependency_transformation_filter_items1_mod#unused_proc' in scheduler.item_factory.item_cache

    # ...and share the same sourcefile object
    assert (
        scheduler.item_factory.item_cache['test_dependency_transformation_filter_items2_mod'].source is
        scheduler.item_factory.item_cache['test_dependency_transformation_filter_items1_mod'].source
    )

    # The filegraph consists of the single file
    filegraph = scheduler.file_graph
    assert filegraph.items == (str(filepath).lower(),)

    # Check that the DependencyTransformation changes only the active items
    # and discards unused routines
    scheduler.process(transformation=DependencyTransformation(suffix='_foo', module_suffix='_mod'))

    expected_dependencies = {
        '#test_dependency_transformation_filter_items_driver': {
            'test_dependency_transformation_filter_items1_foo_mod#proc1_foo',
            'test_dependency_transformation_filter_items3_mod'
        },
        'test_dependency_transformation_filter_items1_foo_mod#proc1_foo': set(),
        'test_dependency_transformation_filter_items3_mod': set()
    }

    assert set(scheduler.items) == set(expected_dependencies)
    assert set(scheduler.dependencies) == {
        (a, b) for a, deps in expected_dependencies.items() for b in deps
    }


    # The other module is still in the item_factory's cache...
    assert 'test_dependency_transformation_filter_items2_mod' in scheduler.item_factory.item_cache

    # ...and so are the original modules
    assert 'test_dependency_transformation_filter_items1_mod' in scheduler.item_factory.item_cache
    assert 'test_dependency_transformation_filter_items3_mod' in scheduler.item_factory.item_cache

    # ...but they don't share the same sourcefile object anymore
    original_source = scheduler.item_factory.item_cache['test_dependency_transformation_filter_items2_mod'].source
    new_src = scheduler.item_factory.item_cache['test_dependency_transformation_filter_items1_foo_mod'].source
    assert new_src is not original_source

    # The new source does not contain the unused module
    assert [m.name.lower() for m in original_source.modules] == [
        'test_dependency_transformation_filter_items1_mod',
        'test_dependency_transformation_filter_items2_mod',
        'test_dependency_transformation_filter_items3_mod'
    ]
    assert [m.name.lower() for m in new_src.modules] == [
        'test_dependency_transformation_filter_items1_foo_mod',
        'test_dependency_transformation_filter_items3_mod'
    ]
    # Note the idiosyncratic behaviour:
    # items3_mod appears twice because the name is not updated but it is part of the
    # scheduler graph. We need to see whether this is what we want...

    # The new module does not contain the unused procedure
    original_mod1 = original_source['test_dependency_transformation_filter_items1_mod']
    new_mod1 = new_src['test_dependency_transformation_filter_items1_foo_mod']

    assert [r.name.lower() for r in original_mod1.subroutines] == ['proc1', 'unused_proc']
    assert [r.name.lower() for r in new_mod1.subroutines] == ['proc1_foo']
loki-ecmwf-0.3.6/loki/transformations/build_system/tests/test_file_write.py0000664000175000017500000002432715167130205027547 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Tests for build system interaction
"""

from pathlib import Path
import re
from subprocess import CalledProcessError

import pytest

from loki.batch import Scheduler, SchedulerConfig, ProcessingStrategy
from loki.frontend import available_frontends, OMNI
from loki.logging import log_levels
from loki.transformations.build_system import FileWriteTransformation


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('enable_imports', [False, True])
@pytest.mark.parametrize('import_level', ['module', 'subroutine'])
@pytest.mark.parametrize('qualified_imports', [False, True])
@pytest.mark.parametrize('use_rootpath', [False, True])
@pytest.mark.parametrize('suffix', [None, '.F90', '.Fstar'])
def test_file_write_module_imports(frontend, tmp_path, enable_imports, import_level,
                                   qualified_imports, use_rootpath, suffix):
    """
    Set up a four file mini-project with some edge cases around
    import behaviour (see in-source comments for details) and verify
    that the generated CMake plan matches the list of files we expect
    to transform, and that the FileWriteTransformation writes exactly these
    files
    """
    fcode_mod_a = """
module a_mod
    implicit none
    public
    integer :: global_a = 1
end module a_mod
"""

    fcode_mod_b = """
module b_mod
    implicit none
    public
    type type_b
        integer :: val
    end type type_b
end module b_mod
"""

    if qualified_imports:
        import_stmt = "use a_mod, only: global_a\n    use b_mod, only: type_b"
    else:
        import_stmt = "use a_mod\n    use b_mod"

    module_import_stmt = ""
    routine_import_stmt = ""
    if import_level == 'module':
        module_import_stmt = import_stmt
    elif import_level == 'subroutine':
        routine_import_stmt = import_stmt

    fcode_mod_c = f"""
module c_mod
    {module_import_stmt}
    implicit none
contains
    subroutine c(val)
        {routine_import_stmt}
        implicit none
        integer, intent(inout) :: val
        type(type_b) :: b
        b%val = global_a
        val = b%val
    end subroutine c
end module c_mod
"""

    fcode_mod_d = """
module d_mod
    implicit none
contains
    subroutine d
        use c_mod, only: c
        implicit none
        integer :: v
        call c(v)
    end subroutine d
end module d_mod
"""

    # Set-up paths and write sources
    src_path = tmp_path/'src'
    src_path.mkdir()
    out_path = tmp_path/'build'
    out_path.mkdir()

    (src_path/'a.F90').write_text(fcode_mod_a)
    (src_path/'b.F90').write_text(fcode_mod_b)
    (src_path/'c.F90').write_text(fcode_mod_c)
    (src_path/'d.F90').write_text(fcode_mod_d)

    # Expected items in the dependency graph
    expected_items = {'c_mod#c', 'd_mod#d'}

    if import_level == 'subroutine':
        if qualified_imports:
            # With qualified imports, we do not have a dependency
            # on 'b_mod' but directly on 'b_mod#type_b'
            expected_items |= {'a_mod', 'b_mod#type_b'}
        else:
            # Without qualified imports, we assume a dependency
            # for the subroutine on the imported module
            expected_items |= {'a_mod', 'b_mod'}

    elif import_level == 'module':
        if qualified_imports:
            # If we have a qualified import for the derived type
            # then we will recognize the dependency
            expected_items |= {'b_mod#type_b'}

    # Create the Scheduler
    config = SchedulerConfig.from_dict({
        'default': {
            'role': 'kernel',
            'expand': True,
            'strict': True,
            'enable_imports': enable_imports,
            'mode': 'foobar'
        },
        'routines': {'d': {'role': 'driver'}}
    })
    try:
        scheduler = Scheduler(
            paths=[src_path], config=config, frontend=frontend,
            output_dir=out_path, xmods=[tmp_path]
        )
    except CalledProcessError as e:
        all_modules_expected = 'a_mod' in expected_items and (expected_items | {'b_mod', 'b_mod#type_b'})
        if frontend == OMNI and not (enable_imports and all_modules_expected):
            # If not all header modules appear in the dependency graph, then these
            # will not be parsed by OMNI and therefore the required xmod files will
            # not be generated, thus making modules 'c' and 'd' fail at parsing
            pytest.xfail('Without parsing imports, OMNI does not have the xmod for imported modules')
        raise e

    # Check the dependency graph
    assert expected_items == {item.name for item in scheduler.items}

    # Set-up the file write
    transformation = FileWriteTransformation(
        suffix=suffix,
        include_module_var_imports=enable_imports
    )

    # Generate the CMake plan
    plan_file = tmp_path/'plan.cmake'
    root_path = tmp_path if use_rootpath else None
    scheduler.process(transformation, proc_strategy=ProcessingStrategy.PLAN)
    scheduler.write_cmake_plan(filepath=plan_file, rootpath=root_path)

    # Validate the plan file content
    plan_pattern = re.compile(r'set\(\s*(\w+)\s*(.*?)\s*\)', re.DOTALL)

    loki_plan = plan_file.read_text()
    plan_dict = {k: v.split() for k, v in plan_pattern.findall(loki_plan)}
    plan_dict = {k: {Path(s).stem for s in v} for k, v in plan_dict.items()}

    if enable_imports:
        # We expect to write all files that correspond to items in the graph
        expected_files = {item[0] for item in expected_items}

        if qualified_imports:
            # ...but we want to never write 'b' if we have fully qualified imports
            # because that only contains a type definition
            expected_files -= {'b'}
    else:
        # We expect to only write the subroutine files
        expected_files = {'c', 'd'}

    assert 'LOKI_SOURCES_TO_TRANSFORM' in plan_dict
    assert plan_dict['LOKI_SOURCES_TO_TRANSFORM'] == expected_files

    assert 'LOKI_SOURCES_TO_REMOVE' in plan_dict
    assert plan_dict['LOKI_SOURCES_TO_REMOVE'] == expected_files

    assert 'LOKI_SOURCES_TO_APPEND' in plan_dict
    assert plan_dict['LOKI_SOURCES_TO_APPEND'] == {
        f'{name}.foobar' for name in expected_files
    }

    # Write the outputs
    scheduler.process(transformation)

    # Validate the list of written files
    if suffix is None:
        suffix = '.F90'
    written_files = {f.name for f in out_path.glob('*')}
    assert written_files == {
        f'{name}.foobar{suffix}' for name in expected_files
    }


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('have_non_replicate_conflict', [False, True])
def test_file_write_replicate(tmp_path, caplog, frontend, have_non_replicate_conflict):
    fcode_a = """
module a_mod
    implicit none
    integer :: a
end module a_mod
    """
    fcode_b = """
module b_mod
    implicit none
    integer :: b
end module b_mod
    """
    if have_non_replicate_conflict:
        other_routine = "subroutine not_c()\n    end subroutine not_c"
    else:
        other_routine = ""
    fcode_c = f"""
module c_mod
contains
    subroutine c(val)
        use a_mod, only: a
        use b_mod, only: b
        integer, intent(inout) :: val
        val = val + a + b
    end subroutine c
    {other_routine}
end module c_mod
    """
    fcode_d = """
subroutine d()
    use c_mod, only: c
    implicit none
    integer :: var
    call c(var)
end subroutine d
    """


    # Set-up paths and write sources
    src_path = tmp_path/'src'
    src_path.mkdir()
    out_path = tmp_path/'build'
    out_path.mkdir()

    (src_path/'a.F90').write_text(fcode_a)
    (src_path/'b.F90').write_text(fcode_b)
    (src_path/'c.F90').write_text(fcode_c)
    (src_path/'d.F90').write_text(fcode_d)

    # Expected items in the dependency graph
    expected_items = {'a_mod', 'b_mod', 'c_mod#c', '#d'}

    if have_non_replicate_conflict:
        expected_items |= {'c_mod#not_c'}

    # Create the Scheduler
    config = SchedulerConfig.from_dict({
        'default': {
            'role': 'kernel',
            'expand': True,
            'strict': True,
            'enable_imports': True,
            'mode': 'foobar',
            'replicate': True
        },
        'routines': {
            'b_mod': {'replicate': False},
            'not_c': {'replicate': False},
            'd': {'role': 'driver', 'replicate': False},
        }
    })

    scheduler = Scheduler(
        paths=[src_path], config=config, frontend=frontend,
        output_dir=out_path, xmods=[tmp_path]
    )

    # Check the dependency graph
    assert expected_items == {item.name for item in scheduler.items}

    # Set-up the file write
    transformation = FileWriteTransformation(include_module_var_imports=True)

    # Generate the CMake plan
    scheduler.process(transformation, proc_strategy=ProcessingStrategy.PLAN)
    plan_file = tmp_path/'plan.cmake'

    caplog.clear()
    with caplog.at_level(log_levels['WARNING']):
        scheduler.write_cmake_plan(filepath=plan_file, rootpath=tmp_path)
        if have_non_replicate_conflict:
            assert len(caplog.records) == 1
            assert 'c.f90' in caplog.records[0].message
            assert 'c_mod#not_c' in caplog.records[0].message
        else:
            assert not caplog.records


    # Validate the plan file content
    plan_pattern = re.compile(r'set\(\s*(\w+)\s*(.*?)\s*\)', re.DOTALL)

    loki_plan = plan_file.read_text()
    plan_dict = {k: v.split() for k, v in plan_pattern.findall(loki_plan)}
    plan_dict = {k: {Path(s).stem for s in v} for k, v in plan_dict.items()}

    assert plan_dict['LOKI_SOURCES_TO_TRANSFORM'] == {'a', 'b', 'c', 'd'}
    assert plan_dict['LOKI_SOURCES_TO_REMOVE'] == {'b', 'd'}
    assert plan_dict['LOKI_SOURCES_TO_APPEND'] == {'a.foobar', 'b.foobar', 'c.foobar', 'd.foobar'}

    # Write the outputs
    scheduler.process(transformation)

    # Validate the list of written files
    written_files = {f.name for f in out_path.glob('*')}
    assert written_files == {'a.foobar.F90', 'b.foobar.F90', 'c.foobar.F90', 'd.foobar.F90'}
loki-ecmwf-0.3.6/loki/transformations/build_system/tests/test_plan.py0000664000175000017500000001114515167130205026342 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from pathlib import Path
import re

import pytest

from loki.batch import Scheduler, SchedulerConfig, ProcessingStrategy
from loki.transformations.build_system import CMakePlanTransformation, FileWriteTransformation


@pytest.mark.parametrize('use_rootpath', [False, True])
@pytest.mark.parametrize('use_fullpath', [False, True])
def test_plan_relative_paths(tmp_path, monkeypatch, use_rootpath, use_fullpath):
    """
    A test that emulates the use of overlay file systems that may cause issues
    if paths are resolved prematurely.

    This can generate file names in the lists produced by the CMakePlanTransformation
    that don't match the internal lists of files in the CMake target, thus breaking
    the source list update process.

    This test creates a file system hierarchy like this:

    - real_path/
        - module/
            - mymod.F90
        - src/
            - mysub.F90
    - overlay_path -> real_path
    - build/

    and initiates the Scheduler on the overlay path

    """
    (tmp_path/'real_path').mkdir()
    (tmp_path/'real_path/src').mkdir()
    (tmp_path/'real_path/module').mkdir()
    (tmp_path/'overlay_path').symlink_to('real_path')
    (tmp_path/'build').mkdir()

    rootpath = tmp_path if use_rootpath else None
    srcpath = f'{tmp_path}/' if use_fullpath else ''

    fcode_mymod = """
module mymod
    implicit none
    contains
        subroutine mod_sub
        end subroutine mod_sub
end module mymod
    """
    fcode_mysub = """
subroutine mysub
    use mymod, only: mod_sub
    implicit none
    call mod_sub
end subroutine mysub
    """

    (tmp_path/'real_path/src/mysub.F90').write_text(fcode_mysub)
    (tmp_path/'real_path/module/mymod.F90').write_text(fcode_mymod)

    assert (tmp_path/'overlay_path/src/mysub.F90').exists()

    # Run the test in tmp_path to be able to specify relative paths to the scheduler
    monkeypatch.chdir(tmp_path)

    # Initialize the Scheduler
    config = SchedulerConfig.from_dict({
        'default': {
            'role': 'kernel',
            'expand': True,
            'strict': True,
            'mode': 'test',
        },
        'routines': {
            'mysub': {'role': 'driver'}
        }
    })
    scheduler = Scheduler(
        paths=[f'{srcpath}overlay_path'],
        config=config,
        full_parse=False,
        output_dir=tmp_path/'build'
    )
    assert scheduler.items == ('#mysub', 'mymod#mod_sub')

    # Scheduler items are all relative paths
    assert scheduler['#mysub'].source.path == Path(f'{srcpath}overlay_path/src/mysub.F90')
    assert scheduler['mymod#mod_sub'].source.path == Path(f'{srcpath}overlay_path/module/mymod.F90')

    # Run the planning transformation pipeline
    scheduler.process(
        transformation=FileWriteTransformation(),
        proc_strategy=ProcessingStrategy.PLAN
    )
    plan_trafo = CMakePlanTransformation(rootpath=rootpath)
    scheduler.process(
        transformation=plan_trafo,
        proc_strategy=ProcessingStrategy.PLAN
    )
    planfile = tmp_path/'build/plan.cmake'
    plan_trafo.write_plan(planfile)

    plan_pattern = re.compile(r'set\(\s*(\w+)\s*(.*?)\s*\)', re.DOTALL)
    plan_dict = {k: v.split() for k, v in plan_pattern.findall(planfile.read_text())}
    plan_dict = {k: [Path(p) for p in v] for k, v in plan_dict.items()}

     # The newly generated files will always have fully qualified paths
    to_append = [tmp_path/'build/mysub.test.F90', tmp_path/'build/mymod.test.F90']

    # The list of files to transform (this property is currently not used by the CMake macros)
    # will provide the original relative paths - unless we resolve them relative to a provided
    # root directory, which will also resolve symlinks
    if rootpath:
        to_transform = [Path('real_path/src/mysub.F90'), Path('real_path/module/mymod.F90')]
    else:
        to_transform = [Path(f'{srcpath}overlay_path/src/mysub.F90'), Path(f'{srcpath}overlay_path/module/mymod.F90')]

    assert plan_trafo.sources_to_append[None] == to_append
    assert plan_dict['LOKI_SOURCES_TO_APPEND'] == to_append
    assert plan_trafo.sources_to_transform[None] == to_transform
    assert plan_dict['LOKI_SOURCES_TO_TRANSFORM'] == to_transform
    assert plan_trafo.sources_to_remove[None] == to_transform
    assert plan_dict['LOKI_SOURCES_TO_REMOVE'] == to_transform
loki-ecmwf-0.3.6/loki/transformations/build_system/plan.py0000664000175000017500000001637415167130205024152 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Transformations to be used for exposing planned changes to
the build system
"""

from pathlib import Path

from loki.batch.transformation import Transformation
from loki.logging import debug

__all__ = ['CMakePlanTransformation']

class CMakePlanTransformation(Transformation):
    """
    Gather the planning information from all :any:`Item.trafo_data` to which
    this information is applied and allows writing a CMake plan file

    This requires that :any:`FileWriteTransformation` has been applied in planning
    mode first.

    Applying this transformation to a :any:`Item` updates internal lists:

    * :attr:`sources_to_transform`: The path of all source files that contain
      objects that are transformed by a Loki transformation in the pipeline
    * :attr:`sources_to_append`: The path of any new source files that exist
      as a consequence of the Loki transformation pipeline, e.g., transformed
      source files that are written.
    * :attr:`sources_to_remove`: The path of any existing source files that
      are to be removed from the compilation target. This includes all items
      that don't have the :any:`Item.replicate` property.

    The :any:`Sourcefile.path` is used to determine the file path from which a
    Fortran sourcefile was read. New paths are provided in
    ``item.trafo_data['FileWriteTransformation']['path']``.

    The method :meth:`write_plan` allows to write the gathered information to
    a CMake file that can be included in the CMake scripts that build a library.
    The plan file is a CMake file defining three lists matching the above:

    * ``LOKI_SOURCES_TO_TRANSFORM``: The list of files that are
        processed in the dependency graph
    * ``LOKI_SOURCES_TO_APPEND``: The list of files that are created
        and have to be added to the build target as part of the processing
    * ``LOKI_SOURCES_TO_REMOVE``: The list of files that are no longer
        required (because they have been replaced by transformed files) and
        should be removed from the build target.

    These lists are used by the Loki CMake wrappers (particularly
    ``loki_transform_target``) to schedule the source updates and update the
    source lists of the CMake target object accordingly.

    Parameters
    ----------
    rootpath : str (optional)
        If given, all paths will be resolved relative to this root directory
    """

    # This transformation is applied over the file graph
    traverse_file_graph = True

    item_filter = None

    def __init__(self, rootpath=None):
        self.rootpath = None if rootpath is None else Path(rootpath).resolve()
        self.sources_to_append = {}
        self.sources_to_remove = {}
        self.sources_to_transform = {}

    def plan_file(self, sourcefile, **kwargs):
        item = kwargs.get('item')
        if not item:
            raise ValueError('No Item provided; required to determine CMake plan')

        if not 'FileWriteTransformation' in item.trafo_data:
            return

        sourcepath = item.path

        # This makes sure the sourcepath does in fact exist. Combined with
        # item duplication or other transformations we might end up adding
        # items on-the-fly that did not exist before, with fake paths.
        # There is possibly a better way of doing this, though.
        source_exists = sourcepath.exists()

        if self.rootpath is not None:
            sourcepath = sourcepath.resolve().relative_to(self.rootpath)

        newsource = item.trafo_data['FileWriteTransformation']['path']

        debug(f'Planning:: {item.name} (role={item.role}, mode={item.mode})')

        key = item.lib
        if newsource not in self.sources_to_append:
            if source_exists:
                self.sources_to_transform.setdefault(key,[]).append(sourcepath)
            if item.replicate:
                # Add new source file next to the old one
                self.sources_to_append.setdefault(key,[]).append(newsource)
            else:
                # Replace old source file to avoid ghosting
                self.sources_to_append.setdefault(key,[]).append(newsource)
                if source_exists:
                    if self.rootpath is not None:
                        self.sources_to_remove.setdefault(key,[]).append(sourcepath)
                    else:
                        # NB: we use the item path directly here instead of resolving it,
                        #     to stay compatible with what has been provided on the CLI.
                        #     This is because the build system will likely use this path
                        #     internally to identify the original file, and if the paths
                        #     don't match the removal of source files from the target fails.
                        self.sources_to_remove.setdefault(key,[]).append(item.path)

    def _write_plan(self, filepath):
        """
        Write the key/target/library independent part of CMake plan file to :data:`filepath`
        """
        sources_to_transform = [s for sources in self.sources_to_transform.values() for s in sources]
        sources_to_append = [s for sources in self.sources_to_append.values() for s in sources]
        sources_to_remove = [s for sources in self.sources_to_remove.values() for s in sources]

        with Path(filepath).open('w') as f:
            s_transform = '\n'.join(f'    {s}' for s in sources_to_transform)
            f.write(f'set( LOKI_SOURCES_TO_TRANSFORM \n{s_transform}\n   )\n')

            s_append = '\n'.join(f'    {s}' for s in sources_to_append)
            f.write(f'set( LOKI_SOURCES_TO_APPEND \n{s_append}\n   )\n')

            s_remove = '\n'.join(f'    {s}' for s in sources_to_remove)
            f.write(f'set( LOKI_SOURCES_TO_REMOVE \n{s_remove}\n   )\n')

    def write_plan(self, filepath):
        """
        Write the CMake plan file to :data:`filepath`
        """
        # write plan disregarding the key/target/library
        self._write_plan(filepath)

        # write plan file for each key/target/library
        all_targets = self.sources_to_transform | self.sources_to_append | self.sources_to_remove
        for target in all_targets:
            if target is None:
                continue
            with Path(filepath).open('a') as f:
                # sanitize target = target, e.g., remove '.' and replace with '_'
                sanitized_target = target.replace('.', '_')
                s_transform = '\n'.join(f'    {s}' for s in self.sources_to_transform.get(target, ()))
                f.write(f'set( LOKI_SOURCES_TO_TRANSFORM_{sanitized_target} \n{s_transform}\n   )\n')

                s_append = '\n'.join(f'    {s}' for s in self.sources_to_append.get(target, ()))
                f.write(f'set( LOKI_SOURCES_TO_APPEND_{sanitized_target} \n{s_append}\n   )\n')

                s_remove = '\n'.join(f'    {s}' for s in self.sources_to_remove.get(target, ()))
                f.write(f'set( LOKI_SOURCES_TO_REMOVE_{sanitized_target} \n{s_remove}\n   )\n')
loki-ecmwf-0.3.6/loki/transformations/build_system/file_write.py0000664000175000017500000000660615167130205025346 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Transformations to be used in build-system level tasks
"""

from pathlib import Path

from loki.backend import FortranStyle, IFSFortranStyle
from loki.batch import Transformation, ProcedureItem, ModuleItem


__all__ = ['FileWriteTransformation']


class FileWriteTransformation(Transformation):
    """
    Write out modified source files to a select build directory

    Parameters
    ----------
    suffix : str, optional
        File suffix to determine file type for all written file. If
        omitted, it will preserve the original file type.
    cuf : bool, optional
        Use CUF (CUDA Fortran) backend instead of Fortran backend.
    include_module_var_imports : bool, optional
        Flag to force the :any:`Scheduler` traversal graph to recognise
        module variable imports and write the modified module files.
    """

    # This transformation is applied over the file graph
    traverse_file_graph = True

    _style_map = {
        None: FortranStyle(),
        'fortran': FortranStyle(),
        'ifs': IFSFortranStyle()
    }

    def __init__(
            self, suffix=None, cuf=False, style=None,
            include_module_var_imports=False,
    ):
        self.suffix = suffix
        self.cuf = cuf
        self.style = self._style_map[style]
        self.include_module_var_imports = include_module_var_imports

    @property
    def item_filter(self):
        """
        Override ``item_filter`` to configure whether module variable
        imports are honoured in the :any:`Scheduler` traversal.
        """
        if self.include_module_var_imports:
            return (ProcedureItem, ModuleItem)
        return ProcedureItem

    def _get_file_path(self, item, build_args):
        if not item:
            raise ValueError('No Item provided; required to determine file write path')

        _mode = item.mode if item.mode else 'loki'
        _mode = _mode.replace('-', '_')  # Sanitize mode string

        path = Path(item.path)
        suffix = self.suffix if self.suffix else path.suffix
        sourcepath = Path(item.path).with_suffix(f'.{_mode}{suffix}')
        if build_args and (output_dir := build_args.get('output_dir', None)) is not None:
            sourcepath = Path(output_dir)/sourcepath.name
        return sourcepath

    def transform_file(self, sourcefile, **kwargs):
        item = kwargs.get('item')
        if not item and 'items' in kwargs:
            if kwargs['items']:
                item = kwargs['items'][0]

        build_args = kwargs.get('build_args', {})
        sourcepath = self._get_file_path(item, build_args)
        sourcefile.write(path=sourcepath, cuf=self.cuf, style=self.style)

    def plan_file(self, sourcefile, **kwargs):  # pylint: disable=unused-argument
        item = kwargs.get('item')
        if not item and 'items' in kwargs:
            if kwargs['items']:
                item = kwargs['items'][0]

        build_args = kwargs.get('build_args', {})
        sourcepath = self._get_file_path(item, build_args)
        item.trafo_data['FileWriteTransformation'] = {'path': sourcepath}
loki-ecmwf-0.3.6/loki/transformations/build_system/dependency.py0000664000175000017500000004212515167130205025327 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from pathlib import Path

from loki.backend import fgen
from loki.batch import Transformation
from loki.ir import (
    CallStatement, Import, Interface, FindNodes, FindInlineCalls, Transformer
)
from loki.logging import warning
from loki.module import Module
from loki.subroutine import Subroutine
from loki.types import ProcedureType, Scope
from loki.tools import as_tuple, OrderedSet


__all__ = ['DependencyTransformation']


class DependencyTransformation(Transformation):
    """
    Basic :any:`Transformation` class that facilitates dependency
    injection for transformed :any:`Module` and :any:`Subroutine`
    into complex source trees

    This transformation appends a provided ``suffix`` argument to
    transformed subroutine and module objects and changes the target
    names of :any:`Import` and :any:`CallStatement` nodes on the call-site
    accordingly.

    For subroutines declared via an interface block, these interfaces
    are updated accordingly. For subroutines that are not wrapped in a
    module, an updated interface block is also written as a header file
    to :data:`include_path`. Where interface blocks to renamed subroutines
    are included via C-style imports, the import name is updated accordingly.

    To ensure that every subroutine is wrapped in a module, the
    accompanying :any:`ModuleWrapTransformation` should be applied
    first. This restores the behaviour of the ``module`` mode in an earlier
    version of this transformation.

    When applying the transformation to a source object, one of two
    "roles" can be specified via the ``role`` keyword:

    * ``'driver'``: Only renames imports and calls to kernel routines
    * ``'kernel'``: Renames routine or enclosing modules, as well as
      renaming any further imports and calls.

    Note that ``routine.apply(transformation, role='driver')`` entails
    that the ``routine`` still mimicks its original counterpart and
    can therefore be used as a drop-in replacement during compilation
    that then diverts the dependency tree to the modified sub-tree.

    Parameters
    ----------
    suffix : str
        The suffix to apply during renaming
    module_suffix : str
        Special suffix to signal module names like `_MOD`
    include path : path
        Directory for generating additional header files
    replace_ignore_items : bool
        Debug flag to toggle the replacement of calls to subroutines
        in the ``ignore``. Default is ``True``.
    remove_inactive_items : bool
        Debug flag to toggle the removal of items (modules, subroutines)
        in the sourcefile that are not part of the scheduler graph.
        Default is ``True``.
    """

    # item_filter = Item

    reverse_traversal = True

    # This transformation is applied over the file graph
    traverse_file_graph = True

    # This transformation recurses from the Sourcefile down
    recurse_to_modules = True
    recurse_to_procedures = True
    recurse_to_internal_procedures = False

    # This transformation changes the names of items and may create items if original modules
    # are retained (e.g., when global variable imports exist)
    renames_items = True
    creates_items = True

    def __init__(self, suffix, module_suffix=None, include_path=None, replace_ignore_items=True,
                 remove_inactive_items=True):
        self.suffix = suffix
        self.module_suffix = module_suffix
        self.replace_ignore_items = replace_ignore_items
        self.remove_inactive_items = remove_inactive_items
        self.include_path = None if include_path is None else Path(include_path)

    def transform_file(self, sourcefile, **kwargs):
        """
        Remove non-active scope nodes if :attr:`remove_inactive_items` is true
        """
        sourcefile.ir = sourcefile.ir.clone(
            body=self.remove_inactive_ir_nodes(
                sourcefile.ir.body, f'file {(sourcefile.path or "")!s}', **kwargs
            )
        )

    def transform_module(self, module, **kwargs):
        """
        Rename kernel modules and re-point module-level imports.
        """
        role = kwargs.get('role')

        # remember/keep track of the module subroutines (even if some of those are removed)
        routines = tuple(routine.name.lower() for routine in module.subroutines)
        if role == 'kernel':
            # Change the name of kernel modules
            module.name = self.derive_module_name(module.name)

            if (item := kwargs.get('item')) and item.name != module.name.lower():
                item.name = module.name.lower()

            if module.contains:
                module.contains = module.contains.clone(
                    body=self.remove_inactive_ir_nodes(
                        module.contains.body, f'module {module.name}', **kwargs
                    ),
                )

        targets = tuple(str(t).lower() for t in as_tuple(kwargs.get('targets')))
        if self.replace_ignore_items and (item := kwargs.get('item')):
            targets += tuple(str(i).lower() for i in item.ignore)
        self.rename_imports(module, imports=module.imports, targets=targets)
        active_nodes = None
        if self.remove_inactive_items and not kwargs.get('items') is None:
            active_nodes = [item.scope_ir.name.lower() for item in kwargs['items']]
        # rename target names in an access spec for both public and private access specs 
        if module.public_access_spec:
            module.public_access_spec = self.rename_access_spec_names(
                module.public_access_spec, targets=targets, active_nodes=active_nodes,
                routines=routines
            )
        if module.private_access_spec:
            module.private_access_spec = self.rename_access_spec_names(
                module.private_access_spec, targets=targets, active_nodes=active_nodes,
                routines=routines
            )

    def transform_subroutine(self, routine, **kwargs):
        """
        Rename kernel subroutine and all imports and calls to target routines

        For subroutines that are not wrapped in a module, re-generate the interface
        block.
        """
        role = kwargs.get('role')
        item = kwargs.get('item')
        targets = tuple(str(t).lower() for t in as_tuple(kwargs.get('targets')))
        if self.replace_ignore_items and item:
            targets += tuple(str(i).lower() for i in item.ignore)

        if role == 'kernel':
            if routine.name.endswith(self.suffix):
                # This is to ensure that the transformation is idempotent if
                # applied more than once to a routine
                return

            # Change the name of kernel routines
            routine.name += self.suffix
            if item:
                item.name += self.suffix.lower()

        self.rename_calls(routine, targets=targets, item=item)

        # Note, C-style imports can be in the body, so use whole IR
        imports = FindNodes(Import).visit(routine.ir)
        self.rename_imports(routine, imports=imports, targets=targets)

        # Interface blocks can only be in the spec
        intfs = FindNodes(Interface).visit(routine.spec)
        self.rename_interfaces(intfs, targets=targets)

        if role == 'kernel' and not routine.parent and self.include_path:
            # Re-generate C-style interface header
            self.generate_interfaces(routine)

    def remove_inactive_ir_nodes(self, body, transformed_scope_name, **kwargs):
        """
        Utility to filter :any:`Scope` nodes in :data:`body` to include only
        those given in ``kwargs['items']``.
        """
        if self.remove_inactive_items:
            if kwargs.get('items') is None:
                msg = (
                    f'Cannot remove inactive items in {transformed_scope_name}.'
                    '. No ``items`` given in kwargs.'
                )
                warning(msg)
            else:
                active_nodes = [item.scope_ir for item in kwargs['items']]
                body = tuple(
                    node for node in body
                    if not isinstance(node, Scope) or node in active_nodes
                )
        return body

    def derive_module_name(self, modname):
        """
        Utility to derive a new module name from :attr:`suffix` and :attr:`module_suffix`

        Parameters
        ----------
        modname : str
            Current module name
        """

        # First step through known suffix variants to determine canonical basename
        if self.module_suffix and modname.lower().endswith(self.module_suffix.lower()):
            # Remove the module_suffix, if present
            idx = modname.lower().rindex(self.module_suffix.lower())
            modname = modname[:idx]
        if modname.lower().endswith(self.suffix.lower()):
            # Remove the dependency injection suffix, if present
            idx = modname.lower().rindex(self.suffix.lower())
            modname = modname[:idx]

        # Suffix combination to canonical basename
        if self.module_suffix:
            return f'{modname}{self.suffix}{self.module_suffix}'
        return f'{modname}{self.suffix}'

    def rename_calls(self, routine, targets=None, item=None):
        """
        Update :any:`CallStatement` and :any:`InlineCall` to actively
        transformed procedures

        Parameters
        ----------
        targets : list of str
            Optional list of subroutine names for which to modify the corresponding
            calls. If not provided, all calls are updated
        """
        from loki.batch import SchedulerConfig  # pylint: disable=import-outside-toplevel,cyclic-import

        def _update_item(orig_name, new_name):
            # Update the ignore property if necessary
            if item and (matched_keys := SchedulerConfig.match_item_keys(orig_name, item.ignore)):
                # Add the renamed but ignored items to the block list because we won't be able to
                # find them as dependencies under their new name anymore
                item.config['block'] = as_tuple(item.block) + tuple(
                    new_name for name in item.ignore if name in matched_keys
                )
                item.config['ignore'] = tuple(
                    new_name if name in matched_keys else name
                    for name in item.ignore
                )

        members = [r.name for r in routine.subroutines]

        for call in FindNodes(CallStatement).visit(routine.body):
            if call.name in members:
                continue
            if targets is None or call.name in targets:
                orig_name = str(call.name)
                new_name = f'{orig_name}{self.suffix}'
                new_type = call.name.type.clone(dtype=ProcedureType(name=new_name))
                call._update(name=call.name.clone(name=new_name, type=new_type))
                _update_item(orig_name, str(call.name))

        for call in FindInlineCalls(unique=False).visit(routine.body):
            if call.function in members:
                continue
            if targets is None or call.function in targets:
                orig_name = str(call.name)
                new_name = f'{orig_name}{self.suffix}'
                new_type = call.function.type.clone(dtype=ProcedureType(name=new_name))
                call.function = call.function.clone(name=new_name, type=new_type)
                _update_item(orig_name, str(call.name))

    def rename_imports(self, source, imports, targets=None):
        """
        Update imports of actively transformed subroutines.

        Parameters
        ----------
        source : :any:`ProgramUnit`
            The IR object to transform
        imports : list of :any:`Import`
            The list of imports to update. This includes both, C-style header includes
            and Fortran import statements (``USE`` and ``IMPORT``)
        targets : list of str
            Optional list of subroutine names for which to modify imports
        """
        # We don't want to rename module variable imports, so we build
        # a list of calls to further filter the targets
        if isinstance(source, Module):
            calls = OrderedSet()
            for routine in source.subroutines:
                calls |= {str(c.name).lower() for c in FindNodes(CallStatement).visit(routine.body)}
                calls |= {str(c.name).lower() for c in FindInlineCalls().visit(routine.body)}
        else:
            calls = {str(c.name).lower() for c in FindNodes(CallStatement).visit(source.body)}
            calls |= {str(c.name).lower() for c in FindInlineCalls().visit(source.body)}

        # Import statements still point to unmodified call names
        calls = {call.replace(f'{self.suffix.lower()}', '') for call in calls}
        call_targets = {call for call in calls if call in as_tuple(targets)}

        # We go through the IR, as C-imports can be attributed to the body
        import_map = {}
        for im in imports:
            if im.c_import:
                target_symbol, *suffixes = im.module.lower().split('.', maxsplit=1)
                if targets and target_symbol.lower() in targets and not 'func.h' in suffixes:
                    # Modify the the basename of the C-style header import
                    s = '.'.join(im.module.split('.')[1:])
                    im._update(module=f'{target_symbol}{self.suffix}.{s}')

            else:
                # Modify module import if it imports any call targets
                if targets and im.symbols and any(s in call_targets for s in im.symbols):
                    new_module_name = self.derive_module_name(im.module)
                    if not all(s in call_targets for s in im.symbols):
                        # Mixed import: We need to split the import, retaining the original name for
                        # non-target imports and using the new name for target imports
                        import_map[im] = tuple(
                            im.clone(module=new_module_name, symbols=(s.clone(name=f'{s.name}{self.suffix}'),))
                            if s in call_targets else im.clone(symbols=(s,))
                            for s in im.symbols
                        )
                    else:
                        # Append suffix to all symbols and in-place update the import
                        symbols = tuple(
                            s.clone(name=f'{s.name}{self.suffix}')
                            if s in call_targets else s for s in im.symbols
                        )
                        im._update(module=new_module_name, symbols=symbols)

                # TODO: Deal with unqualified blanket imports

        if import_map:
            source.spec = Transformer(import_map).visit(source.spec)

    def rename_access_spec_names(self, access_spec, targets=None, active_nodes=None, routines=None):
        """
        Rename target names in an access spec

        For all names in the access spec that are contained in :data:`targets`, rename them as
        ``{name}{self.suffix}``. If :data:`active_nodes` are given, then all names
        that are not in the list of active nodes, are being removed from the list.
        Parameters
        ----------
        access_spec : list of str
            List of names from an access spec
        targets : list of str
            Optional list of subroutine names for which to modify access specs
        active_nodes : list of str
            Optional list of active nodes
        routines : list of :any:`Subroutine`
            Optional list of subroutines
        """
        module_routines = routines or ()
        if active_nodes is not None:
            access_spec = tuple(elem for elem in access_spec if elem in active_nodes
                                or elem.lower() not in module_routines)
        return tuple(
            f'{elem}{self.suffix}' if elem.lower() in module_routines and (not targets or elem in targets)
            else elem
            for elem in access_spec
        )

    def rename_interfaces(self, intfs, targets=None):
        """
        Update explicit interfaces to actively transformed subroutines.

        Parameters
        ----------
        intfs : list of :any:`Interface`
            The list of interfaces to update.
        targets : list of str
            Optional list of subroutine names for which to modify interfaces
        """
        for i in intfs:
            for routine in i.body:
                if isinstance(routine, Subroutine):
                    if targets and routine.name.lower() in targets:
                        routine.name = f'{routine.name}{self.suffix}'

    def generate_interfaces(self, routine):
        """
        Generate external header file with interface block for this subroutine.
        """
        # No need to rename here, as this has already happened before
        intfb_path = self.include_path/f'{routine.name.lower()}.intfb.h'
        with intfb_path.open('w') as f:
            f.write(fgen(routine.interface))
loki-ecmwf-0.3.6/loki/transformations/__init__.py0000664000175000017500000000402715167130205022244 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
"""
Sub-package with supported source code transformation passes.

This sub-package includes general source code transformations and
bespoke :any:`Transformation` and :any:`Pipeline` classes for
IFS-specific source-to-source recipes that target GPUs.
"""

from loki.transformations.array_indexing import * # noqa
from loki.transformations.build_system import * # noqa
from loki.transformations.argument_shape import * # noqa
from loki.transformations.data_offload import * # noqa
from loki.transformations.drhook import * # noqa
from loki.transformations.extract import * # noqa
from loki.transformations.field_api import * # noqa
from loki.transformations.idempotence import * # noqa
from loki.transformations.inline import * # noqa
from loki.transformations.parametrise import * # noqa
from loki.transformations.remove_code import * # noqa
from loki.transformations.sanitise import * # noqa
from loki.transformations.single_column import * # noqa
from loki.transformations.transpile import * # noqa
from loki.transformations.transform_derived_types import * # noqa
from loki.transformations.transform_loop import * # noqa
from loki.transformations.transform_region import * # noqa
from loki.transformations.utilities import * # noqa
from loki.transformations.block_index_transformations import * # noqa
from loki.transformations.split_read_write import * # noqa
from loki.transformations.loop_blocking import * # noqa
from loki.transformations.routine_signatures import * # noqa
from loki.transformations.parallel import * # noqa
from loki.transformations.dependency import * # noqa
from loki.transformations.pragma_model import * # noqa
from loki.transformations.temporaries import * # noqa
loki-ecmwf-0.3.6/loki/transformations/transform_region.py0000664000175000017500000001313615167130205024064 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Collection of utility routines that provide transformations for code regions.

"""
from collections import defaultdict

from loki.ir import (
    Comment, Loop, Pragma, PragmaRegion, FindNodes, FindScopes,
    MaskedTransformer, NestedMaskedTransformer, is_loki_pragma,
    get_pragma_parameters, pragma_regions_attached
)
from loki.logging import info
from loki.tools import as_tuple, flatten

from loki.transformations.array_indexing import (
    promotion_dimensions_from_loop_nest, promote_nonmatching_variables
)


__all__ = ['region_hoist']


def region_hoist(routine):
    """
    Hoist one or multiple code regions annotated by pragma ranges and insert
    them at a specified target location.

    The pragma syntax for annotating the regions to hoist is
    ``!$loki region-hoist [group(group-name)] [collapse(n) [promote(var-name, var-name, ...)]]``
    and ``!$loki end region-hoist``.
    The optional ``group(group-name)`` can be provided when multiple regions
    are to be hoisted and inserted at different positions. Multiple pragma
    ranges can be specified for the same group, all of which are then moved to
    the target location in the same order as the pragma ranges appear.
    The optional ``collapse(n)`` parameter specifies that ``n`` enclosing scopes
    (such as loops, conditionals, etc.) should be re-created at the target location.
    Optionally, this can be combined with variable promotion using ``promote(...)``.
    """
    hoist_targets = defaultdict(list)
    hoist_regions = defaultdict(list)

    # Find all region-hoist pragma regions
    with pragma_regions_attached(routine):
        for region in FindNodes(PragmaRegion).visit(routine.body):
            if is_loki_pragma(region.pragma, starts_with='region-hoist'):
                parameters = get_pragma_parameters(region.pragma, starts_with='region-hoist')
                group = parameters.get('group', 'default')
                hoist_regions[group] += [(region.pragma, region.pragma_post)]

    # Find all region-hoist targets
    for pragma in FindNodes(Pragma).visit(routine.body):
        if is_loki_pragma(pragma, starts_with='region-hoist'):
            parameters = get_pragma_parameters(pragma, starts_with='region-hoist')
            if 'target' in parameters:
                group = parameters.get('group', 'default')
                hoist_targets[group] += [pragma]

    if not hoist_regions:
        return

    # Group-by-group extract the regions and build the node replacement map
    hoist_map = {}
    promotion_vars_dims = {}  # Variables to promote with new dimension
    promotion_vars_index = {}  # Variable subscripts to promote with new indices
    starts, stops = [], []
    for group, regions in hoist_regions.items():
        if not group in hoist_targets or not hoist_targets[group]:
            raise RuntimeError(f'No region-hoist target for group {group} defined.')
        if len(hoist_targets[group]) > 1:
            raise RuntimeError(f'Multiple region-hoist targets given for group {group}')

        hoist_body = ()
        for start, stop in regions:
            parameters = get_pragma_parameters(start, starts_with='region-hoist')

            # Extract the region to hoist
            collapse = int(parameters.get('collapse', 0))
            if collapse > 0:
                scopes = FindScopes(start).visit(routine.body)[0]
                if len(scopes) <= collapse:
                    raise RuntimeError(f'Not enough enclosing scopes for collapse({collapse})')
                scopes = scopes[-(collapse+1):]
                region = NestedMaskedTransformer(start=start, stop=stop, mapper={start: None}).visit(scopes[0])

                # Promote variables given in promotion list
                loops = [scope for scope in scopes if isinstance(scope, Loop)]
                promote_vars = [var.strip().lower()
                                for var in get_pragma_parameters(start).get('promote', '').split(',') if var]
                promotion_vars_dims, promotion_vars_index = promotion_dimensions_from_loop_nest(
                    promote_vars, loops, promotion_vars_dims, promotion_vars_index)
            else:
                region = MaskedTransformer(start=start, stop=stop, mapper={start: None}).visit(routine.body)

            # Append it to the group's body, wrapped in comments
            begin_comment = Comment(f'! Loki {start.content}')
            end_comment = Comment(f'! Loki {stop.content}')
            hoist_body += as_tuple(flatten([begin_comment, region, end_comment]))

            # Register start and end nodes for transformer mask
            starts += [stop]
            stops += [start]

            # Replace end pragma by comment
            comment = Comment(f'! Loki {start.content} - region hoisted')
            hoist_map[stop] = comment

        # Insert target <-> hoisted regions into map
        hoist_map[hoist_targets[group][0]] = hoist_body

    routine.body = MaskedTransformer(active=True, start=starts, stop=stops, mapper=hoist_map).visit(routine.body)
    num_targets = sum(1 for pragma in hoist_map if 'target' in get_pragma_parameters(pragma))
    info('%s: hoisted %d region(s) in %d group(s)', routine.name, len(hoist_map) - num_targets, num_targets)
    promote_nonmatching_variables(routine, promotion_vars_dims, promotion_vars_index)
loki-ecmwf-0.3.6/loki/transformations/tests/0000775000175000017500000000000015167130205021272 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/transformations/tests/test_idempotence.py0000664000175000017500000000455115167130205025204 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import copy
import pytest

from loki import Subroutine, fgen
from loki.frontend import available_frontends

from loki.transformations import IdemTransformation


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_idempotence(frontend):
    """ Test the do-nothing equivalence of :any:`IdemTransformations` """

    fcode_driver = """
  SUBROUTINE column_driver(nlon, nproma, nlev, nz, q, nb)
    INTEGER, INTENT(IN)   :: nlon, nz, nb  ! Size of the horizontal and vertical
    INTEGER, INTENT(IN)   :: nproma, nlev  ! Aliases of horizontal and vertical sizes
    REAL, INTENT(INOUT)   :: q(nlon,nz,nb)
    INTEGER :: b, start, end

    start = 1
    end = nlon
    do b=1, nb
      call compute_column(start, end, nlon, nproma, nz, q(:,:,b))
    end do
  END SUBROUTINE column_driver
"""

    fcode_kernel = """
  SUBROUTINE compute_column(start, end, nlon, nproma, nlev, nz, q)
    INTEGER, INTENT(IN) :: start, end   ! Iteration indices
    INTEGER, INTENT(IN) :: nlon, nz     ! Size of the horizontal and vertical
    INTEGER, INTENT(IN) :: nproma, nlev ! Aliases of horizontal and vertical sizes
    REAL, INTENT(INOUT) :: q(nlon,nz)
    REAL :: t(nlon,nz)
    REAL :: c

    c = 5.345
    DO jk = 2, nz
      DO jl = start, end
        t(jl, jk) = c * jk
        q(jl, jk) = q(jl, jk-1) + t(jl, jk) * c
      END DO
    END DO
  END SUBROUTINE compute_column
"""
    driver = Subroutine.from_source(fcode_driver, frontend=frontend)
    kernel = Subroutine.from_source(fcode_kernel, frontend=frontend)

    driver_before = copy.deepcopy(driver)
    kernel_before = copy.deepcopy(kernel)

    idempotence = IdemTransformation()
    idempotence.apply(driver, role='driver')
    idempotence.apply(kernel, role='kernel')

    assert not id(driver_before.ir) == id(driver.ir)
    assert not id(kernel_before.ir) == id(kernel.ir)
    assert driver_before.ir == driver.ir
    assert kernel_before.ir == kernel.ir
    assert fgen(driver_before) == fgen(driver)
    assert fgen(kernel_before) == fgen(kernel)
loki-ecmwf-0.3.6/loki/transformations/tests/test_remove_code.py0000664000175000017500000005747315167130205025212 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import shutil
import pytest

from loki import Subroutine, Module, Sourcefile, gettempdir
from loki.batch import Scheduler, SchedulerConfig
from loki.frontend import available_frontends, OMNI
from loki.ir import nodes as ir, FindNodes

from loki.transformations.remove_code import (
    do_remove_dead_code, do_remove_marked_regions, do_remove_calls,
    RemoveCodeTransformation, do_remove_unused_vars
)


@pytest.fixture(scope='module', name='srcdir')
def fixture_srcdir():
    """
    Create a src directory in the temp directory
    """
    srcdir = gettempdir()/'test_remove_code'
    if srcdir.exists():
        shutil.rmtree(srcdir)
    srcdir.mkdir()
    yield srcdir
    shutil.rmtree(srcdir)


@pytest.fixture(scope='module', name='source')
def fixture_source(srcdir):
    """
    Write some source files to use in the test
    """
    fcode_driver = """
subroutine rick_astley
    use parkind1, only: jprb
    use yomhook, only : lhook, dr_hook
    use rick_rolled, only : never_gonna_give
    implicit none

    real(kind=jprb) :: zhook_handle
    if (lhook) call dr_hook('rick_astley',0,zhook_handle)
    call never_gonna_give()
    if (lhook) call dr_hook('rick_astley',1,zhook_handle)
end subroutine
    """.strip()

    fcode_kernel = """
module rick_rolled
  type a_type
    integer :: b
  end type
contains
subroutine never_gonna_give
    use parkind1, only: jprb
    use yomhook, only : lhook, dr_hook
    use abor2_mod, only: not_my_abor
    implicit none

    type(a_type) :: a
    real(kind=jprb) :: zhook_handle

    associate(b=>a%b)
    if (lhook) call dr_hook('never_gonna_give',0,zhook_handle)

    CALL ABOR1('[SUBROUTINE CALL]')

    print *, 'never gonna let you down'

    if (dave) call abor1('[INLINE CONDITIONAL]')

    call never_gonna_run_around()

    !$loki remove
    call never_gonna_run_around()
    !$loki end remove

    WRITE(NULOUT,*) "[WRITE INTRINSIC]"
    if (.not. dave) WRITE(NULOUT, *) "[WRITE INTRINSIC]"

    if (lhook) call dr_hook('never_gonna_give',1,zhook_handle)
    end associate
contains

subroutine never_gonna_run_around

    implicit none

    if (lhook) call dr_hook('never_gonna_run_around',0,zhook_handle)

    if (dave) call abor1('[INLINE CONDITIONAL]')
    WRITE(NULOUT,*) "[WRITE INTRINSIC]"
    if (.not. dave) WRITE(NULOUT, *) "[WRITE INTRINSIC]"

    if (lhook) call dr_hook('never_gonna_run_around',1,zhook_handle)

end subroutine never_gonna_run_around

end subroutine
subroutine i_hope_you_havent_let_me_down
    real(kind=jprb) :: zhook_handle
    if (lhook) call dr_hook('i_hope_you_havent_let_me_down',0,zhook_handle)

    if (lhook) call dr_hook('i_hope_you_havent_let_me_down',1,zhook_handle)
end subroutine i_hope_you_havent_let_me_down
end module rick_rolled
    """.strip()

    (srcdir/'rick_astley.F90').write_text(fcode_driver)
    (srcdir/'never_gonna_give.F90').write_text(fcode_kernel)

    yield srcdir

    (srcdir/'rick_astley.F90').unlink()
    (srcdir/'never_gonna_give.F90').unlink()


@pytest.fixture(scope='module', name='source_with_args')
def fixture_source_with_args(srcdir):
    """
    Write some source files with arguments to use in tests.
    """

    fcode_module = """
module types_mod
   type dims_type
       integer :: kst
       integer :: kend
       integer :: klon
   end type dims_type

   type some_unused_type
       real :: a
   end type some_unused_type
end module types_mod
"""

    fcode_driver = """
subroutine driver(dims, StrUct)
    use types_mod, only : dims_type, some_unused_type
    implicit none
    type(dims_type), intent(in) :: dims
    type(some_unused_type), intent(in) :: struct
    type(some_unused_type) :: structs(10)
    real, dimension(dims%klon) :: a, b, c, d


    call kernel(dims%kst, dims%kend, dIms, sTRucT, STRucts, a, b, c, d)

end subroutine driver
"""

    fcode_kernel = """
subroutine kernel(kst, kend, diMs, stRUCt, sTructs, a, b, c, d)
    use types_mod, only : dims_type, some_unused_type
    implicit none
    integer, intent(in) :: kst, kend
    type(dims_type), intent(in) :: dIms
    type(some_unused_type), intent(in) :: StrucT
    type(some_unused_type), intent(in) :: StrucTS(10)
    real, intent(out), dimension(dims%klon) :: a, b, c, d
    real, dimension(dims%klon) :: used_local, unused_local
    integer :: jrof, ji

    used_local(:) = 0.0

    do jrof = kst, kend
      a(jrof) = 0.
      b(jrof) = 0.
    enddo

    do ji=1,10
      strucTs(ji)%a = 1.0
    end do

    !$loki remove
    call an_unused_kernel(stRuCt)
    !$loki end remove

    call another_kernel(kst, kend, d=C, e=D)

end subroutine kernel
"""

    fcode_another_kernel = """
subroutine another_kernel(kst, kend, D, E)
    implicit none
    integer, intent(in) :: kst, kend
    real, intent(out) :: d(:), e(:)
    integer :: jrof

    do jrof = kst, kend
       d(jrof) = 0.
    enddo
end subroutine another_kernel
"""

    (srcdir/'module.F90').write_text(fcode_module)
    (srcdir/'driver.F90').write_text(fcode_driver)
    (srcdir/'kernel.F90').write_text(fcode_kernel)
    (srcdir/'another_kernel.F90').write_text(fcode_another_kernel)

    yield srcdir

    (srcdir/'module.F90').unlink()
    (srcdir/'driver.F90').unlink()
    (srcdir/'kernel.F90').unlink()
    (srcdir/'another_kernel.F90').unlink()


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_dead_code_conditional(frontend):
    """
    Test correct elimination of unreachable conditional branches.
    """
    fcode = """
subroutine test_dead_code_conditional(a, b, flag)
  real(kind=8), intent(inout) :: a, b
  logical, intent(in) :: flag

  if (flag) then
    if (1 == 6) then
      a = a + b
    else
      b = b + 2.0
    end if

    if (2 == 2) then
      b = b + a
    else
      a = a + 3.0
    end if

    if (1 == 2) then
      b = b + a
    elseif (3 == 3) then
      a = a + b
    else
      a = a + 6.0
    end if

  end if
end subroutine test_dead_code_conditional
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)
    # Please note that nested conditionals (elseif) counts as two
    assert len(FindNodes(ir.Conditional).visit(routine.body)) == 5
    assert len(FindNodes(ir.Assignment).visit(routine.body)) == 7

    do_remove_dead_code(routine)

    conditionals = FindNodes(ir.Conditional).visit(routine.body)
    assert len(conditionals) == 1
    assert conditionals[0].condition == 'flag'
    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assigns) == 3
    assert assigns[0].lhs == 'b' and assigns[0].rhs == 'b + 2.0'
    assert assigns[1].lhs == 'b' and assigns[1].rhs == 'b + a'
    assert assigns[2].lhs == 'a' and assigns[2].rhs == 'a + b'


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_dead_code_conditional_nested(frontend):
    """
    Test correct elimination of unreachable branches in nested conditionals.
    """
    fcode = """
subroutine test_dead_code_conditional(a, b, flag)
  real(kind=8), intent(inout) :: a, b
  logical, intent(in) :: flag

  if (1 == 2) then
    a = a + 5
  elseif (flag) then
    b = b + 4
  else
    b = a + 3
  end if

  if (a > 2.0) then
    a = a + 5.0
  elseif (2 == 3) then
    a = a + 3.0
  else
    a = a + 1.0
  endif

  if (a > 2.0) then
    a = a + 5.0
  elseif (2 == 3) then
    a = a + 3.0
  elseif (a > 1.0) then
    a = a + 2.0
  else
    a = a + 1.0
  endif
end subroutine test_dead_code_conditional
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)
    # Please note that nested conditionals (elseif) counts as two
    assert len(FindNodes(ir.Conditional).visit(routine.body)) == 7
    assert len(FindNodes(ir.Assignment).visit(routine.body)) == 10

    do_remove_dead_code(routine)

    conditionals = FindNodes(ir.Conditional).visit(routine.body)
    assert len(conditionals) == 4
    assert conditionals[0].condition == 'flag'
    assert not conditionals[0].has_elseif
    assert conditionals[1].condition == 'a > 2.0'
    assert not conditionals[1].has_elseif
    assert conditionals[2].condition == 'a > 2.0'
    if not frontend == OMNI:  # OMNI does not get elseifs right
        assert conditionals[2].has_elseif
    assert conditionals[3].condition == 'a > 1.0'
    assert not conditionals[3].has_elseif
    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assigns) == 7
    assert assigns[0].lhs == 'b' and assigns[0].rhs == 'b + 4'
    assert assigns[1].lhs == 'b' and assigns[1].rhs == 'a + 3'
    assert assigns[2].lhs == 'a' and assigns[2].rhs == 'a + 5.0'
    assert assigns[3].lhs == 'a' and assigns[3].rhs == 'a + 1.0'
    assert assigns[4].lhs == 'a' and assigns[4].rhs == 'a + 5.0'
    assert assigns[5].lhs == 'a' and assigns[5].rhs == 'a + 2.0'
    assert assigns[6].lhs == 'a' and assigns[6].rhs == 'a + 1.0'


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_dead_code_multiconditional(frontend):
    """
    Test correct elimination of unreachable conditional branches in
    SELECT CASE statements.
    """
    fcode = """
subroutine test_dead_code_multiconditional(a, b, i, flag)
  real(kind=8), intent(inout) :: a, b
  integer, intent(in) :: i
  logical, intent(in) :: flag

  if (flag) then
    select case (2)
    case (1)
      a = a + b
    case (5,2)
      b = b + 2.0
    case (3)
      b = b + a
    case default
      a = a + 3.0
    end select

    select case (i)
    case (1)
      ! Check recursion...
      if (2 == 2) then
        b = b + a
      else
        a = a + 3.0
      end if
    case (2)
      b = b + 4.0
    case (3)
      b = b + 5.0
    case default
      a = a + 6.0
    end select

  end if
end subroutine test_dead_code_multiconditional
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)
    # Please note that nested conditionals (elseif) counts as two
    assert len(FindNodes(ir.MultiConditional).visit(routine.body)) == 2
    assert len(FindNodes(ir.Conditional).visit(routine.body)) == 2
    assert len(FindNodes(ir.Assignment).visit(routine.body)) == 9

    do_remove_dead_code(routine)

    # Check that the first multi-conditional and the nested conditional
    # inside the second conditional have been removed.
    multiconds = FindNodes(ir.MultiConditional).visit(routine.body)
    assert len(multiconds) == 1
    assert multiconds[0].expr == 'i'
    assert len(FindNodes(ir.Conditional).visit(routine.body)) == 1
    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assigns) == 5
    assert assigns[0].lhs == 'b' and assigns[0].rhs == 'b + 2.0'
    assert assigns[1].lhs == 'b' and assigns[1].rhs == 'b + a'
    assert assigns[2].lhs == 'b' and assigns[2].rhs == 'b + 4.0'
    assert assigns[3].lhs == 'b' and assigns[3].rhs == 'b + 5.0'
    assert assigns[4].lhs == 'a' and assigns[4].rhs == 'a + 6.0'


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('mark_with_comment', [True, False])
@pytest.mark.parametrize('replace_with_abort', [True, False])
def test_transform_remove_code_pragma_region(frontend, mark_with_comment, replace_with_abort):
    """
    Test correct removal of pragma-marked code regions.
    """
    fcode = """
subroutine test_remove_code(a, b, n, flag)
  real(kind=8), intent(inout) :: a, b(n)
  integer, intent(in) :: n
  logical, intent(in) :: flag
  integer :: i

  if (flag) then
    a = a + 1.0
  end if

  !$loki remove
  do i=1, n
    !$loki rick-roll
    a = a + 3.0
    !$loki end rick-roll
  end do
  !$loki end remove

  b(:) = 1.0

  !$loki remove no-replacement-call
  b(:) = 42.0
  !$loki end remove

  !$acc parallel
  do i=1, n
    b(i) = b(i) + a

    !$loki remove
    a = b(i) + 42.
    !$loki end remove
  end do
end subroutine test_remove_code
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    if replace_with_abort:
        do_remove_marked_regions(
            routine, mark_with_comment=mark_with_comment,
            replacement_call='ABOR1', replacement_module='ABOR1_MOD',
            replacement_msg='Unsupported code path in {}',
        )
    else:
        do_remove_marked_regions(routine, mark_with_comment=mark_with_comment)

    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assigns) == 3
    assert assigns[0].lhs == 'a' and assigns[0].rhs == 'a + 1.0'
    assert assigns[1].lhs == 'b(:)' and assigns[1].rhs == '1.0'
    assert assigns[2].lhs == 'b(i)' and assigns[2].rhs == 'b(i) + a'

    loops = FindNodes(ir.Loop).visit(routine.body)
    assert len(loops) == 1
    assert assigns[2] in loops[0].body

    comments = [
        c for c in FindNodes(ir.Comment).visit(routine.body)
        if '[Loki] Removed content' in c.text
    ]
    assert len(comments) == (3 if mark_with_comment else 0)

    calls = FindNodes(ir.CallStatement).visit(routine.body)
    imports = FindNodes(ir.Import).visit(routine.spec)
    if replace_with_abort:
        assert len(calls) == 2
        for c in calls:
            # Check that the replacement calls have been inserted
            assert c.name == 'ABOR1'
            assert len(c.arguments) == 1 and not c.kwarguments
            assert c.arguments[0] == 'Unsupported code path in test_remove_code'

        # Check that only one C-import was inserted
        assert len(imports) == 1
        assert imports[0].module == 'ABOR1_MOD'
        assert imports[0].symbols == ('ABOR1',)
        assert not imports[0].c_import
    else:
        assert len(calls) == 0
        assert len(imports) == 0


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('remove_imports', [True, False])
def test_transform_remove_calls(frontend, remove_imports, tmp_path):
    """
    Test removal of utility calls and intrinsics with custom patterns.
    """

    fcode_yomhook = """
module yomhook
  logical lhook
contains
  subroutine dr_hook(name, id, handle)
    character(len=*), intent(in) :: name
    integer(kind=8), intent(in) :: id, handle
  end subroutine dr_hook
end module yomhook
    """

    fcode_abor1 = """
module abor1_mod
implicit none
contains
  subroutine abor1(msg)
    character(len=*), intent(in) :: msg
    write(*,*) msg
  end subroutine abor1
end module abor1_mod
    """

    fcode = """
subroutine never_gonna_give(dave)
    use yomhook, only : lhook, dr_hook
    use abor1_mod, only : abor1
    implicit none

    integer(kind=8), parameter :: NULOUT = 6
    integer, parameter :: jprb = 8
    logical, intent(in) :: dave
    real(kind=jprb) :: zhook_handle
    if (lhook) call dr_hook('never_gonna_give',0,zhook_handle)

    CALL ABOR1('[SUBROUTINE CALL]')

    print *, 'never gonna let you down'

    if (dave) call abor1('[INLINE CONDITIONAL]')

    call never_gonna_run_around()

    WRITE(NULOUT,*) "[WRITE INTRINSIC]"
    if (.not. dave) WRITE(NULOUT, *) "[WRITE INTRINSIC]"

    if (lhook) call dr_hook('never_gonna_give',1,zhook_handle)

end subroutine
    """

    # Parse utility module first, to get type info for OMNI
    Module.from_source(fcode_yomhook, frontend=frontend, xmods=[tmp_path])
    Module.from_source(fcode_abor1, frontend=frontend, xmods=[tmp_path])

    # Parse the main test function and remove calls
    routine = Subroutine.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    # Note that OMNI enforces keyword-arg passing for intrinsic
    # call to ``write``, so we match both conventions.
    do_remove_calls(
        routine, call_names=('ABOR1', 'DR_HOOK'),
        intrinsic_names=('WRITE(NULOUT', 'write(unit=nulout'),
        remove_imports=remove_imports
    )

    # Check that all but one specific call have been removed
    calls = FindNodes(ir.CallStatement).visit(routine.body)
    assert len(calls) == 1
    assert calls[0].name == 'never_gonna_run_around'

    # OMNI resolves inline-conditionals and expands the keyword-args,
    # so neither the inline-conditional removal, nor the intrinsic
    # matching works with it.
    conditionals = FindNodes(ir.Conditional).visit(routine.body)
    assert len(conditionals) == (4 if frontend == OMNI else 0)

    # Check that all intrinsic calls to WRITE have been removed
    intrinsics = FindNodes(ir.Intrinsic).visit(routine.body)
    assert len(intrinsics) == 1
    assert 'never gonna let you down' in intrinsics[0].text

    # Check that the repsective imports have also been stripped
    imports = FindNodes(ir.Import).visit(routine.spec)
    assert len(imports) == 1 if remove_imports else 2
    assert imports[0].module == 'yomhook'
    if remove_imports:
        assert imports[0].symbols == ('lhook',)
    else:
        assert imports[0].symbols == ('lhook', 'dr_hook')
        assert imports[1].module == 'abor1_mod'
        assert imports[1].symbols == ('abor1',)


@pytest.mark.parametrize('frontend', available_frontends(
    skip=[(OMNI, 'Incomplete source tree impossible with OMNI')]
))
@pytest.mark.parametrize('include_intrinsics', (True, False))
@pytest.mark.parametrize('kernel_only', (True, False))
@pytest.mark.parametrize('remove_marked', (True, False))
def test_remove_code_transformation(
        frontend, source, include_intrinsics, kernel_only, remove_marked, tmp_path
):
    """
    Test the use of code removal utilities, in particular the call
    removal, via the scheduler.
    """

    config = {
        'default': {
            'role': 'kernel', 'expand': True, 'strict': False,
            'disable': ['dr_hook', 'abor1']
        },
        'routines': {
            'rick_astley': {'role': 'driver'},
        }
    }
    scheduler_config = SchedulerConfig.from_dict(config)
    scheduler = Scheduler(paths=source, config=scheduler_config, frontend=frontend, xmods=[tmp_path])

    # Apply the transformation to the call tree
    transformation = RemoveCodeTransformation(
        call_names=('ABOR1', 'DR_HOOK'),
        intrinsic_names=('WRITE(NULOUT',) if include_intrinsics else (),
        kernel_only=kernel_only,
        remove_marked_regions=remove_marked,
        replacement_call='ABOR2' if remove_marked else None,
        replacement_msg='!!!Unsupported!!!' if remove_marked else None,
        replacement_module='abor2_mod' if remove_marked else None,
    )
    scheduler.process(transformation=transformation)

    routine = scheduler['rick_rolled#never_gonna_give'].ir
    transformed = routine.to_fortran()

    assert '[SUBROUTINE CALL]' not in transformed
    assert '[INLINE CONDITIONAL]' not in transformed
    assert ('dave' not in transformed) == include_intrinsics
    assert ('[WRITE INTRINSIC]' not in transformed) == include_intrinsics

    # Check that `!$loki remove` added replacement call, but did not duplicate import
    if remove_marked:
        assert ("CALL ABOR2('!!!Unsupported!!!')" in transformed) == remove_marked
        assert transformed.count('USE abor2_mod, ONLY: not_my_abor, ABOR2') == 1

    for r in routine.members:
        transformed = r.to_fortran()
        assert '[SUBROUTINE CALL]' not in transformed
        assert '[INLINE CONDITIONAL]' not in transformed
        assert ('dave' not in transformed) == include_intrinsics

    routine = Sourcefile.from_file(
        source/'never_gonna_give.F90', frontend=frontend
    )['i_hope_you_havent_let_me_down']
    assert 'zhook_handle' in routine.variables
    assert len([call for call in FindNodes(ir.CallStatement).visit(routine.body) if call.name == 'dr_hook']) == 2

    driver = scheduler['#rick_astley'].ir
    drhook_calls = [call for call in FindNodes(ir.CallStatement).visit(driver.body) if call.name == 'dr_hook']
    assert len(drhook_calls) == (2 if kernel_only else 0)


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('kernel_override', [True, False])
def test_remove_code_unused_args(frontend, source_with_args, kernel_override, tmp_path):
    """
    Test the removal of unused arguments in a call tree.
    """

    config = {
        'default': {
            'role': 'kernel', 'expand': True, 'strict': False,
            'enable_imports': True, 'block': ['an_unused_kernel']
        },
        'routines': {
            'driver': {'role': 'driver'},
        }
    }

    if kernel_override:
        config['routines'].update(
            {'another_kernel': {'role': 'kernel', 'remove_unused_args': False}}
        )

    scheduler_config = SchedulerConfig.from_dict(config)
    scheduler = Scheduler(paths=source_with_args, config=scheduler_config, frontend=frontend, xmods=[tmp_path])

    # Apply the code removal transformation
    transformation = RemoveCodeTransformation(remove_unused_args=True)
    scheduler.process(transformation=transformation)

    # check the kernel was transformed correctly
    kernel = scheduler['#kernel'].ir
    driver = scheduler['#driver'].ir

    kernel_calls = FindNodes(ir.CallStatement).visit(kernel.body)
    driver_calls = FindNodes(ir.CallStatement).visit(driver.body)

    assert len(kernel_calls) == 1
    assert kernel_calls[0].name.name.lower() == 'another_kernel'
    assert len(driver_calls) == 1
    assert driver_calls[0].name.name.lower() == 'kernel'

    kernel_vars = [v.clone(dimensions=None) for v in kernel.variables]

    assert 'structs' in kernel_vars
    assert 'structs' in driver_calls[0].arguments
    if kernel_override:
        assert not 'struct' in kernel_vars
        assert not 'struct' in driver_calls[0].arguments

        assert 'd' in kernel_vars
        assert 'd' in driver_calls[0].arguments
    else:
        assert not any(v in kernel_vars for v in ['d', 'struct'])
        assert not any(v in driver_calls[0].arguments for v in ['d', 'struct'])

    assert 'used_local' in kernel_vars
    assert 'unused_local' in kernel_vars

    transformation = RemoveCodeTransformation(remove_unused_vars=True)
    scheduler.process(transformation=transformation)

    kernel_vars = [v.clone(dimensions=None) for v in kernel.variables]
    assert 'used_local' in kernel_vars
    assert 'unused_local' not in kernel_vars


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('remove_only_arrays', (True, False))
def test_remove_code_unused_vars(frontend, remove_only_arrays, tmp_path):
    fcode_some_type = """
module some_type_mod
  type some_type
    integer :: a
  end type some_type
end module some_type_mod
    """

    fcode = """
subroutine test_remove_unused_vars(a, b, c, len, flag)

  use some_type_mod, only: some_type
  implicit none

  real(kind=8), intent(inout) :: a(len, len), b(len), c
  integer, intent(in) :: len
  logical, intent(in) :: flag

  type(some_type) :: some_var, some_vars(len)
  real(kind=8) :: test1, test2, unused1, unused2(len, len)

  test1 = 2
  test2 = 2

end subroutine test_remove_unused_vars
"""
    module = Module.from_source(fcode_some_type, frontend=frontend, xmods=[tmp_path])
    routine = Subroutine.from_source(fcode, frontend=frontend, definitions=module, xmods=[tmp_path])
    do_remove_unused_vars(routine, remove_only_arrays=remove_only_arrays)

    expected_args = ('a', 'b', 'c', 'len', 'flag')
    routine_args = [arg.name.lower() for arg in routine.arguments]
    for arg in expected_args:
        assert arg in routine_args
    if remove_only_arrays:
        expected_locals = ('some_var', 'test1', 'test2', 'unused1')
    else:
        expected_locals = ('test1', 'test2')
    routine_locals = [var.clone(dimensions=None) for var in routine.variables]
    for var in expected_locals:
        assert var in routine_locals


@pytest.mark.parametrize('frontend', available_frontends())
def test_remove_code_nested_regions(frontend):
    fcode = """
subroutine nested_regions(arg)
implicit none
real, intent(inout) :: arg

!$loki remove
!$acc kernels
!$loki end remove
arg = 5
!$loki remove
!$acc end kernels
!$loki end remove
end subroutine nested_regions
    """.strip()

    routine = Subroutine.from_source(fcode, frontend=frontend)
    assert len(FindNodes(ir.Pragma).visit(routine.body)) == 6

    transformation = RemoveCodeTransformation(remove_marked_regions=True)
    transformation.apply(routine)

    assert not FindNodes(ir.Pragma).visit(routine.body)
loki-ecmwf-0.3.6/loki/transformations/tests/__init__.py0000664000175000017500000000057015167130205023405 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
loki-ecmwf-0.3.6/loki/transformations/tests/test_pragma_model.py0000664000175000017500000002016615167130205025337 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki import Module, Subroutine, as_tuple
from loki.frontend import available_frontends

from loki.transformations import PragmaModelTransformation
from loki.ir import FindNodes, Pragma

def check_pragma(pragma, keyword, content, check_for_equality=True):
    assert pragma.keyword == keyword
    if check_for_equality:
        assert pragma.content == content
    else:
        for _content in as_tuple(content):
            assert _content in pragma.content

@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('directive', [False, 'openacc', 'omp-gpu'])
@pytest.mark.parametrize('keep_loki_pragmas', [True, False])
def test_transform_pragma_model(tmp_path, frontend, directive, keep_loki_pragmas):
    """
    Test Pragma model trafo for different directives/flavors.
    """
    fcode_mod = """
    module some_mod
      integer :: a, b
      !$loki create device(a, b)
    end module some_mod
    """.strip()

    fcode = """
subroutine some_func(ret)
  implicit none
  integer, intent(out) :: ret
  integer :: tmp1, tmp2, tmp3, tmp4, jk

  !$loki create device(tmp1, tmp2)
  !$loki update device(tmp1) host(tmp2)
  !$loki unstructured-data in(tmp1, tmp2) create(tmp3, tmp4) attach(tmp1)
  !$loki exit-unstructured-data out(tmp2, tmp3, tmp4) detach(tmp1) delete(tmp1) finalize
  !$loki structured-data in(tmp1) out(tmp2) inout(tmp3) create(tmp4)
  !$loki end structured-data in(tmp1) out(tmp2) inout(tmp3) create(tmp4)
  !$loki loop gang private(tmp1) vlength(128)
  !$loki end loop gang
  !$loki loop vector private(tmp2)
  !$loki end loop vector
  !$loki loop seq
  !$loki end loop seq
  !$loki routine vector
  !$loki routine seq
  !$loki device-present vars(tmp1, tmp2)
  !$loki end device-present vars(tmp1, tmp2)
  !$loki device-ptr vars(tmp1, tmp2)
  !$loki end device-ptr vars(tmp1, tmp2)
  !$loki unmapped-directive whatever(tmp1) foo(tmp2)
  ! misspelled by purpose
  !$loki create drvice(tmp1)
  !$loki structured-data present(tmp3, tmp4)
  !$loki end structured-data
  !$loki structured-data in(tmp1) present(tmp3, tmp4)
  !$loki end structured-data

end subroutine some_func
    """.strip()

    routine = Subroutine.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    module = Module.from_source(fcode_mod, frontend=frontend, xmods=[tmp_path])

    pragma_model_trafo = PragmaModelTransformation(directive=directive,
            keep_loki_pragmas=keep_loki_pragmas)
    pragma_model_trafo.transform_subroutine(routine)
    pragma_model_trafo.transform_module(module)

    # CHECK MODULE
    pragmas = FindNodes(Pragma).visit(module.spec)
    if directive == 'openacc':
        check_pragma(pragmas[0], 'acc', 'declare create(a, b)')
    if directive == 'omp-gpu':
        check_pragma(pragmas[0], 'omp', 'declare target(a, b)')
    if directive is False and keep_loki_pragmas:
        check_pragma(pragmas[0], 'loki', 'create device(a, b)')

    # CHECK ROUTINE
    pragmas = FindNodes(Pragma).visit(routine.ir)
    if directive == 'openacc':
        args = (('acc', 'declare create(tmp1, tmp2)'),
                ('acc', ('update', 'device(tmp1)', 'self(tmp2)'), False),
                ('acc', ('enter data', 'copyin(tmp1, tmp2)', 'create(tmp3, tmp4)', 'attach(tmp1)'), False),
                ('acc', ('exit data', 'copyout(tmp2, tmp3, tmp4)', 'detach(tmp1)', 'delete(tmp1)', 'finalize'), False),
                ('acc', ('data', 'copyin(tmp1)', 'copy(tmp3)', 'copyout(tmp2)', 'create(tmp4)'), False),
                ('acc', 'end data'),
                ('acc', ('parallel loop gang', 'private(tmp1)', 'vector_length(128)'), False),
                ('acc', 'end parallel loop'),
                ('acc', ('loop vector', 'private(tmp2)'), False),
                ('loki', 'end loop vector'),
                ('acc', 'loop seq'),
                ('loki', 'end loop seq'),
                ('acc', 'routine vector'),
                ('acc', 'routine seq'),
                ('acc', 'data present(tmp1, tmp2)'),
                ('acc', 'end data'),
                ('acc', 'data deviceptr(tmp1, tmp2)'),
                ('acc', 'end data'),
                ('loki', 'unmapped-directive whatever(tmp1) foo(tmp2)'),
                ('loki', 'create drvice(tmp1)'),
                ('acc', 'data present(tmp3, tmp4)'),
                ('acc', 'end data'),
                ('acc', ('data', 'copyin(tmp1)', 'present(tmp3, tmp4)'), False),
                ('acc', 'end data'))
    if directive == 'omp-gpu':
        args = (('omp', 'declare target(tmp1, tmp2)'),
                ('omp', ('target update', 'to(tmp1)', 'from(tmp2)'), False),
                ('omp', ('target enter data', 'map(to: tmp1, tmp2)', 'map(alloc: tmp3, tmp4)'), False),
                ('omp', ('target exit data', 'map(from: tmp2, tmp3, tmp4)', 'map(delete: tmp1)'), False),
                ('omp', ('target data', 'map(to: tmp1)', 'map(tofrom: tmp3)',
                    'map(from: tmp2)', 'map(alloc: tmp4)'), False),
                ('omp', 'end target data'),
                ('omp', ('target teams distribute', 'thread_limit(128)'), False),
                ('omp', 'end target teams distribute'),
                ('omp', 'parallel do'),
                ('omp', 'end parallel do'),
                ('loki', 'loop seq'),
                ('loki', 'end loop seq'),
                ('loki', 'routine vector'),
                ('omp', 'declare target'),
                ('loki', ('device-present', 'vars(tmp1, tmp2)'), False),
                ('loki', ('end device-present', 'vars(tmp1, tmp2)'), False),
                ('loki', ('device-ptr', 'vars(tmp1, tmp2)'), False),
                ('loki', ('end device-ptr', 'vars(tmp1, tmp2)'), False),
                ('loki', 'unmapped-directive whatever(tmp1) foo(tmp2)'),
                ('loki', 'create drvice(tmp1)'),
                ('omp', 'target data map(to: tmp3, tmp4)'),
                ('omp', 'end target data'),
                ('omp', ('target data', 'map(to: tmp1, tmp3, tmp4)'), False),
                ('omp', 'end target data'))
    if directive is False:
        args = (('loki', 'create device(tmp1, tmp2)'),
                ('loki', ('update', 'device(tmp1)', 'host(tmp2)'), False),
                ('loki', ('unstructured-data', 'in(tmp1, tmp2)', 'create(tmp3, tmp4)', 'attach(tmp1)', ), False),
                ('loki', ('exit-unstructured-data', 'out(tmp2, tmp3, tmp4)', 'detach(tmp1)', 'delete(tmp1)',
                          'finalize'), False),
                ('loki', ('structured-data', 'in(tmp1)', 'out(tmp2)', 'inout(tmp3)', 'create(tmp4)'), False),
                ('loki', ('end structured-data', 'in(tmp1)', 'out(tmp2)', 'inout(tmp3)', 'create(tmp4)'), False),
                ('loki', ('loop gang', 'private(tmp1)', 'vlength(128)'), False),
                ('loki', 'end loop gang'),
                ('loki', ('loop vector', 'private(tmp2)'), False),
                ('loki', 'end loop vector'),
                ('loki', 'loop seq'),
                ('loki', 'end loop seq'),
                ('loki', 'routine vector'),
                ('loki', 'routine seq'),
                ('loki', ('device-present', 'vars(tmp1, tmp2)'), False),
                ('loki', ('end device-present', 'vars(tmp1, tmp2)'), False),
                ('loki', ('device-ptr', 'vars(tmp1, tmp2)'), False),
                ('loki', ('end device-ptr', 'vars(tmp1, tmp2)'), False),
                ('loki', 'unmapped-directive whatever(tmp1) foo(tmp2)'),
                ('loki', 'create drvice(tmp1)'),
                ('loki', 'structured-data present(tmp3, tmp4)'),
                ('loki', 'end structured-data'),
                ('loki', ('structured-data', 'in(tmp1)', 'present(tmp3, tmp4)'), False),
                ('loki', 'end structured-data'))

    if not keep_loki_pragmas:
        args = tuple(arg for arg in args if arg[0] != 'loki')

    for pragma, _args in zip(pragmas, args):
        check_pragma(pragma, *_args)
loki-ecmwf-0.3.6/loki/transformations/tests/test_dependency.py0000664000175000017500000007044515167130205025033 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
"""
A selection of tests for (proof-of-concept) transformations changing
dependencies through e.g., duplicating or removing kernels (and calls).
"""
import re
from pathlib import Path
import pytest

from loki.batch import Pipeline, ProcedureItem, ModuleItem
from loki import (
    Scheduler, SchedulerConfig, ProcessingStrategy
)
from loki.frontend import available_frontends
from loki.ir import nodes as ir, FindNodes
from loki.tools import as_tuple
from loki.transformations.dependency import (
        DuplicateKernel, RemoveKernel
)
from loki.transformations.build_system import FileWriteTransformation


@pytest.fixture(scope='module', name='here')
def fixture_here():
    return Path(__file__).parent


@pytest.fixture(scope='module', name='testdir')
def fixture_testdir(here):
    return here.parent.parent/'tests'


@pytest.fixture(name='config')
def fixture_config():
    """
    Default configuration dict with basic options.
    """
    return {
        'default': {
            'mode': 'idem',
            'role': 'kernel',
            'expand': True,
            'strict': False,
        },
        'routines': {
            'driver': {
                'role': 'driver',
                'expand': True,
            },
        }
    }


@pytest.fixture(name='fcode_as_module')
def fixture_fcode_as_module(tmp_path):
    fcode_driver = """
subroutine driver(NLON, NB, FIELD1)
    use kernel_mod, only: kernel
    implicit none
    INTEGER, INTENT(IN) :: NLON, NB
    integer :: b
    integer, intent(inout) :: field1(nlon, nb)
    integer :: local_nlon
    local_nlon = nlon
    do b=1,nb
        call kernel(local_nlon, field1(:,b))
    end do
end subroutine driver
    """.strip()
    fcode_kernel = """
module kernel_mod
    implicit none
contains
    subroutine kernel(klon, field1)
        implicit none
        integer, intent(in) :: klon
        integer, intent(inout) :: field1(klon)
        integer :: tmp1(klon)
        integer :: jl

        do jl=1,klon
            tmp1(jl) = 0
            field1(jl) = tmp1(jl)
        end do

    end subroutine kernel
end module kernel_mod
    """.strip()
    (tmp_path/'driver.F90').write_text(fcode_driver)
    (tmp_path/'kernel_mod.F90').write_text(fcode_kernel)


@pytest.fixture(name='fcode_no_module')
def fixture_fcode_no_module(tmp_path):
    fcode_driver = """
subroutine driver(NLON, NB, FIELD1)
    implicit none
    INTEGER, INTENT(IN) :: NLON, NB
    integer :: b
    integer, intent(inout) :: field1(nlon, nb)
    integer :: local_nlon
    local_nlon = nlon
    do b=1,nb
        call kernel(local_nlon, field1(:,b))
    end do
end subroutine driver
    """.strip()
    fcode_kernel = """
subroutine kernel(klon, field1)
    implicit none
    integer, intent(in) :: klon
    integer, intent(inout) :: field1(klon)
    integer :: tmp1(klon)
    integer :: jl

    do jl=1,klon
        tmp1(jl) = 0
        field1(jl) = tmp1(jl)
    end do

end subroutine kernel
    """.strip()
    (tmp_path/'driver.F90').write_text(fcode_driver)
    (tmp_path/'kernel.F90').write_text(fcode_kernel)


@pytest.mark.usefixtures('fcode_as_module')
@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('suffix,module_suffix', (
    ('_duplicated', None), ('_dupl1', '_dupl2'), ('_d_test_1', '_d_test_2')
))
@pytest.mark.parametrize('full_parse', (True, False))
def test_dependency_duplicate_plan(tmp_path, frontend, suffix, module_suffix, config, full_parse):

    scheduler = Scheduler(
        paths=[tmp_path], config=SchedulerConfig.from_dict(config),
        frontend=frontend, xmods=[tmp_path], full_parse=full_parse
    )

    pipeline = Pipeline(classes=(DuplicateKernel, FileWriteTransformation),
                        duplicate_kernels=('kernel',), duplicate_suffix=suffix,
                        duplicate_module_suffix=module_suffix)

    plan_file = tmp_path/'plan.cmake'
    scheduler.process(pipeline, proc_strategy=ProcessingStrategy.PLAN)
    scheduler.write_cmake_plan(filepath=plan_file, rootpath=tmp_path)

    module_suffix = module_suffix or suffix

    # Validate the Scheduler graph:
    # - New procedure item has been added
    # - Module item has been created but is not in the sgraph
    assert f'kernel_mod{module_suffix}' in scheduler.item_factory.item_cache
    item = scheduler.item_factory.item_cache[f'kernel_mod{module_suffix}']
    assert isinstance(item, ModuleItem)
    assert item.ir.name == item.local_name
    assert f'kernel_mod{module_suffix}' not in scheduler

    assert f'kernel_mod{module_suffix}#kernel{suffix}' in scheduler.item_factory.item_cache
    assert f'kernel_mod{module_suffix}#kernel{suffix}' in scheduler
    item = scheduler[f'kernel_mod{module_suffix}#kernel{suffix}']
    assert isinstance(item, ProcedureItem)
    assert item.ir.name == item.local_name

    # Validate the plan file content
    plan_pattern = re.compile(r'set\(\s*(\w+)\s*(.*?)\s*\)', re.DOTALL)
    loki_plan = plan_file.read_text()
    plan_dict = {k: v.split() for k, v in plan_pattern.findall(loki_plan)}
    plan_dict = {k: {Path(s).stem for s in v} for k, v in plan_dict.items()}
    assert plan_dict['LOKI_SOURCES_TO_TRANSFORM'] == {'kernel_mod', 'driver'}
    assert plan_dict['LOKI_SOURCES_TO_REMOVE'] == {'kernel_mod', 'driver'}
    assert plan_dict['LOKI_SOURCES_TO_APPEND'] == {f'kernel_mod{module_suffix}.idem', 'kernel_mod.idem', 'driver.idem'}


@pytest.mark.usefixtures('fcode_as_module')
@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('suffix,module_suffix', (
    ('_duplicated', None), ('_dupl1', '_dupl2'), ('_d_test_1', '_d_test_2')
))
def test_dependency_duplicate_trafo(tmp_path, frontend, suffix, module_suffix, config):

    scheduler = Scheduler(
        paths=[tmp_path], config=SchedulerConfig.from_dict(config),
        frontend=frontend, xmods=[tmp_path]
    )

    pipeline = Pipeline(classes=(DuplicateKernel, FileWriteTransformation),
                        duplicate_kernels=('kernel',), duplicate_suffix=suffix,
                        duplicate_module_suffix=module_suffix)

    scheduler.process(pipeline)

    module_suffix = module_suffix or suffix

    # Validate the Scheduler graph:
    # - New procedure item has been added
    # - Module item has been created but is not in the sgraph
    assert f'kernel_mod{module_suffix}' in scheduler.item_factory.item_cache
    item = scheduler.item_factory.item_cache[f'kernel_mod{module_suffix}']
    assert isinstance(item, ModuleItem)
    assert item.ir.name == item.local_name
    assert f'kernel_mod{module_suffix}' not in scheduler

    assert f'kernel_mod{module_suffix}#kernel{suffix}' in scheduler.item_factory.item_cache
    assert f'kernel_mod{module_suffix}#kernel{suffix}' in scheduler
    item = scheduler[f'kernel_mod{module_suffix}#kernel{suffix}']
    assert isinstance(item, ProcedureItem)
    assert item.ir.name == item.local_name

    driver = scheduler["#driver"].ir
    kernel = scheduler["kernel_mod#kernel"].ir
    new_kernel = scheduler[f"kernel_mod{module_suffix}#kernel{suffix}"].ir

    calls_driver = FindNodes(ir.CallStatement).visit(driver.body)
    assert len(calls_driver) == 2
    assert new_kernel is not kernel
    assert calls_driver[0].routine == kernel
    assert calls_driver[1].routine == new_kernel


@pytest.mark.usefixtures('fcode_as_module')
@pytest.mark.parametrize('frontend', available_frontends())
def test_dependency_remove(tmp_path, frontend, config):

    scheduler = Scheduler(
        paths=[tmp_path], config=SchedulerConfig.from_dict(config),
        frontend=frontend, xmods=[tmp_path]
    )
    pipeline = Pipeline(classes=(RemoveKernel, FileWriteTransformation),
                        remove_kernels=('kernel',))

    plan_file = tmp_path/'plan.cmake'
    root_path = tmp_path
    scheduler.process(pipeline, proc_strategy=ProcessingStrategy.PLAN)
    scheduler.write_cmake_plan(filepath=plan_file, rootpath=root_path)

    # Validate the plan file content
    plan_pattern = re.compile(r'set\(\s*(\w+)\s*(.*?)\s*\)', re.DOTALL)
    loki_plan = plan_file.read_text()
    plan_dict = {k: v.split() for k, v in plan_pattern.findall(loki_plan)}
    plan_dict = {k: {Path(s).stem for s in v} for k, v in plan_dict.items()}
    assert plan_dict['LOKI_SOURCES_TO_TRANSFORM'] == {'driver'}
    assert plan_dict['LOKI_SOURCES_TO_REMOVE'] == {'driver'}
    assert plan_dict['LOKI_SOURCES_TO_APPEND'] == {'driver.idem'}

    scheduler.process(pipeline)
    driver = scheduler["#driver"].ir
    assert "kernel_mod#kernel" not in scheduler

    calls_driver = FindNodes(ir.CallStatement).visit(driver.body)
    assert len(calls_driver) == 0


@pytest.mark.usefixtures('fcode_no_module')
@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('suffix, module_suffix', (
    ('_duplicated', None), ('_dupl1', '_dupl2'), ('_d_test_1', '_d_test_2')
))
@pytest.mark.parametrize('full_parse', (True, False))
def test_dependency_duplicate_plan_no_module(tmp_path, frontend, suffix, module_suffix, config, full_parse):

    scheduler = Scheduler(
        paths=[tmp_path], config=SchedulerConfig.from_dict(config),
        frontend=frontend, xmods=[tmp_path], full_parse=full_parse
    )

    pipeline = Pipeline(classes=(DuplicateKernel, FileWriteTransformation),
                        duplicate_kernels=('kernel',), duplicate_suffix=suffix,
                        duplicate_module_suffix=module_suffix)

    plan_file = tmp_path/'plan.cmake'
    scheduler.process(pipeline, proc_strategy=ProcessingStrategy.PLAN)
    scheduler.write_cmake_plan(filepath=plan_file, rootpath=tmp_path)

    # Validate Scheduler graph
    assert f'#kernel{suffix}' in scheduler.item_factory.item_cache
    assert f'#kernel{suffix}' in scheduler
    assert isinstance(scheduler[f'#kernel{suffix}'], ProcedureItem)
    assert scheduler[f'#kernel{suffix}'].ir.name == f'kernel{suffix}'

    # Validate IR objects
    kernel = scheduler["#kernel"].ir
    new_kernel = scheduler[f"#kernel{suffix}"].ir
    assert new_kernel is not kernel

    # Validate the plan file content
    plan_pattern = re.compile(r'set\(\s*(\w+)\s*(.*?)\s*\)', re.DOTALL)
    loki_plan = plan_file.read_text()
    plan_dict = {k: v.split() for k, v in plan_pattern.findall(loki_plan)}
    plan_dict = {k: {Path(s).stem for s in v} for k, v in plan_dict.items()}
    assert plan_dict['LOKI_SOURCES_TO_TRANSFORM'] == {'kernel', 'driver'}
    assert plan_dict['LOKI_SOURCES_TO_REMOVE'] == {'kernel', 'driver'}
    assert plan_dict['LOKI_SOURCES_TO_APPEND'] == {f'kernel{suffix}.idem', 'kernel.idem', 'driver.idem'}


@pytest.mark.usefixtures('fcode_no_module')
@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('suffix, module_suffix', (
    ('_duplicated', None), ('_dupl1', '_dupl2'), ('_d_test_1', '_d_test_2')
))
def test_dependency_duplicate_trafo_no_module(tmp_path, frontend, suffix, module_suffix, config):

    scheduler = Scheduler(
        paths=[tmp_path], config=SchedulerConfig.from_dict(config),
        frontend=frontend, xmods=[tmp_path]
    )

    pipeline = Pipeline(classes=(DuplicateKernel, FileWriteTransformation),
                        duplicate_kernels=('kernel',), duplicate_suffix=suffix,
                        duplicate_module_suffix=module_suffix)

    scheduler.process(pipeline)

    # Validate Scheduler graph
    assert f'#kernel{suffix}' in scheduler.item_factory.item_cache
    assert f'#kernel{suffix}' in scheduler
    assert isinstance(scheduler[f'#kernel{suffix}'], ProcedureItem)
    assert scheduler[f'#kernel{suffix}'].ir.name == f'kernel{suffix}'

    # Validate transformed objects
    driver = scheduler["#driver"].ir
    kernel = scheduler["#kernel"].ir
    new_kernel = scheduler[f"#kernel{suffix}"].ir

    calls_driver = FindNodes(ir.CallStatement).visit(driver.body)
    assert len(calls_driver) == 2
    assert new_kernel is not kernel
    assert calls_driver[0].routine == kernel
    assert calls_driver[1].routine == new_kernel


@pytest.mark.usefixtures('fcode_no_module')
@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('full_parse', (True, False))
def test_dependency_remove_plan_no_module(tmp_path, frontend, config, full_parse):

    scheduler = Scheduler(
        paths=[tmp_path], config=SchedulerConfig.from_dict(config),
        frontend=frontend, xmods=[tmp_path], full_parse=full_parse
    )
    pipeline = Pipeline(classes=(RemoveKernel, FileWriteTransformation),
                        remove_kernels=('kernel',))

    plan_file = tmp_path/'plan.cmake'
    scheduler.process(pipeline, proc_strategy=ProcessingStrategy.PLAN)
    scheduler.write_cmake_plan(filepath=plan_file, rootpath=tmp_path)

    # Validate the plan file content
    plan_pattern = re.compile(r'set\(\s*(\w+)\s*(.*?)\s*\)', re.DOTALL)
    loki_plan = plan_file.read_text()
    plan_dict = {k: v.split() for k, v in plan_pattern.findall(loki_plan)}
    plan_dict = {k: {Path(s).stem for s in v} for k, v in plan_dict.items()}
    assert plan_dict['LOKI_SOURCES_TO_TRANSFORM'] == {'driver'}
    assert plan_dict['LOKI_SOURCES_TO_REMOVE'] == {'driver'}
    assert plan_dict['LOKI_SOURCES_TO_APPEND'] == {'driver.idem'}

    assert '#kernel' not in scheduler


@pytest.mark.usefixtures('fcode_no_module')
@pytest.mark.parametrize('frontend', available_frontends())
def test_dependency_remove_trafo_no_module(tmp_path, frontend, config):

    scheduler = Scheduler(
        paths=[tmp_path], config=SchedulerConfig.from_dict(config),
        frontend=frontend, xmods=[tmp_path]
    )
    pipeline = Pipeline(classes=(RemoveKernel, FileWriteTransformation),
                        remove_kernels=('kernel',))

    scheduler.process(pipeline)
    driver = scheduler["#driver"].ir
    assert "#kernel" not in scheduler

    assert not FindNodes(ir.CallStatement).visit(driver.body)


@pytest.mark.usefixtures('fcode_as_module')
@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('duplicate_kernels,remove_kernels', (
    ('kernel', 'kernel'), ('kernel', 'kernel_new'), ('kernel', None), (None, 'kernel')
))
@pytest.mark.parametrize('full_parse', (True, False))
def test_dependency_duplicate_remove_plan(tmp_path, frontend, duplicate_kernels, remove_kernels,
                                          config, full_parse):

    scheduler = Scheduler(
        paths=[tmp_path], config=SchedulerConfig.from_dict(config),
        frontend=frontend, xmods=[tmp_path], full_parse=full_parse
    )

    expected_items = {'kernel_mod#kernel', '#driver'}
    assert {item.name for item in scheduler.items} == expected_items

    pipeline = Pipeline(classes=(DuplicateKernel, RemoveKernel, FileWriteTransformation),
                        duplicate_kernels=duplicate_kernels, duplicate_suffix='_new',
                        remove_kernels=remove_kernels)

    plan_file = tmp_path/'plan.cmake'
    scheduler.process(pipeline, proc_strategy=ProcessingStrategy.PLAN)
    scheduler.write_cmake_plan(filepath=plan_file, rootpath=tmp_path)

    for kernel in as_tuple(duplicate_kernels):
        for name in list(expected_items):
            scope_name, local_name = name.split('#')
            if local_name == kernel:
                expected_items.add(f'{scope_name}_new#{local_name}_new')

    for kernel in as_tuple(remove_kernels):
        for name in list(expected_items):
            scope_name, local_name = name.split('#')
            if local_name == kernel:
                expected_items.remove(name)

    # Validate Scheduler graph
    assert {item.name for item in scheduler.items} == expected_items

    # Validate the plan file content
    plan_pattern = re.compile(r'set\(\s*(\w+)\s*(.*?)\s*\)', re.DOTALL)
    loki_plan = plan_file.read_text()
    plan_dict = {k: v.split() for k, v in plan_pattern.findall(loki_plan)}
    plan_dict = {k: {Path(s).stem for s in v} for k, v in plan_dict.items()}

    transformed_items = {name.split('#')[0] or name[1:] for name in expected_items if not name.endswith('_new')}
    assert plan_dict['LOKI_SOURCES_TO_TRANSFORM'] == transformed_items
    assert plan_dict['LOKI_SOURCES_TO_REMOVE'] == transformed_items
    assert plan_dict['LOKI_SOURCES_TO_APPEND'] == {f'{name.split("#")[0] or name[1:]}.idem' for name in expected_items}


@pytest.mark.usefixtures('fcode_no_module')
@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('duplicate_kernels,remove_kernels', (
    ('kernel', 'kernel'), ('kernel', 'kernel_new'), ('kernel', None), (None, 'kernel')
))
@pytest.mark.parametrize('full_parse', (True, False))
def test_dependency_duplicate_remove_plan_no_module(tmp_path, frontend, duplicate_kernels, remove_kernels,
                                                    config, full_parse):

    scheduler = Scheduler(
        paths=[tmp_path], config=SchedulerConfig.from_dict(config),
        frontend=frontend, xmods=[tmp_path], full_parse=full_parse
    )

    expected_items = {'#kernel', '#driver'}
    assert {item.name for item in scheduler.items} == expected_items

    pipeline = Pipeline(classes=(DuplicateKernel, RemoveKernel, FileWriteTransformation),
                        duplicate_kernels=duplicate_kernels, duplicate_suffix='_new',
                        remove_kernels=remove_kernels)

    plan_file = tmp_path/'plan.cmake'
    scheduler.process(pipeline, proc_strategy=ProcessingStrategy.PLAN)
    scheduler.write_cmake_plan(filepath=plan_file, rootpath=tmp_path)

    if duplicate_kernels:
        expected_items.add(f'#{duplicate_kernels}_new')

    if remove_kernels:
        expected_items.remove(f'#{remove_kernels}')

    # Validate Scheduler graph
    assert {item.name for item in scheduler.items} == expected_items

    # Validate the plan file content
    plan_pattern = re.compile(r'set\(\s*(\w+)\s*(.*?)\s*\)', re.DOTALL)
    loki_plan = plan_file.read_text()
    plan_dict = {k: v.split() for k, v in plan_pattern.findall(loki_plan)}
    plan_dict = {k: {Path(s).stem for s in v} for k, v in plan_dict.items()}

    transformed_items = {name[1:] for name in expected_items if not name.endswith('_new')}
    assert plan_dict['LOKI_SOURCES_TO_TRANSFORM'] == transformed_items
    assert plan_dict['LOKI_SOURCES_TO_REMOVE'] == transformed_items
    assert plan_dict['LOKI_SOURCES_TO_APPEND'] == {f'{name[1:]}.idem' for name in expected_items}

@pytest.fixture(name='fcode_as_module_extended')
def fixture_fcode_as_module_extended(tmp_path):
    fcode_driver = """
subroutine driver(NLON, NB, FIELD1)
    use kernel_mod, only: kernel
    implicit none
    INTEGER, INTENT(IN) :: NLON, NB
    integer :: b
    integer, intent(inout) :: field1(nlon, nb)
    integer :: local_nlon
    local_nlon = nlon
    do b=1,nb
        call kernel(local_nlon, field1(:,b))
    end do
end subroutine driver
    """.strip()
    fcode_kernel = """
module kernel_mod
    implicit none
contains
    subroutine kernel(klon, field1)
        use iso_fortran_env, only: real64
        use kernel_nested_mod, only: kernel_nested_vector, kernel_nested_seq
        use compute_2_mod, only: compute_2, compute_2_1
        use compute_3_mod, only: compute_3
        implicit none
        integer, intent(in) :: klon
        integer, intent(inout) :: field1(klon)
        integer :: tmp1(klon)
        integer :: jl

        call kernel_nested_vector(klon, field1)

        do jl=1,klon
            call kernel_nested_seq(field1(jl))
            call compute_2(field1(jl))
            call compute_2_1(field1(jl))
            tmp1(jl) = 0
            field1(jl) = tmp1(jl)
            call compute_3(field1(jl))
        end do

    end subroutine kernel
end module kernel_mod
    """.strip()
    fcode_kernel_nested = """
module kernel_nested_mod
    implicit none
contains
    subroutine kernel_nested_vector(klon, field1)
        implicit none
        integer, intent(in) :: klon
        integer, intent(inout) :: field1(klon)
        integer :: tmp1(klon)
        integer :: jl

        do jl=1,klon
            tmp1(jl) = 0
            field1(jl) = tmp1(jl)
        end do

    end subroutine kernel_nested_vector
    subroutine kernel_nested_seq(val)
        use compute_1_mod, only: compute_1
        implicit none
        integer, intent(inout) :: val

        val = 0
        call compute_1(val)

    end subroutine kernel_nested_seq
end module kernel_nested_mod
    """.strip()
    fcode_compute_1 = """
module compute_1_mod
    implicit none
contains
    subroutine compute_1(val)
        implicit none
        integer, intent(inout) :: val

        val = 0

    end subroutine compute_1
end module compute_1_mod
    """.strip()
    fcode_compute_2 = """
module compute_2_mod
    implicit none
contains
    subroutine compute_2(val)
        implicit none
        integer, intent(inout) :: val

        val = 0

    end subroutine compute_2
    subroutine compute_2_1(val)
        implicit none
        integer, intent(inout) :: val

        val = 0

    end subroutine compute_2_1
end module compute_2_mod
    """.strip()
    fcode_compute_3 = """
module compute_3_mod
    implicit none
contains
    subroutine compute_3(val)
        implicit none
        integer, intent(inout) :: val

        val = 0

    end subroutine compute_3
end module compute_3_mod
    """.strip()

    (tmp_path/'driver.F90').write_text(fcode_driver)
    (tmp_path/'kernel_mod.F90').write_text(fcode_kernel)
    (tmp_path/'kernel_nested_mod.F90').write_text(fcode_kernel_nested)
    (tmp_path/'compute_1_mod.F90').write_text(fcode_compute_1)
    (tmp_path/'compute_2_mod.F90').write_text(fcode_compute_2)
    (tmp_path/'compute_3_mod.F90').write_text(fcode_compute_3)


@pytest.mark.usefixtures('fcode_as_module_extended')
@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('suffix,module_suffix', (
    ('_duplicated', None), ('_dupl1', '_dupl2'), ('_d_test_1', '_d_test_2')
))
@pytest.mark.parametrize('full_parse', (True, False))
@pytest.mark.parametrize('duplicate_subgraph', (True, False))
def test_dependency_duplicate_subgraph(tmp_path, frontend, suffix, module_suffix, config,
                                       full_parse, duplicate_subgraph):

    config['routines']['kernel'] = {'role': 'kernel', 'ignore': ['compute_2_1', 'compute_3']}
    scheduler = Scheduler(
        paths=[tmp_path], config=SchedulerConfig.from_dict(config),
        frontend=frontend, xmods=[tmp_path], full_parse=full_parse
    )

    pipeline = Pipeline(classes=(DuplicateKernel, FileWriteTransformation),
                        duplicate_kernels=('kernel',), duplicate_suffix=suffix,
                        duplicate_module_suffix=module_suffix,
                        duplicate_subgraph=duplicate_subgraph)

    module_suffix = module_suffix or suffix

    plan_file = tmp_path/'plan.cmake'

    # dry-run for planning
    scheduler.process(pipeline, proc_strategy=ProcessingStrategy.PLAN)
    scheduler.write_cmake_plan(filepath=plan_file, rootpath=tmp_path)

    expected_items = {'#driver', 'kernel_mod#kernel', 'kernel_nested_mod#kernel_nested_vector',
            'kernel_nested_mod#kernel_nested_seq', 'compute_1_mod#compute_1', 'compute_2_mod#compute_2',
            'compute_2_mod#compute_2_1', 'compute_3_mod#compute_3'
    }
    standard_items = {'iso_fortran_env'}
    expected_items |= {f'kernel_mod{module_suffix}#kernel{suffix}'}
    if duplicate_subgraph:
        expected_items |= {f'kernel_nested_mod{module_suffix}#kernel_nested_vector{suffix}',
                f'kernel_nested_mod{module_suffix}#kernel_nested_seq{suffix}',
                f'compute_1_mod{module_suffix}#compute_1{suffix}',
                f'compute_2_mod{module_suffix}#compute_2{suffix}'
        }
    # Validate Scheduler graph
    assert {item.name for item in scheduler.items} == expected_items | standard_items

    # Validate the plan file content
    plan_pattern = re.compile(r'set\(\s*(\w+)\s*(.*?)\s*\)', re.DOTALL)
    loki_plan = plan_file.read_text()
    plan_dict = {k: v.split() for k, v in plan_pattern.findall(loki_plan)}
    plan_dict = {k: {Path(s).stem for s in v} for k, v in plan_dict.items()}

    transformed_items = {name.split('#')[0] if name.split('#')[0] else name[1:]
                         for name in expected_items if not name.endswith(f'{suffix}')}
    assert plan_dict['LOKI_SOURCES_TO_TRANSFORM'] == transformed_items
    assert plan_dict['LOKI_SOURCES_TO_REMOVE'] == transformed_items
    appended_items = {name.split('#')[0] if name.split('#')[0] else name[1:] for name in expected_items}
    appended_items = {f'{name}.idem' for name in appended_items}
    assert plan_dict['LOKI_SOURCES_TO_APPEND'] == appended_items

    # actual transformation(s) if fully parsed
    if full_parse:
        scheduler.process(pipeline)

        if duplicate_subgraph:
            dupl_kernel_imports = [('iso_fortran_env', ['real64']), (f'kernel_nested_mod{module_suffix}',
                                    [f'kernel_nested_vector{suffix}', f'kernel_nested_seq{suffix}']),
                                   (f'compute_2_mod{module_suffix}', [f'compute_2{suffix}']),
                                   ('compute_2_mod', ['compute_2_1']),
                                   ('compute_3_mod', ['compute_3'])]
            dupl_kernel_calls = [f'kernel_nested_vector{suffix}', f'kernel_nested_seq{suffix}', f'compute_2{suffix}',
                    'compute_2_1', 'compute_3']
        else:
            dupl_kernel_imports = [('iso_fortran_env', ['real64']),
                                   ('kernel_nested_mod', ['kernel_nested_vector', 'kernel_nested_seq']),
                                   ('compute_2_mod', ['compute_2', 'compute_2_1']),
                                   ('compute_3_mod', ['compute_3'])]
            dupl_kernel_calls = ['kernel_nested_vector', 'kernel_nested_seq', 'compute_2', 'compute_2_1', 'compute_3']

        expected_imports = {
                '#driver': [(f'kernel_mod{module_suffix}', [f'kernel{suffix}']), ('kernel_mod', ['kernel'])],
                'kernel_mod#kernel': [('iso_fortran_env', ['real64']),
                                      ('kernel_nested_mod', ['kernel_nested_vector', 'kernel_nested_seq']),
                                      ('compute_2_mod', ['compute_2', 'compute_2_1']),
                                      ('compute_3_mod', ['compute_3'])],
                'kernel_nested_mod#kernel_nested_vector': [],
                'kernel_nested_mod#kernel_nested_seq': [('compute_1_mod', ['compute_1'])],
                'compute_1_mod#compute_1': [],
                'compute_2_mod#compute_2': [],
                f'kernel_mod{module_suffix}#kernel{suffix}': dupl_kernel_imports 
        }
        expected_calls = {
                '#driver': ['kernel', f'kernel{suffix}'],
                'kernel_mod#kernel': ['kernel_nested_vector', 'kernel_nested_seq', 'compute_2',
                    'compute_2_1', 'compute_3'],
                'kernel_nested_mod#kernel_nested_vector': [],
                'kernel_nested_mod#kernel_nested_seq': ['compute_1'],
                'compute_1_mod#compute_1': [],
                'compute_2_mod#compute_2': [],
                f'kernel_mod{module_suffix}#kernel{suffix}': dupl_kernel_calls
        }

        if duplicate_subgraph:
            expected_imports |= {
                    f'kernel_nested_mod{module_suffix}#kernel_nested_vector{suffix}': [],
                    f'kernel_nested_mod{module_suffix}#kernel_nested_seq{suffix}': [(f'compute_1_mod{module_suffix}',
                                                                                     [f'compute_1{suffix}'])],
                    f'compute_1_mod{module_suffix}#compute_1{suffix}': [],
                    f'compute_2_mod{module_suffix}#compute_2{suffix}': []
            }
            expected_calls |= {
                    f'kernel_nested_mod{module_suffix}#kernel_nested_vector{suffix}': [],
                    f'kernel_nested_mod{module_suffix}#kernel_nested_seq{suffix}':  [f'compute_1{suffix}'],
                    f'compute_1_mod{module_suffix}#compute_1{suffix}': [],
                    f'compute_2_mod{module_suffix}#compute_2{suffix}': []
            }

        for item_name, calls in expected_calls.items():
            routine = scheduler[item_name].ir
            calls = [str(call.name).lower() for call in FindNodes(ir.CallStatement).visit(routine.body)]
            imports = [(imp.module.lower(), [symb.name.lower() for symb in imp.symbols]) for imp in routine.imports]
            assert calls == expected_calls[item_name]
            assert imports == expected_imports[item_name]
loki-ecmwf-0.3.6/loki/transformations/tests/test_cloudsc.py0000664000175000017500000001054415167130205024343 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import os
import io
import resource
from subprocess import CalledProcessError
from pathlib import Path
import pandas as pd
import pytest

from loki.tools import (
    execute, write_env_launch_script, local_loki_setup, local_loki_cleanup
)
from loki.frontend import available_frontends, OMNI, HAVE_FP
from loki.logging import warning

pytestmark = pytest.mark.skipif('CLOUDSC_DIR' not in os.environ, reason='CLOUDSC_DIR not set')


@pytest.fixture(scope='module', name='here')
def fixture_here():
    return Path(os.environ['CLOUDSC_DIR'])


@pytest.fixture(scope='module', name='local_loki_bundle')
def fixture_local_loki_bundle(here):
    """Call setup utilities for injecting ourselves into the CLOUDSC bundle"""
    lokidir, target, backup = local_loki_setup(here)
    yield lokidir
    local_loki_cleanup(target, backup)


@pytest.fixture(scope='module', name='bundle_create')
def fixture_bundle_create(here, local_loki_bundle):
    """Inject ourselves into the CLOUDSC bundle"""
    env = os.environ.copy()
    env['CLOUDSC_BUNDLE_LOKI_DIR'] = local_loki_bundle

    # Run ecbundle to fetch dependencies
    execute(
        ['./cloudsc-bundle', 'create'], cwd=here, silent=False, env=env
    )


@pytest.mark.usefixtures('bundle_create')
@pytest.mark.parametrize('frontend', available_frontends(
    skip=[(OMNI, 'OMNI needs FParser for parsing headers')] if not HAVE_FP else None
))
def test_cloudsc(here, frontend):
    build_cmd = [
        './cloudsc-bundle', 'build', '--retry-verbose', '--clean',
        '--with-loki=ON', '--loki-frontend=' + str(frontend), '--without-loki-install',
        '--with-double-precision=ON', '--with-single-precision=ON'
    ]

    if 'CLOUDSC_ARCH' in os.environ:
        build_cmd += [f"--arch={os.environ['CLOUDSC_ARCH']}"]

    execute(build_cmd, cwd=here, silent=False)

    # Raise stack limit
    resource.setrlimit(resource.RLIMIT_STACK, (resource.RLIM_INFINITY, resource.RLIM_INFINITY))
    env = os.environ.copy()
    env.update({'OMP_STACKSIZE': '2G', 'NVCOMPILER_ACC_CUDA_HEAPSIZE': '2G'})

    # For some reason, the 'data' dir symlink is not created???
    os.symlink(here/'data', here/'build/data')

    # Run the produced binaries
    binaries = [('dwarf-cloudsc-loki-c-dp', '2', '16000', '32')]
    for prec in ('dp', 'sp'):
        binaries += [
            (f'dwarf-cloudsc-loki-idem-{prec}', '2', '16000', '32'),
            (f'dwarf-cloudsc-loki-idem-stack-{prec}', '2', '16000', '32'),
            (f'dwarf-cloudsc-loki-scc-{prec}', '1', '16000', '32'),
            (f'dwarf-cloudsc-loki-scc-hoist-{prec}', '1', '16000', '32'),
            (f'dwarf-cloudsc-loki-scc-stack-{prec}', '1', '16000', '32'),
        ]

    failures, warnings = {}, {}
    for binary, *args in binaries:
        # Write a script to source env.sh and launch the binary
        script = write_env_launch_script(here, binary, args)

        # Run the script and verify error norms
        try:
            output = execute([str(script)], cwd=here/'build', capture_output=True, silent=False, env=env)
            results = pd.read_fwf(io.StringIO(output.stdout.decode()), index_col='Variable')
            no_errors = results['AbsMaxErr'].astype('float') == 0
            if not no_errors.all(axis=None):
                only_small_errors = results['MaxRelErr-%'].astype('float') < 1e-12
                # We report only validation failures for double-precision as the single-precision
                # result validation is known to fail due to a lack of suitable reference data
                if binary.endswith('-dp') and not only_small_errors.all(axis=None):
                    failures[binary] = results
                else:
                    warnings[binary] = results
        except CalledProcessError as err:
            failures[binary] = err.stderr.decode()

    if warnings:
        msg = '\n'.join([f'{binary}:\n{results}' for binary, results in warnings.items()])
        warning(msg)

    if failures:
        msg = '\n'.join([f'{binary}:\n{results}' for binary, results in failures.items()])
        pytest.fail(msg)
loki-ecmwf-0.3.6/loki/transformations/tests/test_drhook.py0000664000175000017500000001766715167130205024212 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import shutil
import pytest

from loki.batch import (
    Scheduler, SFilter, ProcedureItem, SchedulerConfig
)
from loki.frontend import available_frontends, OMNI
from loki.ir import FindNodes, CallStatement, Import
from loki.tools import gettempdir
from loki.transformations import DrHookTransformation


@pytest.fixture(scope='module', name='config')
def fixture_config():
    """
    Write default configuration as a temporary file and return
    the file path
    """
    default_config = {
        'default': {
            'role': 'kernel', 'expand': True, 'strict': False, 'disable': ['dr_hook', 'abor1']
        },
        'routines': {
            'rick_astley': {'role': 'driver'},
        }
    }
    return default_config


@pytest.fixture(scope='module', name='srcdir')
def fixture_srcdir():
    """
    Create a src directory in the temp directory
    """
    srcdir = gettempdir()/'test_dr_hook'
    if srcdir.exists():
        shutil.rmtree(srcdir)
    srcdir.mkdir()
    yield srcdir
    shutil.rmtree(srcdir)


@pytest.fixture(scope='module', name='source')
def fixture_source(srcdir):
    """
    Write some source files to use in the test
    """
    fcode_driver = """
subroutine rick_astley
    use parkind1, only: jprb
    use yomhook, only : lhook, dr_hook
    use rick_rolled, only : never_gonna_give
    implicit none

    real(kind=jprb) :: zhook_handle
    if (lhook) call dr_hook('rick_astley',0,zhook_handle)
    call never_gonna_give()
    if (lhook) call dr_hook('rick_astley',1,zhook_handle)
end subroutine
    """.strip()

    fcode_kernel = """
module rick_rolled
contains
subroutine never_gonna_give
    use parkind1, only: jprb
    use yomhook, only : lhook, dr_hook
    implicit none

    real(kind=jprb) :: zhook_handle
    if (lhook) call dr_hook('never_gonna_give',0,zhook_handle)

    CALL ABOR1('[SUBROUTINE CALL]')

    print *, 'never gonna let you down'

    if (dave) call abor1('[INLINE CONDITIONAL]')

    call never_gonna_run_around()

    WRITE(NULOUT,*) "[WRITE INTRINSIC]"
    if (.not. dave) WRITE(NULOUT, *) "[WRITE INTRINSIC]"

    if (lhook) call dr_hook('never_gonna_give',1,zhook_handle)

contains

subroutine never_gonna_run_around

    implicit none

    if (lhook) call dr_hook('never_gonna_run_around',0,zhook_handle)

    if (dave) call abor1('[INLINE CONDITIONAL]')
    WRITE(NULOUT,*) "[WRITE INTRINSIC]"
    if (.not. dave) WRITE(NULOUT, *) "[WRITE INTRINSIC]"

    if (lhook) call dr_hook('never_gonna_run_around',1,zhook_handle)

end subroutine never_gonna_run_around

end subroutine
subroutine i_hope_you_havent_let_me_down
    real(kind=jprb) :: zhook_handle
    if (lhook) call dr_hook('i_hope_you_havent_let_me_down',0,zhook_handle)

    if (lhook) call dr_hook('i_hope_you_havent_let_me_down',1,zhook_handle)
end subroutine i_hope_you_havent_let_me_down
end module rick_rolled
    """.strip()

    (srcdir/'rick_astley.F90').write_text(fcode_driver)
    (srcdir/'never_gonna_give.F90').write_text(fcode_kernel)

    yield srcdir

    (srcdir/'rick_astley.F90').unlink()
    (srcdir/'never_gonna_give.F90').unlink()


@pytest.mark.parametrize('frontend', available_frontends(
    xfail=[(OMNI, 'Incomplete source tree impossible with OMNI')]
))
def test_dr_hook_transformation(frontend, config, source, tmp_path):
    """Test DrHook transformation for a renamed Subroutine"""
    scheduler_config = SchedulerConfig.from_dict(config)
    scheduler = Scheduler(paths=source, config=scheduler_config, frontend=frontend, xmods=[tmp_path])
    scheduler.process(transformation=DrHookTransformation(suffix='you_up'))

    for item in SFilter(scheduler.sgraph, item_filter=ProcedureItem):
        drhook_calls = [
            call for call in FindNodes(CallStatement).visit(item.ir.ir)
            if call.name == 'dr_hook'
        ]
        assert len(drhook_calls) == 2
        drhook_imports = [
            imp for imp in FindNodes(Import).visit(item.ir.ir)
            if imp.module == 'yomhook'
        ]
        assert len(drhook_imports) == 1
        assert 'zhook_handle' in item.ir.variables
        if item.role == 'driver':
            assert all(
                str(call.arguments[0]).lower().strip("'") == item.local_name.lower()
                for call in drhook_calls
            )
        elif item.role == 'kernel':
            assert all(
                str(call.arguments[0]).lower().strip("'") == f'{item.local_name.lower()}_you_up'
                for call in drhook_calls
            )


@pytest.mark.parametrize('frontend', available_frontends(
    xfail=[(OMNI, 'Incomplete source tree impossible with OMNI')]
))
def test_dr_hook_transformation_remove(frontend, config, source, tmp_path):
    """Test DrHook transformation in remove mode"""
    scheduler_config = SchedulerConfig.from_dict(config)
    scheduler = Scheduler(paths=source, config=scheduler_config, frontend=frontend, xmods=[tmp_path])
    scheduler.process(transformation=DrHookTransformation(suffix='you_up', remove=True))

    for item in SFilter(scheduler.sgraph, item_filter=ProcedureItem):
        drhook_calls = [
            call for call in FindNodes(CallStatement).visit(item.ir.ir)
            if call.name == 'dr_hook'
        ]
        drhook_imports = [
            imp for imp in FindNodes(Import).visit(item.ir.ir)
            if imp.module == 'yomhook'
        ]
        for r in item.ir.members:
            drhook_calls += [
                call for call in FindNodes(CallStatement).visit(r.ir)
                if call.name == 'dr_hook'
            ]
            drhook_imports += [
                imp for imp in FindNodes(Import).visit(item.ir.ir)
                if imp.module == 'yomhook'
            ]
        if item.role == 'driver':
            assert len(drhook_calls) == 2
            assert len(drhook_imports) == 1
            assert 'zhook_handle' in item.ir.variables
            assert all(
                str(call.arguments[0]).lower().strip("'") == item.local_name.lower()
                for call in drhook_calls
            )
        elif item.role == 'kernel':
            assert not drhook_calls
            assert not drhook_imports
            assert 'zhook_handle' not in item.ir.variables


@pytest.mark.parametrize('frontend', available_frontends(
    xfail=[(OMNI, 'Incomplete source tree impossible with OMNI')]
))
def test_dr_hook_transformation_rename(frontend, config, source):
    """Test DrHook transformation in remove mode"""
    scheduler_config = SchedulerConfig.from_dict(config)
    scheduler = Scheduler(paths=source, config=scheduler_config, frontend=frontend)
    scheduler.process(
        transformation=DrHookTransformation(rename={
            'rick_astley': 'my_man_dave',
            'never_gonna_run_around': 'see_ya_later',
        }, kernel_only=False)
    )

    for item in SFilter(scheduler.sgraph, item_filter=ProcedureItem):
        drhook_calls = [
            call for call in FindNodes(CallStatement).visit(item.ir.ir)
            if call.name == 'dr_hook'
        ]
        if item.local_name == 'rick_astley':
            assert drhook_calls[0].arguments[0] == 'my_man_dave'
            assert drhook_calls[1].arguments[0] == 'my_man_dave'

        if item.local_name == 'never_gonna_give':
            assert drhook_calls[0].arguments[0] == 'never_gonna_give'
            assert drhook_calls[1].arguments[0] == 'never_gonna_give'

            assert len(item.ir.members) == 1
            inner = item.ir.members[0]
            inner_calls = [
                call for call in FindNodes(CallStatement).visit(inner.ir)
                if call.name == 'dr_hook'
            ]
            assert inner_calls[0].arguments[0] == 'see_ya_later'
            assert inner_calls[1].arguments[0] == 'see_ya_later'
loki-ecmwf-0.3.6/loki/transformations/tests/test_ecwam.py0000664000175000017500000001057415167130205024006 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import os
import resource
from subprocess import CalledProcessError
from pathlib import Path
import pytest

from loki.tools import (
    execute, write_env_launch_script, local_loki_setup, local_loki_cleanup
)
from loki.frontend import HAVE_FP

pytestmark = pytest.mark.skipif('ECWAM_DIR' not in os.environ, reason='ECWAM_DIR not set')

@pytest.fixture(scope='module', name='here')
def fixture_here():
    return Path(os.environ['ECWAM_DIR'])


@pytest.fixture(scope='module', name='local_loki_bundle')
def fixture_local_loki_bundle(here):
    """Call setup utilities for injecting ourselves into the ECWAM bundle"""
    lokidir, target, backup = local_loki_setup(here)
    yield lokidir
    local_loki_cleanup(target, backup)


@pytest.fixture(scope='module', name='bundle_create')
def fixture_bundle_create(here, local_loki_bundle):
    """Inject ourselves into the ECWAM bundle"""
    env = os.environ.copy()
    env['ECWAM_BUNDLE_LOKI_DIR'] = local_loki_bundle

    # Run ecbundle to fetch dependencies
    execute(
        ['./package/bundle/ecwam-bundle', 'create', '--bundle', 'package/bundle/bundle.yml'],
        cwd=here,
        silent=False, env=env
    )


@pytest.mark.usefixtures('bundle_create')
@pytest.mark.skipif(not HAVE_FP, reason="FP needed for ECWAM parsing")
@pytest.mark.parametrize('mode', ['idem', 'idem-stack', 'scc', 'scc-stack', 'scc-hoist'])
def test_ecwam(here, mode, tmp_path):
    build_dir = tmp_path/'build'
    build_cmd = [
        './package/bundle/ecwam-bundle', 'build', '--clean',
        '--with-loki', '--without-loki-install', '--loki-mode', mode,
        '--build-dir', str(build_dir)
    ]

    if 'ECWAM_ARCH' in os.environ:
        build_cmd += [f"--arch={os.environ['ECWAM_ARCH']}"]
    else:
        # Build without OpenACC support as this makes problems
        # with older versions of GNU
        build_cmd += ['--cmake=ENABLE_ACC=OFF']

    execute(build_cmd, cwd=here, silent=False)

    # Raise stack limit
    resource.setrlimit(resource.RLIMIT_STACK, (resource.RLIM_INFINITY, resource.RLIM_INFINITY))
    env = os.environ.copy()
    env.update({'OMP_STACKSIZE': '2G', 'NVCOMPILER_ACC_CUDA_HEAPSIZE': '2G', 'DEV_ALLOC_SIZE': '2147483648'})

    # create rundir
    rundir = build_dir/'wamrun_48'
    os.mkdir(rundir)

    # Run pre-processing steps
    preprocs = [
        ('ecwam-run-preproc', '--run-dir=wamrun_48', f'--config={here}/source/ecwam/tests/etopo1_oper_an_fc_O48.yml'),
        ('ecwam-run-preset', '--run-dir=wamrun_48')
    ]

    failures = {}
    for preproc, *args in preprocs:
        script = write_env_launch_script(tmp_path, preproc, args)

        # Run the script and verify error norms
        try:
            execute([str(script)], cwd=build_dir, silent=False, env=env)
        except CalledProcessError as err:
            failures[preproc] = err.returncode

    if failures:
        msg = '\n'.join([f'Non-zero return code {rcode} in {p}' for p, rcode in failures.items()])
        pytest.fail(msg)

    # Run the produced binary
    binary = 'ecwam-run-model'
    args = ('--run-dir=wamrun_48',)

    # Write a script to source env.sh and launch the binary
    script = write_env_launch_script(tmp_path, binary, args)

    def get_logs():
        logs = rundir.glob('**/*.log')
        return '\n'.join(
            (
                f'-------------------------------------------------------\n{log}:\n\n'
                + log.read_text()
            )
            for log in logs
        )

    # Run the script and verify error norms
    failure = None
    try:
        execute([str(script)], cwd=build_dir, silent=False, env=env)
    except CalledProcessError as err:
        pytest.fail(f'{binary}: Failed with error code: {err.returncode}\n{get_logs()}')

    with open(build_dir/"wamrun_48/logs/model/stdout.log") as reader:
        lines = list(reader)

    if 'Validation FAILED' in lines[-1]:
        failure = 'Validation failed'
    elif not 'Validation PASSED' in lines[-1]:
        failure = 'Validation check never run'

    if failure:
        pytest.fail(f'{binary}: {failure}\n{get_logs()}')
loki-ecmwf-0.3.6/loki/transformations/tests/test_scc_cuf.py0000664000175000017500000004052215167130205024313 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from pathlib import Path
import pytest

from loki import Scheduler, Subroutine, Dimension, Module
from loki.expression import symbols as sym
from loki.frontend import available_frontends
from loki.ir import (
    FindNodes, FindVariables, Loop, Assignment, CallStatement,
    Allocation, Deallocation, VariableDeclaration, Import, Pragma
)

from loki.transformations.pragma_model import PragmaModelTransformation
from loki.transformations.parametrise import ParametriseTransformation
from loki.transformations.temporaries.hoist_variables import HoistTemporaryArraysAnalysis
from loki.transformations.single_column import (
    HoistTemporaryArraysDeviceAllocatableTransformation,
    HoistTemporaryArraysPragmaOffloadTransformation, SCCLowLevelCuf
)


@pytest.fixture(scope='module', name='horizontal')
def fixture_horizontal():
    return Dimension(name='horizontal', size='nlon', index='jl', bounds=('start', 'iend'))


@pytest.fixture(scope='module', name='vertical')
def fixture_vertical():
    return Dimension(name='vertical', size='nz', index='jk')


@pytest.fixture(scope='module', name='blocking')
def fixture_blocking():
    return Dimension(name='blocking', size='nb', index='b')


@pytest.fixture(scope='module', name='here')
def fixture_here():
    return Path(__file__).parent


@pytest.fixture(name='config')
def fixture_config():
    """
    Default configuration dict with basic options.
    """
    return {
        'default': {
            'mode': 'idem',
            'role': 'kernel',
            'expand': True,
            'strict': False,  # cudafor import
        },
        'routines': {
            'driver': {'role': 'driver'}
        }
    }


def check_subroutine_driver(routine, blocking, disable=()):
    # use of "use cudafor"
    imports = [_import.module.lower() for _import in FindNodes(Import).visit(routine.spec)]
    assert "cudafor" in imports
    # device arrays
    # device arrays: declaration
    arrays = [var for var in routine.variables if isinstance(var, sym.Array)]
    device_arrays = [array for array in arrays if "_d" in array.name[-2::]]
    array_map = {}
    for device_array in device_arrays:
        for array in arrays:
            if device_array.name.replace("_d", "") == array.name:
                array_map[device_array] = array
    assert len(arrays)
    assert len(device_arrays)
    _declarations = FindNodes(VariableDeclaration).visit(routine.spec)
    declarations = []
    for _decl in _declarations:
        declarations.extend(_decl.symbols)
    for array in arrays:
        assert array in declarations
        if "_d" in array.name[-2::]:
            assert array.type.allocatable
            assert array.type.device
    # device arrays: allocation and deallocation
    _allocations = FindNodes(Allocation).visit(routine.body)
    allocations = []
    for _allocation in _allocations:
        allocations.extend(_allocation.variables)
    _de_allocations = FindNodes(Deallocation).visit(routine.body)
    de_allocations = []
    for _de_allocation in _de_allocations:
        de_allocations.extend(_de_allocation.variables)
    for device_array in device_arrays:
        assert device_array.name in [_.name for _ in allocations]
        assert device_array.name in [_.name for _ in de_allocations]
    # device arrays: copy device to host and host to device
    assignments = FindNodes(Assignment).visit(routine.body)
    cuda_device_synchronize = sym.InlineCall(
        function=sym.ProcedureSymbol(name="cudaDeviceSynchronize", scope=routine),
        parameters=())
    assert cuda_device_synchronize in [assignment.rhs for assignment in assignments]
    for device_array in device_arrays:
        if array_map[device_array].type.intent == "inout":
            assert Assignment(lhs=device_array.clone(dimensions=None),
                              rhs=array_map[device_array].clone(dimensions=None)) in assignments
            assert Assignment(rhs=device_array.clone(dimensions=None),
                              lhs=array_map[device_array].clone(dimensions=None)) in assignments
        elif array_map[device_array].type.intent == "in":
            assert Assignment(lhs=device_array.clone(dimensions=None),
                              rhs=array_map[device_array].clone(dimensions=None)) in assignments
        elif array_map[device_array].type.intent == "out":
            assert Assignment(rhs=device_array.clone(dimensions=None),
                              lhs=array_map[device_array].clone(dimensions=None)) in assignments
    # definition of block and griddim
    assert "GRIDDIM" in routine.variables
    assert "BLOCKDIM" in routine.variables
    # kernel launch configuration
    calls = [call for call in FindNodes(CallStatement).visit(routine.body) if str(call.name) not in disable]
    for call in calls:
        assert call.chevron[0] == "GRIDDIM"
        assert call.chevron[1] == "BLOCKDIM"
        assert blocking.size in call.arguments


def _check_subroutine_kernel(routine, horizontal, vertical, blocking):
    # use of "use cudafor"
    imports = [_import.module for _import in FindNodes(Import).visit(routine.spec)]
    assert "cudafor" in imports
    # if statement around body
    assert blocking.size in routine.arguments
    # loop structure
    loops = FindNodes(Loop).visit(routine.body)
    loop_variables = [loop.variable for loop in loops]
    assert horizontal.index not in loop_variables
    assert vertical.index in loop_variables
    argument_arrays = [arg for arg in routine.arguments if isinstance(arg, sym.Array)]
    for argument_array in argument_arrays:
        dims = FindVariables().visit(argument_array.dimensions)
        assert blocking.index in dims or blocking.size in dims

def check_subroutine_kernel(routine, horizontal, vertical, blocking):
    _check_subroutine_kernel(routine=routine, horizontal=horizontal, vertical=vertical, blocking=blocking)
    assert "ATTRIBUTES(GLOBAL)" in routine.prefix
    assignments = FindNodes(Assignment).visit(routine.body)
    assert "THREADIDX%X" in [_.rhs for _ in assignments]
    assert "BLOCKIDX%Z" in [_.rhs for _ in assignments]


def check_subroutine_device(routine, horizontal, vertical, blocking):
    _check_subroutine_kernel(routine=routine, horizontal=horizontal, vertical=vertical, blocking=blocking)
    assert "ATTRIBUTES(DEVICE)" in routine.prefix
    assert horizontal.index in routine.arguments
    assert blocking.index in routine.arguments


def check_subroutine_elemental_device(routine):
    assert "ATTRIBUTES(DEVICE)" in routine.prefix
    assert "ELEMENTAL" not in routine.prefix


@pytest.mark.parametrize('frontend', available_frontends())
def test_scc_cuf_simple(frontend, horizontal, vertical, blocking, tmp_path):

    fcode_driver = """
  SUBROUTINE driver(nlon, nz, nb, tot, q, t, z)
  use kernel_mod, only: kernel
    INTEGER, INTENT(IN)   :: nlon, nz, nb  ! Size of the horizontal and vertical
    INTEGER, INTENT(IN)   :: tot
    REAL, INTENT(INOUT)   :: t(nlon,nz,nb)
    REAL, INTENT(INOUT)   :: q(nlon,nz,nb)
    REAL, INTENT(INOUT)   :: z(nlon,nz+1,nb)
    INTEGER :: b, start, iend, ibl, icend

    start = 1
    iend = tot
    do b=1,iend,nlon
      ibl = (b-1)/nlon+1
      icend = MIN(nlon,tot-b+1)
      call kernel(start, icend, nlon, nz, q(:,:,b), t(:,:,b), z(:,:,b))
    end do
  END SUBROUTINE driver
"""

    fcode_kernel = """
module kernel_mod
implicit none
contains
  SUBROUTINE kernel(start, iend, nlon, nz, q, t, z)
    implicit none
    INTEGER, INTENT(IN) :: start, iend  ! Iteration indices
    INTEGER, INTENT(IN) :: nlon, nz    ! Size of the horizontal and vertical
    REAL, INTENT(INOUT) :: t(nlon,nz)
    REAL, INTENT(INOUT) :: q(nlon,nz)
    REAL, INTENT(INOUT) :: z(nlon,nz)
    INTEGER :: jl, jk
    REAL :: c

    c = 5.345
    DO jk = 2, nz
      DO jl = start, iend
        t(jl, jk) = c * jk
        q(jl, jk) = q(jl, jk-1) + t(jl, jk) * c
      END DO
    END DO

    DO jk = 2, nz
      DO jl = start, iend
        z(jl, jk) = 0.0
      END DO
    END DO

  END SUBROUTINE kernel
end module kernel_mod
"""
    kernel_mod = Module.from_source(fcode_kernel, frontend=frontend, xmods=[tmp_path])
    driver = Subroutine.from_source(fcode_driver, frontend=frontend, definitions=kernel_mod, xmods=[tmp_path])
    kernel = kernel_mod['kernel']

    cuf_transform = SCCLowLevelCuf(
        horizontal=horizontal, vertical=vertical, block_dim=blocking,
        dim_vars=(vertical.size,), as_kwarguments=True, remove_vector_section=True
    )

    cuf_transform.apply(driver, role='driver', targets=['kernel'])
    cuf_transform.apply(kernel, role='kernel')

    check_subroutine_driver(routine=driver, blocking=blocking)
    check_subroutine_kernel(routine=kernel, horizontal=horizontal, vertical=vertical, blocking=blocking)


@pytest.mark.parametrize('frontend', available_frontends())
def test_scc_cuf_parametrise(here, frontend, config, horizontal, vertical, blocking, tmp_path):
    """
    Test SCC-CUF transformation type 0, thus including parametrising (array dimension(s))
    """

    proj = here / 'sources/projSccCuf/module'

    scheduler = Scheduler(paths=[proj], config=config, seed_routines=['driver'], frontend=frontend, xmods=[tmp_path])

    cuf_transform = SCCLowLevelCuf(
        horizontal=horizontal, vertical=vertical, block_dim=blocking,
        transformation_type='parametrise',
        dim_vars=(vertical.size,), as_kwarguments=True, remove_vector_section=True
    )
    scheduler.process(transformation=cuf_transform)

    dic2p = {'nz': 137}
    scheduler.process(transformation=ParametriseTransformation(dic2p=dic2p))
    scheduler.process(PragmaModelTransformation(directive='openacc'))

    # check for correct CUF transformation
    check_subroutine_driver(routine=scheduler["driver_mod#driver"].ir, blocking=blocking)
    check_subroutine_kernel(routine=scheduler["kernel_mod#kernel"].ir, horizontal=horizontal,
                            vertical=vertical, blocking=blocking)
    check_subroutine_device(routine=scheduler["kernel_mod#device"].ir, horizontal=horizontal,
                            vertical=vertical, blocking=blocking)
    check_subroutine_elemental_device(routine=scheduler["kernel_mod#elemental_device"].ir)

    # check for parametrised variables
    vars2p = list(dic2p.keys())
    routine_parameters = [var for var in scheduler["driver_mod#driver"].ir.variables
                          if var.type.parameter]
    assert routine_parameters == vars2p
    routine_parameters = [var for var in scheduler["kernel_mod#kernel"].ir.variables
                          if var.type.parameter]
    assert routine_parameters == vars2p
    routine_parameters = [var for var in scheduler["kernel_mod#device"].ir.variables
                          if var.type.parameter]
    assert routine_parameters == vars2p

    # local arrays
    routine = scheduler["kernel_mod#kernel"].ir
    argument_arrays = [arg for arg in routine.arguments if isinstance(arg, sym.Array)]
    local_arrays = [var for var in routine.variables if isinstance(var, sym.Array) and var not in argument_arrays]
    for local_array in local_arrays:
        assert local_array.type.device
    routine = scheduler["kernel_mod#device"].ir
    argument_arrays = [arg for arg in routine.arguments if isinstance(arg, sym.Array)]
    local_arrays = [var for var in routine.variables if isinstance(var, sym.Array) and var not in argument_arrays]
    for local_array in local_arrays:
        assert local_array.type.device


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('hoist_synthesis', (
    HoistTemporaryArraysDeviceAllocatableTransformation(as_kwarguments=True),
    HoistTemporaryArraysPragmaOffloadTransformation(as_kwarguments=True))
)
def test_scc_cuf_hoist(here, frontend, config, horizontal, vertical, blocking, hoist_synthesis, tmp_path):
    """
    Test SCC-CUF transformation type 1, thus including host side hoisting
    """

    proj = here / 'sources/projSccCuf/module'

    scheduler = Scheduler(paths=[proj], config=config, seed_routines=['driver'], frontend=frontend, xmods=[tmp_path])

    cuf_transform = SCCLowLevelCuf(
        horizontal=horizontal, vertical=vertical, block_dim=blocking,
        transformation_type='hoist',
        dim_vars=(vertical.size,), as_kwarguments=True, remove_vector_section=True
    )
    scheduler.process(transformation=cuf_transform)

    # Transformation: Analysis
    scheduler.process(transformation=HoistTemporaryArraysAnalysis())
    # Transformation: Synthesis
    scheduler.process(transformation=hoist_synthesis)
    scheduler.process(PragmaModelTransformation(directive='openacc'))
    check_subroutine_driver(routine=scheduler["driver_mod#driver"].ir, blocking=blocking)
    check_subroutine_kernel(routine=scheduler["kernel_mod#kernel"].ir, horizontal=horizontal,
                            vertical=vertical, blocking=blocking)
    check_subroutine_device(routine=scheduler["kernel_mod#device"].ir, horizontal=horizontal,
                            vertical=vertical, blocking=blocking)
    check_subroutine_elemental_device(routine=scheduler["kernel_mod#elemental_device"].ir)

    # check driver
    driver_routine = scheduler["driver_mod#driver"].ir
    assert 'kernel_local_z' in driver_routine.variable_map
    assert 'device_local_x' in driver_routine.variable_map
    if isinstance(hoist_synthesis, HoistTemporaryArraysDeviceAllocatableTransformation):
        assert driver_routine.variable_map['kernel_local_z'].type.device
        assert driver_routine.variable_map['device_local_x'].type.device
        assert driver_routine.variable_map['kernel_local_z'].shape == ('nlon', 'nz', 'nb')
        assert driver_routine.variable_map['device_local_x'].shape == ('nlon', 'nz', 'nb')
    elif isinstance(hoist_synthesis, HoistTemporaryArraysPragmaOffloadTransformation):
        assert driver_routine.variable_map['kernel_local_z'].type.device is None
        assert driver_routine.variable_map['device_local_x'].type.device is None
        assert driver_routine.variable_map['kernel_local_z'].shape == ('nlon', 'nz', 'nb')
        assert driver_routine.variable_map['device_local_x'].shape == ('nlon', 'nz', 'nb')
        pragmas = FindNodes(Pragma).visit(driver_routine.body)
        assert pragmas[0].keyword == 'acc'
        assert 'enter data create' in pragmas[0].content.lower()
        assert 'kernel_local_z' in pragmas[0].content.lower()
        assert 'device_local_x' in pragmas[0].content.lower()
        assert pragmas[1].keyword == 'acc'
        assert 'exit data delete' in pragmas[1].content.lower()
        assert 'kernel_local_z' in pragmas[1].content.lower()
        assert 'device_local_x' in pragmas[1].content.lower()
    else:
        raise ValueError
    for call in FindNodes(CallStatement).visit(scheduler["driver_mod#driver"].ir.body):
        argnames = [arg.name.lower() for arg in call.arguments] + [elem[1] for elem in call.kwarguments]
        assert 'kernel_local_z' in argnames
        assert 'device_local_x' in argnames
    # check kernel
    argnames = [arg.name.lower() for arg in scheduler["kernel_mod#kernel"].ir.arguments]
    assert 'local_z' in argnames
    assert 'device_local_x' in argnames
    calls = [call for call in FindNodes(CallStatement).visit(scheduler["kernel_mod#kernel"].ir.body)
             if str(call.name) == "DEVICE"]
    for call in calls:
        assert 'DEVICE_local_x' in [elem[1] for elem in call.kwarguments] # call.arguments
    # check device
    assert all(_ in [arg.name for arg in scheduler["kernel_mod#device"].ir.arguments]
               for _ in ['local_x'])

    # local arrays
    routine = scheduler["kernel_mod#kernel"].ir
    local_arrays = [routine.variable_map["local_z"]]
    for local_array in local_arrays:
        assert local_array.type.intent == 'inout'
        dims = FindVariables().visit(local_array.dimensions)
        assert horizontal.size in dims
        assert vertical.size in dims
        assert blocking.size in dims
    routine = scheduler["kernel_mod#device"].ir
    local_arrays = [routine.variable_map["local_x"]]
    for local_array in local_arrays:
        assert local_array.type.intent == 'inout'
        dims = FindVariables().visit(local_array.dimensions)
        assert horizontal.size in dims
        assert vertical.size in dims
        assert blocking.size in dims
loki-ecmwf-0.3.6/loki/transformations/tests/test_split_read_write.py0000664000175000017500000001346315167130205026252 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki import Dimension, Subroutine
from loki.ir import (
    nodes as ir, FindNodes, FindVariables, pragma_regions_attached,
    is_loki_pragma
)
from loki.frontend import available_frontends
from loki.transformations import SplitReadWriteTransformation


@pytest.fixture(scope='module', name='horizontal')
def fixture_horizontal():
    return Dimension(name='horizontal', size='nlon', index='jl', bounds=('start', 'end'), aliases=('nproma',))

@pytest.fixture(scope='module', name='vertical')
def fixture_vertical():
    return Dimension(name='vertical', size='nz', index='jk', aliases=('nlev',))


@pytest.mark.parametrize('frontend', available_frontends())
def test_split_read_write(frontend, horizontal, vertical):
    """
    Test pragma-assisted splitting of reads and writes.
    """

    fcode = """
subroutine kernel(nlon, nz, start, end, n1, n2, n3, var0, var1, var2, nfre)
  implicit none

  integer, intent(in) :: nlon, nz, n1, n2, n3, start, end, nfre
  real, intent(inout) :: var0(nlon,nfre,6), var1(nlon, nz, 6), var2(nlon,nz)
  integer :: jl, jk, m

  !$loki split-read-write
  do jk = 1,nz
    do jl = start,end
       var1(jl, jk, n1) = var1(jl, jk, n1) + 1.
       var1(jl, jk, n1) = var1(jl, jk, n1) * 2.
       var1(jl, jk, n2) = var1(jl, jk, n2) + var1(jl, jk, n1)
       var2(jl, jk    ) = 0.
    end do
  end do
  print *, "a leaf node that shouldn't be copied"
  !$loki end split-read-write

  !.....should be transformed to........
  !!$loki split-read-write
  !  do jk=1,nz
  !    do jl=start,end
  !      loki_temp_0(jl, jk) = var1(jl, jk, n1) + 1.
  !      loki_temp_0(jl, jk) = loki_temp_0(jl, jk)*2.
  !      loki_temp_1(jl, jk) = var1(jl, jk, n2) + loki_temp_0(jl, jk)
  !      var2(jl, jk) = 0.
  !    end do
  !  end do
  !  print *, 'a leaf node that shouldn''t be copied'
  !  do jk=1,nz
  !    do jl=start,end
  !      var1(jl, jk, n1) = loki_temp_0(jl, jk)
  !      var1(jl, jk, n2) = loki_temp_1(jl, jk)
  !    end do
  !  end do
  !!$loki end split-read-write

  do m = 1,nfre
  !$loki split-read-write
     if( m < nfre/2 )then
        do jl = start,end
           var0(jl, m, n3) = var0(jl, m, n3) + 1.
        end do
     endif
  !$loki end split-read-write
  !.....should be transformed to........
  !!$loki split-read-write
  !  if (m < nfre / 2) then
  !   do jl=start,end
  !     loki_temp_2(jl) = var0(jl, m, n3) + 1.
  !   end do
  !  end if
  !  if (m < nfre / 2) then
  !    do jl=start,end
  !      var0(jl, m, n3) = loki_temp_2(jl)
  !    end do
  !  end if
  !!$loki end split-read-write
  end do

end subroutine kernel
"""

    routine = Subroutine.from_source(fcode, frontend=frontend)
    SplitReadWriteTransformation(dimensions=(horizontal, vertical)).apply(routine)

    with pragma_regions_attached(routine):

        pragma_regions = FindNodes(ir.PragmaRegion).visit(routine.body)
        assert len(pragma_regions) == 2

        #=========== check first pragma region ==============#
        region = pragma_regions[0]
        assert is_loki_pragma(region.pragma, starts_with='split-read-write')

        # check that temporaries were declared
        assert 'loki_temp_0(nlon,nz)' in routine.variables
        assert 'loki_temp_1(nlon,nz)' in routine.variables

        # check correctly nested loops
        outer_loops = [l for l in FindNodes(ir.Loop).visit(region.body) if l.variable == 'jk']
        assert len(outer_loops) == 2
        for loop in outer_loops:
            _loops = FindNodes(ir.Loop).visit(loop.body)
            assert len(_loops) == 1
            assert _loops[0].variable == 'jl'

        # check simple assignment is only in first copy of region
        assert 'var2(jl,jk)' in FindVariables().visit(outer_loops[0])
        assert not 'var2(jl,jk)' in FindVariables().visit(outer_loops[1])

        # check print statement is only present in first copy of region
        assert len(FindNodes(ir.Intrinsic).visit(region)) == 1

        # check correctness of split reads
        assigns = FindNodes(ir.Assignment).visit(outer_loops[0].body)
        assert len(assigns) == 4
        assert assigns[0].lhs == assigns[1].lhs
        assert assigns[1].rhs == f'{assigns[0].lhs}*2.'
        assert assigns[2].lhs != assigns[0].lhs
        assert assigns[2].lhs.dimensions == assigns[0].lhs.dimensions
        assert f'{assigns[0].lhs}' in assigns[2].rhs

        # check correctness of split writes
        _assigns = FindNodes(ir.Assignment).visit(outer_loops[1].body)
        assert len(_assigns) == 2
        assert _assigns[0].lhs == 'var1(jl, jk, n1)'
        assert _assigns[1].lhs == 'var1(jl, jk, n2)'
        assert _assigns[0].rhs == assigns[0].lhs
        assert _assigns[1].rhs == assigns[2].lhs


        #=========== check second pragma region ==============#
        region = pragma_regions[1]
        assert is_loki_pragma(region.pragma, starts_with='split-read-write')

        conds = FindNodes(ir.Conditional).visit(region.body)
        assert len(conds) == 2

        # check that temporaries were declared
        assert 'loki_temp_2(nlon)' in routine.variables

        # check correctness of split reads
        assigns = FindNodes(ir.Assignment).visit(conds[0])
        assert len(assigns) == 1
        assert assigns[0].lhs == 'loki_temp_2(jl)'
        assert 'var0(jl, m, n3)' in assigns[0].rhs

        # check correctness of split writes
        assigns = FindNodes(ir.Assignment).visit(conds[1])
        assert len(assigns) == 1
        assert assigns[0].lhs == 'var0(jl, m, n3)'
        assert assigns[0].rhs == 'loki_temp_2(jl)'
loki-ecmwf-0.3.6/loki/transformations/tests/test_transform_region.py0000664000175000017500000002234315167130205026265 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest
import numpy as np

from loki import Subroutine, FindNodes, Loop
from loki.jit_build import jit_compile
from loki.expression import symbols as sym
from loki.ir import Assignment
from loki.frontend import available_frontends

from loki.transformations.transform_region import region_hoist


def loop_variables(node):
    return [loop.variable for loop in FindNodes(Loop).visit(node)]


def loop_symbols(node):
    return [
        (loop.variable, loop.bounds.start, loop.bounds.stop, loop.bounds.step)
        for loop in FindNodes(Loop).visit(node)
    ]


def assignment_symbols(node):
    return [(assign.lhs, assign.rhs) for assign in FindNodes(Assignment).visit(node)]


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_region_hoist(frontend):
    """
    A very simple hoisting example
    """
    fcode = """
subroutine transform_region_hoist(a, b, c)
  integer, intent(out) :: a, b, c

  a = 5

!$loki region-hoist target

  a = 1

!$loki region-hoist
  b = a
!$loki end region-hoist

  c = a + b
end subroutine transform_region_hoist
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)
    assert assignment_symbols(routine.body) == [('a', '5'), ('a', '1'), ('b', 'a'), ('c', 'a + b')]

    region_hoist(routine)
    assert assignment_symbols(routine.body) == [('a', '5'), ('b', 'a'), ('a', '1'), ('c', 'a + b')]


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_region_hoist_inlined_pragma(tmp_path, frontend):
    """
    Hoisting when pragmas are potentially inlined into other nodes.
    """
    fcode = """
subroutine transform_region_hoist_inlined_pragma(a, b, klon, klev)
  integer, intent(inout) :: a(klon, klev), b(klon, klev)
  integer, intent(in) :: klon, klev
  integer :: jk, jl

!$loki region-hoist target

  do jl=1,klon
    a(jl, 1) = jl
  end do

  do jk=2,klev
    do jl=1,klon
      a(jl, jk) = a(jl, jk-1)
    end do
  end do

!$loki region-hoist
  do jk=1,klev
    b(1, jk) = jk
  end do
!$loki end region-hoist

  do jk=1,klev
    do jl=2,klon
      b(jl, jk) = b(jl-1, jk)
    end do
  end do
end subroutine transform_region_hoist_inlined_pragma
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/(f'{routine.name}_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname=routine.name)
    klon, klev = 32, 100
    ref_a = np.array([[jl + 1] * klev for jl in range(klon)], order='F')
    ref_b = np.array([[jk + 1 for jk in range(klev)] for _ in range(klon)], order='F')

    # Test the reference solution
    a = np.zeros(shape=(klon, klev), order='F', dtype=np.int32)
    b = np.zeros(shape=(klon, klev), order='F', dtype=np.int32)
    function(a=a, b=b, klon=klon, klev=klev)
    assert np.all(a == ref_a)
    assert np.all(b == ref_b)

    loops = FindNodes(Loop).visit(routine.body)
    assert len(loops) == 6
    assert loop_variables(routine.body) == ['jl', 'jk', 'jl', 'jk', 'jk', 'jl']

    # Apply transformation
    region_hoist(routine)

    loops = FindNodes(Loop).visit(routine.body)
    assert len(loops) == 6
    assert loop_variables(routine.body) == ['jk', 'jl', 'jk', 'jl', 'jk', 'jl']

    hoisted_filepath = tmp_path/(f'{routine.name}_hoisted_{frontend}.f90')
    hoisted_function = jit_compile(routine, filepath=hoisted_filepath, objname=routine.name)

    # Test transformation
    klon, klev = 32, 100
    a = np.zeros(shape=(klon, klev), order='F', dtype=np.int32)
    b = np.zeros(shape=(klon, klev), order='F', dtype=np.int32)
    hoisted_function(a=a, b=b, klon=klon, klev=klev)
    assert np.all(a == ref_a)
    assert np.all(b == ref_b)


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_region_hoist_multiple(frontend):
    """
    Test hoisting with multiple groups and multiple regions per group
    """
    fcode = """
subroutine transform_region_hoist_multiple(a, b, c)
  integer, intent(out) :: a, b, c

  a = 1

!$loki region-hoist target
!$loki region-hoist target group(some-group)

  a = a + 1
  a = a + 1
!$loki region-hoist group(some-group)
  a = a + 1
!$loki end region-hoist
  a = a + 1

!$loki region-hoist
  b = a
!$loki end region-hoist

!$loki region-hoist group(some-group)
  c = a + b
!$loki end region-hoist
end subroutine transform_region_hoist_multiple
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)
    assert assignment_symbols(routine.body) == [
        ('a', '1'), ('a', 'a + 1'), ('a', 'a + 1'), ('a', 'a + 1'),
        ('a', 'a + 1'), ('b', 'a'), ('c', 'a + b')
    ]

    region_hoist(routine)
    assert assignment_symbols(routine.body) == [
        ('a', '1'), ('b', 'a'), ('a', 'a + 1'), ('c', 'a + b'),
        ('a', 'a + 1'), ('a', 'a + 1'), ('a', 'a + 1')
    ]


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_region_hoist_collapse(frontend):
    """
    Use collapse with region-hoist.
    """
    fcode = """
subroutine transform_region_hoist_collapse(a, b, klon, klev)
  integer, intent(inout) :: a(klon, klev), b(klon, klev)
  integer, intent(in) :: klon, klev
  integer :: jk, jl

!$loki region-hoist target

  do jl=1,klon
    a(jl, 1) = jl
  end do

  do jk=2,klev
    do jl=1,klon
      a(jl, jk) = a(jl, jk-1)
    end do
  end do

  do jk=1,klev
!$loki region-hoist collapse(1)
    b(1, jk) = jk
!$loki end region-hoist
  end do

  do jk=1,klev
    do jl=2,klon
      b(jl, jk) = b(jl-1, jk)
    end do
  end do
end subroutine transform_region_hoist_collapse
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    loops = FindNodes(Loop).visit(routine.body)
    assert len(loops) == 6
    assert loop_variables(routine.body) == ['jl', 'jk', 'jl', 'jk', 'jk', 'jl']
    assert assignment_symbols(routine.body) == [
        ('a(jl, 1)', 'jl'), ('a(jl, jk)', 'a(jl, jk - 1)'),
        ('b(1, jk)', 'jk'), ('b(jl, jk)', 'b(jl - 1, jk)')
    ]

    region_hoist(routine)

    loops = FindNodes(Loop).visit(routine.body)
    assert len(loops) == 7
    assert loop_variables(routine.body) == ['jk', 'jl', 'jk', 'jl', 'jk', 'jk', 'jl']
    assert loop_symbols(routine.body) == [
        ('jk', 1, 'klev', None), ('jl', 1, 'klon', None),
        ('jk', 2, 'klev', None), ('jl', 1, 'klon', None),
        ('jk', 1, 'klev', None), ('jk', 1, 'klev', None),
        ('jl', 2, 'klon', None)
    ]
    assert assignment_symbols(routine.body) == [
        ('b(1, jk)', 'jk'), ('a(jl, 1)', 'jl'), ('a(jl, jk)', 'a(jl, jk - 1)'), ('b(jl, jk)', 'b(jl - 1, jk)')
    ]


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_region_hoist_promote(tmp_path, frontend):
    """
    Use collapse with region-hoist.
    """
    fcode = """
subroutine transform_region_hoist_promote(a, b, klon, klev)
  integer, intent(inout) :: a(klon, klev), b(klon, klev)
  integer, intent(in) :: klon, klev
  integer :: jk, jl, b_tmp

!$loki region-hoist target

  do jl=1,klon
    a(jl, 1) = jl
  end do

  do jk=2,klev
    do jl=1,klon
      a(jl, jk) = a(jl, jk-1)
    end do
  end do

  do jk=1,4
    b(1, jk) = jk
  end do

  do jk=5,klev
!$loki region-hoist collapse(1) promote(b_tmp)
    b_tmp = jk + 1
!$loki end region-hoist
    b(1, jk) = b_tmp - 1
  end do

  do jk=1,klev
    do jl=2,klon
      b(jl, jk) = b(jl-1, jk)
    end do
  end do
end subroutine transform_region_hoist_promote
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/(f'{routine.name}_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname=routine.name)
    klon, klev = 32, 100
    ref_a = np.array([[jl + 1] * klev for jl in range(klon)], order='F')
    ref_b = np.array([[jk + 1 for jk in range(klev)] for _ in range(klon)], order='F')

    # Test the reference solution
    a = np.zeros(shape=(klon, klev), order='F', dtype=np.int32)
    b = np.zeros(shape=(klon, klev), order='F', dtype=np.int32)
    function(a=a, b=b, klon=klon, klev=klev)
    assert np.all(a == ref_a)
    assert np.all(b == ref_b)

    loops = FindNodes(Loop).visit(routine.body)
    assert len(loops) == 7
    assert loop_variables(routine.body) == ['jl', 'jk', 'jl', 'jk', 'jk', 'jk', 'jl']

    assert isinstance(routine.variable_map['b_tmp'], sym.Scalar)

    # Apply transformation
    region_hoist(routine)

    loops = FindNodes(Loop).visit(routine.body)
    assert len(loops) == 8
    assert loop_variables(routine.body) == ['jk', 'jl', 'jk', 'jl', 'jk', 'jk', 'jk', 'jl']

    b_tmp = routine.variable_map['b_tmp']
    assert isinstance(b_tmp, sym.Array) and len(b_tmp.type.shape) == 1
    assert b_tmp.type.shape[0] == 'klev'

    hoisted_filepath = tmp_path/(f'{routine.name}_hoisted_{frontend}.f90')
    hoisted_function = jit_compile(routine, filepath=hoisted_filepath, objname=routine.name)

    # Test transformation
    klon, klev = 32, 100
    a = np.zeros(shape=(klon, klev), order='F', dtype=np.int32)
    b = np.zeros(shape=(klon, klev), order='F', dtype=np.int32)
    hoisted_function(a=a, b=b, klon=klon, klev=klev)
    assert np.all(a == ref_a)
    assert np.all(b == ref_b)
loki-ecmwf-0.3.6/loki/transformations/tests/test_parametrise.py0000664000175000017500000004454515167130205025233 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
A selection of tests for the parametrisation functionality.
"""
from pathlib import Path
import pytest
import numpy as np

from loki import Scheduler, fgen, Subroutine
from loki.jit_build import jit_compile
from loki.expression import symbols as sym, parse_expr
from loki.frontend import available_frontends
from loki.ir import nodes as ir, FindNodes

from loki.transformations.parametrise import ParametriseTransformation, declare_fixed_value_scalars_as_constants


@pytest.fixture(scope='module', name='here')
def fixture_here():
    return Path(__file__).parent


@pytest.fixture(scope='module', name='testdir')
def fixture_testdir(here):
    return here.parent.parent/'tests'


@pytest.fixture(name='config')
def fixture_config():
    """
    Default configuration dict with basic options.
    """
    return {
        'default': {
            'mode': 'idem',
            'role': 'kernel',
            'expand': True,
            'strict': True,
        },
        'routines': {
            'driver': {
                'role': 'driver',
                'expand': True,
            },
            'another_driver': {
                'role': 'driver',
                'expand': True,
            },
        }
    }


def compile_and_test(scheduler, tmp_path, a=5, b=1):
    """
    Compile the source code and call the driver function in order to test the results for correctness.
    """
    # Pick out the source files to compile
    driver_path_map = {item: item.source.path.stem for item in scheduler.items if 'driver' in item.name}
    path_source_map = {item.source.path.stem: item.source for item in driver_path_map}

    # Compile each file only once
    path_module_map = {
        stem: jit_compile(source, filepath=tmp_path/f'{stem}.F90', objname=stem)
        for stem, source in path_source_map.items()
    }

    # Run and validate each driver
    for item, stem in driver_path_map.items():
        c = np.zeros((a, b), dtype=np.int32, order='F')
        d = np.zeros((b,), dtype=np.int32, order='F')
        if item.local_name == 'driver':
            path_module_map[stem].driver(a, b, c, d)
            assert (c == 11).all()
            assert (d == 42).all()
        elif item.local_name == 'another_driver':
            path_module_map[stem].another_driver(a, b, c)
            assert (c == 11).all()
        else:
            assert False, f'Unknown driver name {item.local_name}'


def check_arguments_and_parameter(scheduler, subroutine_arguments, call_arguments, parameter_variables):
    """
    Check the parameters, subroutine and call arguments of each subroutine.
    """
    item = scheduler['parametrise#driver']
    routine_parameters = [var for var in item.ir.variables if var.type.parameter]
    assert routine_parameters == parameter_variables["driver"]
    assert [arg.name for arg in item.ir.arguments] == subroutine_arguments["driver"]
    for call in FindNodes(ir.CallStatement).visit(item.ir.body):
        if "kernel1" in call.name:
            assert call.arguments == call_arguments["kernel1"]
        elif "kernel2" in call.name:
            assert call.arguments == call_arguments["kernel2"]
    item = scheduler['parametrise#another_driver']
    routine_parameters = [var for var in item.ir.variables if var.type.parameter]
    assert routine_parameters == parameter_variables["another_driver"]
    assert [arg.name for arg in item.ir.arguments] == subroutine_arguments["another_driver"]
    for call in FindNodes(ir.CallStatement).visit(item.ir.body):
        if "kernel1" in call.name:
            assert call.arguments == call_arguments["kernel1"]
    item = scheduler['parametrise#kernel1']
    routine_parameters = [var for var in item.ir.variables if var.type.parameter]
    assert routine_parameters == parameter_variables["kernel1"]
    assert [arg.name for arg in item.ir.arguments] == subroutine_arguments["kernel1"]
    item = scheduler['parametrise#kernel2']
    routine_parameters = [var for var in item.ir.variables if var.type.parameter]
    assert routine_parameters == parameter_variables["kernel2"]
    assert [arg.name for arg in item.ir.arguments] == subroutine_arguments["kernel2"]
    for call in FindNodes(ir.CallStatement).visit(item.ir.body):
        if "device1" in call.name:
            assert call.arguments == call_arguments["device1"]
        elif "device2" in call.name:
            assert call.arguments == call_arguments["device2"]
    item = scheduler['parametrise#device1']
    routine_parameters = [var for var in item.ir.variables if var.type.parameter]
    assert routine_parameters == parameter_variables["device1"]
    assert [arg.name for arg in item.ir.arguments] == subroutine_arguments["device1"]
    for call in FindNodes(ir.CallStatement).visit(item.ir.body):
        if "device2" in call.name:
            assert call.arguments == call_arguments["device2"]
    item = scheduler['parametrise#device2']
    routine_parameters = [var for var in item.ir.variables if var.type.parameter]
    assert routine_parameters == parameter_variables["device2"]
    assert [arg.name for arg in item.ir.arguments] == subroutine_arguments["device2"]


@pytest.mark.parametrize('frontend', available_frontends())
def test_parametrise_source(tmp_path, testdir, frontend, config):
    """
    Test the actual source code without any transformations applied.
    """

    proj = testdir/'sources/projParametrise'

    dic2p = {'a': 12, 'b': 11}
    a = dic2p['a']
    b = dic2p['b']

    scheduler = Scheduler(
        paths=[proj], config=config, seed_routines=['driver', 'another_driver'],
        frontend=frontend, xmods=[tmp_path]
    )

    # check generated source code
    subroutine_arguments = {
        "driver": ['a', 'b', 'c', 'd'],
        "another_driver": ['a', 'b', 'c'],
        "kernel1": ['a', 'b', 'c'],
        "kernel2": ['a_new', 'b', 'd'],
        "device1": ['a', 'b', 'd', 'x', 'y'],
        "device2": ['a', 'b', 'd', 'x'],
    }

    call_arguments = {
        "kernel1": ('a', 'b', 'c'),
        "kernel2": ('a', 'b', 'd'),
        "device1": ('a_new', 'b', 'd', 'x', 'k2_tmp'),
        "device2": ('a', 'b', 'd', 'x')
    }

    parameter_variables = {
        "driver": [],
        "another_driver": [],
        "kernel1": [],
        "kernel2": [],
        "device1": [],
        "device2": [],
    }

    check_arguments_and_parameter(scheduler=scheduler, subroutine_arguments=subroutine_arguments,
                                  call_arguments=call_arguments, parameter_variables=parameter_variables)

    compile_and_test(scheduler=scheduler, tmp_path=tmp_path, a=a, b=b)


@pytest.mark.parametrize('frontend', available_frontends())
def test_parametrise_simple(tmp_path, testdir, frontend, config):
    """
    Basic testing of parametrisation functionality.
    """

    proj = testdir/'sources/projParametrise'

    dic2p = {'a': 12, 'b': 11}
    a = dic2p['a']
    b = dic2p['b']

    scheduler = Scheduler(
        paths=[proj], config=config, seed_routines=['driver', 'another_driver'],
        frontend=frontend, xmods=[tmp_path]
    )

    transformation = ParametriseTransformation(dic2p=dic2p)
    scheduler.process(transformation=transformation)

    subroutine_arguments = {
        "driver": ['parametrised_a', 'parametrised_b', 'c', 'd'],
        "another_driver": ['parametrised_a', 'parametrised_b', 'c'],
        "kernel1": ['c'],
        "kernel2": ['d'],
        "device1": ['d', 'x', 'y'],
        "device2": ['d', 'x'],
    }

    call_arguments = {
        "kernel1": ('c',),
        "kernel2": ('d',),
        "device1": ('d', 'x', 'k2_tmp'),
        "device2": ('d', 'x')
    }

    parameter_variables = {
        "driver": ['a', 'b'],
        "another_driver": ['a', 'b'],
        "kernel1": ['a', 'b'],
        "kernel2": ['a_new', 'b'],
        "device1": ['a', 'b'],
        "device2": ['a', 'b'],
    }

    check_arguments_and_parameter(scheduler=scheduler, subroutine_arguments=subroutine_arguments,
                                  call_arguments=call_arguments, parameter_variables=parameter_variables)

    compile_and_test(scheduler=scheduler, tmp_path=tmp_path, a=a, b=b)


@pytest.mark.parametrize('frontend', available_frontends())
def test_parametrise_simple_replace_by_value(tmp_path, testdir, frontend, config):
    """
    Basic testing of parametrisation functionality including replacing of the variables with the actual values.
    """

    proj = testdir/'sources/projParametrise'

    dic2p = {'a': 12, 'b': 11}
    a = dic2p['a']
    b = dic2p['b']

    scheduler = Scheduler(
        paths=[proj], config=config, seed_routines=['driver', 'another_driver'],
        frontend=frontend, xmods=[tmp_path]
    )

    transformation = ParametriseTransformation(dic2p=dic2p, replace_by_value=True)
    scheduler.process(transformation=transformation)

    subroutine_arguments = {
        "driver": ['parametrised_a', 'parametrised_b', 'c', 'd'],
        "another_driver": ['parametrised_a', 'parametrised_b', 'c'],
        "kernel1": ['c'],
        "kernel2": ['d'],
        "device1": ['d', 'x', 'y'],
        "device2": ['d', 'x'],
    }

    call_arguments = {
        "kernel1": ('c',),
        "kernel2": ('d',),
        "device1": ('d', 'x', 'k2_tmp'),
        "device2": ('d', 'x')
    }

    parameter_variables = {
        "driver": [],
        "another_driver": [],
        "kernel1": [],
        "kernel2": [],
        "device1": [],
        "device2": [],
    }

    check_arguments_and_parameter(scheduler=scheduler, subroutine_arguments=subroutine_arguments,
                                  call_arguments=call_arguments, parameter_variables=parameter_variables)

    routine_spec_str = fgen(scheduler['parametrise#driver'].ir.spec)
    assert f'c({a}, {b})' in routine_spec_str
    assert f'd({b})' in routine_spec_str
    routine_spec_str = fgen(scheduler['parametrise#another_driver'].ir.spec)
    assert f'c({a}, {b})' in routine_spec_str
    assert f'x({a})' in routine_spec_str
    routine_spec_str = fgen(scheduler['parametrise#kernel1'].ir.spec)
    assert f'c({a}, {b})' in routine_spec_str
    assert f'x({a})' in routine_spec_str
    assert f'y({a}, {b})' in routine_spec_str
    assert f'k1_tmp({a}, {b})' in routine_spec_str
    routine_spec_str = fgen(scheduler['parametrise#kernel2'].ir.spec)
    assert f'd({b})' in routine_spec_str
    assert f'x({a})' in routine_spec_str
    assert f'k2_tmp({a}, {a})' in routine_spec_str
    routine_spec_str = fgen(scheduler['parametrise#device1'].ir.spec)
    assert f'd({b})' in routine_spec_str
    assert f'x({a})' in routine_spec_str
    assert f'y({a}, {a})' in routine_spec_str
    routine_spec_str = fgen(scheduler['parametrise#device2'].ir.spec)
    assert f'd({b})' in routine_spec_str
    assert f'x({a})' in routine_spec_str
    assert f'z({b})' in routine_spec_str
    assert f'd2_tmp({b})' in routine_spec_str

    compile_and_test(scheduler=scheduler, tmp_path=tmp_path, a=a, b=b)


@pytest.mark.parametrize('frontend', available_frontends())
def test_parametrise_modified_callback(tmp_path, testdir, frontend, config):
    """
    Testing of the parametrisation functionality with modified callbacks for failed sanity checks.
    """

    proj = testdir/'sources/projParametrise'

    dic2p = {'a': 12, 'b': 11}
    a = dic2p['a']
    b = dic2p['b']

    subroutine_arguments = {
        "driver": ['parametrised_a', 'parametrised_b', 'c', 'd'],
        "another_driver": ['parametrised_a', 'parametrised_b', 'c'],
        "kernel1": ['c'],
        "kernel2": ['d'],
        "device1": ['d', 'x', 'y'],
        "device2": ['d', 'x'],
    }

    call_arguments = {
        "kernel1": ('c',),
        "kernel2": ('d',),
        "device1": ('d', 'x', 'k2_tmp'),
        "device2": ('d', 'x')
    }

    parameter_variables = {
        "driver": ['a', 'b'],
        "another_driver": ['a', 'b'],
        "kernel1": ['a', 'b'],
        "kernel2": ['a_new', 'b'],
        "device1": ['a', 'b'],
        "device2": ['a', 'b'],
    }

    def error_stop(**kwargs):
        msg = kwargs.get("msg")
        abort = (ir.Intrinsic(text=f'error stop "{msg}"'),)
        return abort

    def stop_execution(**kwargs):
        msg = kwargs.get("msg")
        abort = (ir.CallStatement(name=sym.Variable(name="stop_execution"), arguments=(sym.StringLiteral(f'{msg}'),)),)
        return abort

    abort_callbacks = (error_stop, stop_execution)

    for abort_callback in abort_callbacks:
        scheduler = Scheduler(
            paths=[proj], config=config, seed_routines=['driver', 'another_driver'],
            frontend=frontend, xmods=[tmp_path]
        )
        transformation = ParametriseTransformation(dic2p=dic2p, abort_callback=abort_callback)
        scheduler.process(transformation=transformation)

        check_arguments_and_parameter(scheduler=scheduler, subroutine_arguments=subroutine_arguments,
                                      call_arguments=call_arguments, parameter_variables=parameter_variables)

        compile_and_test(scheduler=scheduler, tmp_path=tmp_path, a=a, b=b)


@pytest.mark.parametrize('frontend', available_frontends())
def test_parametrise_modified_callback_wrong_input(tmp_path, testdir, frontend, config):
    """
    Testing of the parametrisation functionality with modified callback for failed sanity checks including test of
    a failed sanity check.
    """

    proj = testdir/'sources/projParametrise'

    dic2p = {'a': 12, 'b': 11}

    def only_warn(**kwargs):
        msg = kwargs.get("msg")
        abort = (ir.Intrinsic(text=f'print *, "This is just a warning: {msg}"'),)
        return abort

    scheduler = Scheduler(
        paths=[proj], config=config, seed_routines=['driver', 'another_driver'],
        frontend=frontend, xmods=[tmp_path]
    )
    transformation = ParametriseTransformation(dic2p=dic2p, abort_callback=only_warn)
    scheduler.process(transformation=transformation)

    subroutine_arguments = {
        "driver": ['parametrised_a', 'parametrised_b', 'c', 'd'],
        "another_driver": ['parametrised_a', 'parametrised_b', 'c'],
        "kernel1": ['c'],
        "kernel2": ['d'],
        "device1": ['d', 'x', 'y'],
        "device2": ['d', 'x'],
    }

    call_arguments = {
        "kernel1": ('c',),
        "kernel2": ('d',),
        "device1": ('d', 'x', 'k2_tmp'),
        "device2": ('d', 'x')
    }

    parameter_variables = {
        "driver": ['a', 'b'],
        "another_driver": ['a', 'b'],
        "kernel1": ['a', 'b'],
        "kernel2": ['a_new', 'b'],
        "device1": ['a', 'b'],
        "device2": ['a', 'b'],
    }

    check_arguments_and_parameter(scheduler=scheduler, subroutine_arguments=subroutine_arguments,
                                  call_arguments=call_arguments, parameter_variables=parameter_variables)

    compile_and_test(scheduler=scheduler, tmp_path=tmp_path, a=5, b=1)


@pytest.mark.parametrize('frontend', available_frontends())
def test_parametrise_non_driver_entry_points(tmp_path, testdir, frontend, config):
    """
    Testing of parametrisation functionality with defined entry points/functions, thus not being the default (driver).
    """

    proj = testdir/'sources/projParametrise'

    dic2p = {'a': 12, 'b': 11}
    a = dic2p['a']
    b = dic2p['b']

    scheduler = Scheduler(
        paths=[proj], config=config, seed_routines=['driver', 'another_driver'], frontend=frontend,
        xmods=[tmp_path]
    )

    transformation = ParametriseTransformation(dic2p=dic2p, entry_points=("kernel1", "kernel2"))
    scheduler.process(transformation=transformation)

    subroutine_arguments = {
        "driver": ['a', 'b', 'c', 'd'],
        "another_driver": ['a', 'b', 'c'],
        "kernel1": ['parametrised_a', 'parametrised_b', 'c'],
        "kernel2": ['a_new', 'parametrised_b', 'd'],
        "device1": ['a', 'd', 'x', 'y'],
        "device2": ['a', 'd', 'x'],
    }

    call_arguments = {
        "kernel1": ('a', 'b', 'c'),
        "kernel2": ('a', 'b', 'd'),
        "device1": ('a_new', 'd', 'x', 'k2_tmp'),
        "device2": ('a', 'd', 'x')
    }

    parameter_variables = {
        "driver": [],
        "another_driver": [],
        "kernel1": ['a', 'b'],
        "kernel2": ['b'],
        "device1": ['b'],
        "device2": ['b'],
    }

    check_arguments_and_parameter(scheduler=scheduler, subroutine_arguments=subroutine_arguments,
                                  call_arguments=call_arguments, parameter_variables=parameter_variables)

    compile_and_test(scheduler=scheduler, tmp_path=tmp_path, a=a, b=b)


@pytest.mark.parametrize('frontend', available_frontends())
def test_declare_constant_scalars(frontend):
    fcode = """
subroutine transform_declare_constant_scalars(invar, ret)
  implicit none
  integer, intent(in) :: invar
  integer, intent(out) :: ret
  integer, parameter :: param = 1
  integer :: a, b, c, d, e, f
  real :: x, y, z
  logical :: l1, l2, l3
  real :: arr1(2), arr2(3)

  arr1(1) = 1.0
  arr1(2) = 2.0
  arr2 = (/ 1.0, 2.0, 3.0 /)
  ret = 0
  x = 3.5
  y = invar * 2.0
  z = 4.5 + 3.5
  a = ret
  b = 10
  C = invar*invar
  e = 1
  F = 1
  L1 = .false.
  l3 = .true.
  if (c .gt. 0) then
    L2 = .true.
  else
    l2 = .false.
  endif

  call some_routine(e)
  arr1(1) = some_inline_func(f)

contains
  function some_inline_func(a) result(res)
    integer, intent(in) :: a
    integer :: res
    res = a
  end function some_inline_func
end subroutine transform_declare_constant_scalars
    """.strip()
    routine = Subroutine.from_source(fcode, frontend=frontend)

    init_dict = {'x': '3.5', 'z': '4.5 + 3.5', 'b': '10', 'l1': '.false.', 'l3': '.true.',
            'param': '1'}

    assignments = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assignments) == 17
    variables = routine.variables
    params = [var for var in variables if var.type.parameter]
    assert len(params) == 1
    for param in params:
        assert param.initial == parse_expr(init_dict[str(param.name)])

    # declare those scalars as constant which are only written to once and assigned with a value
    declare_fixed_value_scalars_as_constants(routine)

    assignments = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assignments) == 12
    variables = routine.variables
    params = [var for var in variables if var.type.parameter]
    assert len(params) == 6
    for param in params:
        assert param.initial == parse_expr(init_dict[str(param.name)])
loki-ecmwf-0.3.6/loki/transformations/tests/test_utilities.py0000664000175000017500000006065415167130205024731 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki import Module, Subroutine, Dimension, fgen
from loki.expression import symbols as sym, parse_expr
from loki.frontend import available_frontends, OMNI
from loki.logging import WARNING
from loki.ir import (
    nodes as ir, FindNodes, FindVariables, FindInlineCalls,
    SubstituteExpressions, pragmas_attached
)
from loki.types import BasicType

from loki.transformations.utilities import (
    single_variable_declaration, recursive_expression_map_update,
    convert_to_lower_case, replace_intrinsics, rename_variables,
    get_integer_variable, get_loop_bounds, is_driver_loop,
    find_driver_loops, get_local_arrays, check_routine_sequential,
    substitute_variables_for_definitions, is_pragma_driver_loop
)


@pytest.mark.parametrize('frontend', available_frontends(skip=[(OMNI, 'Makes variable declaration already unique')]))
def test_transform_utilities_single_variable_declaration(frontend):
    """
    Test correct inlining of elemental functions.
    """
    fcode = """
subroutine foo(a, x, y)
    integer, intent(in) :: a
    real, intent(inout):: x(a), y(a, a)
    integer :: i1, i2, i3, i4
    real :: r1, r2, r3, r4
    x = a
    y = a
end subroutine foo
"""

    routine = Subroutine.from_source(fcode, frontend=frontend)
    single_variable_declaration(routine=routine, variables=('y', 'i1', 'i3', 'r1', 'r2', 'r3', 'r4'))

    declarations = FindNodes(ir.VariableDeclaration).visit(routine.spec)
    assert declarations[0].symbols == ('a',)
    assert [smbl.name for smbl in declarations[1].symbols] == ['x']
    assert [smbl.name for smbl in declarations[2].symbols] == ['y']
    assert declarations[3].symbols == ('i2', 'i4')
    assert declarations[4].symbols == ('i1',)
    assert declarations[5].symbols == ('i3',)
    assert declarations[6].symbols == ('r1',)
    assert declarations[7].symbols == ('r2',)
    assert declarations[8].symbols == ('r3',)
    assert declarations[9].symbols == ('r4',)


@pytest.mark.parametrize('frontend', available_frontends(skip=[(OMNI, 'Makes variable declaration already unique')]))
def test_transform_utilities_single_variable_declarations(frontend):
    """
    Test correct inlining of elemental functions.
    """
    fcode = """
subroutine foo(a, x, y)
    integer, intent(in) :: a
    real, intent(inout):: x(a), y(a, a)
    integer :: i1, i2, i3, i4
    real :: r1, r2, r3, r4
    real :: x1, x2(a), x3(a), x4(a, a)
    x = a
    y = a
end subroutine foo
"""
    # variables=None and group_by_shape=False, meaning all variable declarations to be unique
    routine = Subroutine.from_source(fcode, frontend=frontend)
    single_variable_declaration(routine=routine)

    declarations = FindNodes(ir.VariableDeclaration).visit(routine.spec)
    assert len(declarations) == 15
    for decl in declarations:
        assert len(decl.symbols) == 1

    # group_by_shape = False and variables=None, meaning only non-similar variable declarations unique
    routine = Subroutine.from_source(fcode, frontend=frontend)
    single_variable_declaration(routine=routine, group_by_shape=True)

    declarations = FindNodes(ir.VariableDeclaration).visit(routine.spec)
    assert len(declarations) == 8
    for decl in declarations:
        types = [smbl.type for smbl in decl.symbols]
        _ = [type == types[0] for type in types]
        assert all(_)
        if isinstance(decl.symbols[0], sym.Array):
            shapes = [smbl.shape for smbl in decl.symbols]
            _ = [shape == shapes[0] for shape in shapes]
            assert all(_)

    # group_by_shape = False and variables=('x2', 'r3'), meaning only non-similar variable declarations unique
    routine = Subroutine.from_source(fcode, frontend=frontend)
    single_variable_declaration(routine=routine, variables=('x2', 'r3'), group_by_shape=True)

    declarations = FindNodes(ir.VariableDeclaration).visit(routine.spec)
    assert len(declarations) == 10
    assert declarations[5].symbols == ('r3',)
    assert [smbl.name for smbl in declarations[8].symbols] == ['x2']
    for decl in declarations:
        types = [smbl.type for smbl in decl.symbols]
        _ = [type == types[0] for type in types]
        assert all(_)
        if isinstance(decl.symbols[0], sym.Array):
            shapes = [smbl.shape for smbl in decl.symbols]
            _ = [shape == shapes[0] for shape in shapes]
            assert all(_)


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_convert_to_lower_case(frontend):
    fcode = """
subroutine my_NOT_ALL_lowercase_ROUTINE(VAR1, another_VAR, lower_case, MiXeD_CasE)
    implicit none
    integer, intent(in) :: VAR1, another_VAR
    integer, intent(inout) :: lower_case(ANOTHER_VAR)
    integer, intent(inout) :: MiXeD_CasE(Var1, ANOTHER_VAR)
    integer :: J, k

    do k=1,ANOTHER_VAR
        do J=1,VAR1
            mixed_CASE(J, K) = J + K
        end do
    end do

    do K=1,ANOTHER_VAR
        LOWER_CASE(MIXEd_cASE(1, K)) = K - 1
    end do

    miXed_CasE(1, 1) = Max(mIn(sQrT(9.0), 2.0), 1.0)
end subroutine my_NOT_ALL_lowercase_ROUTINE
    """.strip()
    routine = Subroutine.from_source(fcode, frontend=frontend)
    convert_to_lower_case(routine)
    assert all(
        var.name.islower() and str(var).islower()
        for var in FindVariables(unique=True).visit(routine.ir)
    )
    assert all(
        f.name.islower() and str(f).islower()
        for f in FindInlineCalls().visit(routine.ir)
    )


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_utilities_recursive_expression_map_update(frontend, tmp_path):
    fcode = """
module some_mod
    implicit none

    type some_type
        integer :: m, n
        real, allocatable :: a(:, :)
    contains
        procedure, pass :: my_add
    end type some_type
contains
    function my_add(self, data, val)
        class(some_type), intent(inout) :: self
        real, intent(in) :: data(:,:)
        real, value :: val
        real :: my_add(:,:)
        my_add(:,:) = self%a(:,:) + data(:,:) + val
    end function my_add

    subroutine do(my_obj)
        type(some_type), intent(inout) :: my_obj
        my_obj%a = my_obj%my_add(MY_OBJ%a(1:my_obj%m, 1:MY_OBJ%n), 1.)
    end subroutine do
end module some_mod
    """.strip()

    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    routine = module['do']

    expr_map = {}
    expr_map[routine.variable_map['my_obj']] = routine.variable_map['my_obj'].clone(name='obj')
    for var in FindVariables().visit(routine.body):
        if var.parent == 'my_obj':
            expr_map[var] = var.clone(name=f'obj%{var.basename}', parent=var.parent.clone(name='obj'))

    # There are "my_obj" nodes still around...
    assert any(
        var == 'my_obj' or var.parent == 'my_obj' for var in FindVariables().visit(list(expr_map.values()))
    )

    # ...and application performs only a partial substitution
    cloned = routine.clone()
    cloned.body = SubstituteExpressions(expr_map).visit(cloned.body)
    assert fgen(cloned.body.body[0]).lower() == 'obj%a = obj%my_add(obj%a(1:my_obj%m, 1:my_obj%n), 1.)'

    # Apply recursive update
    expr_map = recursive_expression_map_update(expr_map)

    # No more "my_obj" nodes...
    assert all(
        var != 'my_obj' and var.parent != 'my_obj' for var in FindVariables().visit(list(expr_map.values()))
    )

    # ...and full substitution
    assert fgen(routine.body.body[0]).lower() == 'my_obj%a = my_obj%my_add(my_obj%a(1:my_obj%m, 1:my_obj%n), 1.)'
    routine.body = SubstituteExpressions(expr_map).visit(routine.body)
    assert fgen(routine.body.body[0]) == 'obj%a = obj%my_add(obj%a(1:obj%m, 1:obj%n), 1.)'

@pytest.mark.parametrize('frontend', available_frontends(skip=[(OMNI, 'Argument mismatch for "min"')]))
def test_transform_utilites_replace_intrinsics(frontend):
    fcode = """
subroutine replace_intrinsics()
    implicit none
    real :: a, b, eps
    real, parameter :: param = min(0.1, epsilon(param)*1000.)

    eps = param * 10.
    eps = 0.1
    b = max(10., eps)
    a = min(1. + b, 1. - eps)

end subroutine replace_intrinsics
    """.strip()
    routine = Subroutine.from_source(fcode, frontend=frontend)
    symbol_map = {'epsilon': 'DBL_EPSILON'}
    function_map = {'min': 'fmin', 'max': 'fmax'}
    replace_intrinsics(routine, symbol_map=symbol_map, function_map=function_map)
    inline_calls = FindInlineCalls(unique=False).visit(routine.ir)
    assert inline_calls[0].name == 'fmin'
    assert inline_calls[1].name == 'fmax'
    assert inline_calls[2].name == 'fmin'
    variables = FindVariables(unique=False).visit(routine.ir)
    assert 'DBL_EPSILON' in variables
    assert 'epsilon' not in variables
    # check wether it really worked for variable declarations or rather parameters
    assert 'DBL_EPSILON' in FindVariables().visit(routine.variable_map['param'].initial)

@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_utilites_rename_variables(frontend):
    fcode = """
subroutine rename_variables(some_arg, rename_arg)
    implicit none
    integer, intent(inout) :: some_arg, rename_arg
    integer :: some_var, rename_var
    integer :: i, j
    real :: some_array(10, 10), rename_array(10, 10)

    do i=1,10
        some_var = i
        rename_var = i + 1
        do J=1,10
            some_array(i, j) = 10. * some_arg * rename_arg
	        rename_array(i, j) = 5. * some_arg * rename_arg
        end do
    end do

end subroutine rename_variables
    """.strip()
    routine = Subroutine.from_source(fcode, frontend=frontend)
    symbol_map = {'rename_var': 'renamed_var',
                  'rename_arg': 'renamed_arg',
                  'rename_array': 'renamed_array'}
    rename_variables(routine, symbol_map=symbol_map)
    variables = [var.name for var in FindVariables(unique=False).visit(routine.ir)]
    assert 'renamed_var' in variables
    assert 'rename_var'  not in variables
    assert 'renamed_arg' in variables
    assert 'rename_arg' not in variables
    assert 'renamed_array' in variables
    assert 'rename_array' not in variables
    # check routine arguments
    assert 'renamed_arg' in routine.arguments
    assert 'rename_arg' not in routine.arguments
    # check symbol table
    assert 'renamed_arg' in routine.symbol_attrs
    assert 'rename_arg' not in routine.symbol_attrs
    assert 'renamed_array' in routine.symbol_attrs
    assert 'rename_array' not in routine.symbol_attrs
    assert 'renamed_arg' in routine.symbol_attrs
    assert 'rename_arg' not in routine.symbol_attrs

@pytest.mark.parametrize('frontend', available_frontends(
    xfail=[(OMNI, 'OMNI does not handle missing type definitions')]
))
def test_transform_utilites_rename_variables_extended(frontend):
    fcode = """
subroutine rename_variables_extended(KLON, ARR, TT)
    implicit none

    INTEGER, INTENT(IN) :: KLON
    REAL, INTENT(INOUT) :: ARR(KLON)
    REAL :: MY_TMP(KLON)
    TYPE(SOME_TYPE), INTENT(INOUT) :: TT
    TYPE(OTHER_TYPE) :: TMP_TT

    TMP_TT%SOME_MEMBER = TT%SOME_MEMBER + TT%PROC_FUNC(5.0)
    CALL TT%NESTED%PROC_SUB(TT%NESTED%VAR)
    TT%VAL = TMP_TT%VAL

end subroutine rename_variables_extended
    """.strip()
    routine = Subroutine.from_source(fcode, frontend=frontend)
    symbol_map = {'klon': 'ncol', 'tt': 'arg_tt'}
    rename_variables(routine, symbol_map=symbol_map)
    # check arguments
    arguments = [arg.name.lower() for arg in routine.arguments]
    assert 'ncol' in arguments
    assert 'klon' not in arguments
    assert 'arg_tt' in arguments
    assert 'tt' not in arguments
    # check array shape
    assert routine.variable_map['arr'].shape == ('ncol',)
    assert routine.variable_map['my_tmp'].shape == ('ncol',)
    # check variables
    variables = [var.name.lower() for var in FindVariables(unique=False).visit(routine.ir)]
    assert 'ncol' in variables
    assert 'klon' not in variables
    assert 'arg_tt' in variables
    assert 'tt' not in variables
    assert 'arg_tt%some_member' in variables
    assert 'tt%some_member' not in variables
    assert 'arg_tt%proc_func' in variables
    assert 'tt%proc_func' not in variables
    assert 'arg_tt%nested' in variables
    assert 'tt%nested' not in variables
    assert 'arg_tt%nested%proc_sub' in variables
    assert 'tt%nested%proc_sub' not in variables
    assert 'arg_tt%nested%var' in variables
    assert 'tt%nested%var' not in variables
    # check symbol table
    routine_symbol_attrs_name = tuple(key.lower() for key in routine.symbol_attrs)+\
            tuple(key.split('%')[0].lower() for key in routine.symbol_attrs)
    assert 'ncol' in routine_symbol_attrs_name
    assert 'klon' not in routine_symbol_attrs_name
    assert 'arg_tt' in routine_symbol_attrs_name
    assert 'tt' not in routine_symbol_attrs_name


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_utilites_get_integer_variable(frontend):
    """ Test :any:`get_integer_variable` utility. """

    fcode = """
subroutine test_get_integer_variable(n)
  integer, intent(inout) :: n
  integer(kind=4) :: i

  n = n + 2 * i
end subroutine test_get_integer_variable
    """
    routine = Subroutine.from_source(fcode, frontend=frontend)

    n = get_integer_variable(routine, 'n')
    assert isinstance(n, sym.Scalar)
    assert n.type.dtype == BasicType.INTEGER
    assert n.type.intent == 'inout'

    i = get_integer_variable(routine, 'I')
    assert isinstance(i, sym.Scalar)
    assert i.type.dtype == BasicType.INTEGER
    assert i.type.kind == '4'

    k = get_integer_variable(routine, 'k')
    assert isinstance(k, sym.Scalar)
    assert k.type.dtype == BasicType.INTEGER


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_utilites_get_loop_bounds(frontend, tmp_path):
    """ Test :any:`get_loop_bounds` utility. """

    fcode = """
module test_get_loop_bounds_mod
implicit none
type my_dim
  integer(kind=8) :: a, b
end type my_dim
contains

subroutine test_get_loop_bounds(dim, n, start, end, arr)
  type(my_dim), intent(in) :: dim
  integer, intent(in) :: n, start, end
  real, intent(inout) :: arr(n)
  integer :: i, j, k

  do i=start, end
    arr(i) = 2. * arr(i)
  end do

  do j=dim%a, dim%b
    arr(j) = 2. * arr(j)
  end do

  do k=1,n
    arr(k) = 2. * arr(k)
  end do

end subroutine test_get_loop_bounds
end module test_get_loop_bounds_mod
"""
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    routine = module['test_get_loop_bounds']

    x = Dimension(name='x', size='n', index='i', bounds=('start', 'end'))
    y = Dimension(name='y', size='n', index='i', bounds=('a', 'b'))
    z = Dimension(name='y', size='n', index='i', bounds=('dim%a', 'dim%b'))
    a = Dimension(name='a', size='n', index='k', lower='1', upper='n')

    start, end = get_loop_bounds(routine, x)  # pylint: disable=unbalanced-tuple-unpacking
    assert isinstance(start, sym.Scalar)
    assert start.type.dtype == BasicType.INTEGER
    assert start.type.intent == 'in'
    assert isinstance(end, sym.Scalar)
    assert end.type.dtype == BasicType.INTEGER
    assert end.type.intent == 'in'

    with pytest.raises(RuntimeError):
        _, _ = get_loop_bounds(routine, y)  # pylint: disable=unbalanced-tuple-unpacking

    # Test type-bound symbol resolution
    start, end = get_loop_bounds(routine, z)  # pylint: disable=unbalanced-tuple-unpacking
    assert isinstance(start, sym.Scalar)
    assert start.type.dtype == BasicType.INTEGER
    assert start.type.kind == '8'
    assert isinstance(end, sym.Scalar)
    assert end.type.dtype == BasicType.INTEGER
    assert end.type.kind == '8'

    # Test matching with a natural constant lower bound
    start, end = get_loop_bounds(routine, a)  # pylint: disable=unbalanced-tuple-unpacking
    assert isinstance(start, sym.IntLiteral) and start == 1
    assert isinstance(end, sym.Scalar) and end == 'n'
    assert end.type.dtype == BasicType.INTEGER
    assert end.type.intent == 'in'


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_utilities_find_driver_loops(frontend):
    """ Test :any:`find_driver_loops` utility. """

    fcode = """
subroutine test_find_driver_loops(n, start, end, arr)
  integer, intent(in) :: n, start, end
  real, intent(inout) :: arr(n)
  integer :: i, j

  !$loki driver-loop
  do i=start, end
    arr(i) = 2. * arr(i)
  end do

  do j=start, end
    arr(j) = 2. * arr(j)
  end do

  do i=start, end
    call make_mine_a_double(arr(i))
  end do
end subroutine test_find_driver_loops
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    with pragmas_attached(routine, node_type=ir.Loop):
        # Test is_driver_loop utility
        loops = FindNodes(ir.Loop).visit(routine.body)
        assert len(loops) == 3
        assert is_driver_loop(loops[0], targets=())
        assert not is_driver_loop(loops[1], targets=())
        assert not is_driver_loop(loops[2], targets=())
        assert is_driver_loop(loops[2], targets=('make_mine_a_double', ))

        # Test find_driver_loopd utility
        driver_loops = find_driver_loops(routine.body, targets=('make_mine_a_double',))
        assert len(driver_loops) == 2
        assert driver_loops[0].variable == 'i'
        assert isinstance(driver_loops[0].body[0], ir.Assignment)
        assert driver_loops[1].variable == 'i'
        assert isinstance(driver_loops[1].body[0], ir.CallStatement)


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_utilities_find_driver_loops_multiple_nested(frontend, caplog):
    """ Test :any:`find_driver_loops` utility. """

    fcode = """
subroutine test_find_driver_loops(n, start, end, arr)
  integer, intent(in) :: n, start, end
  real, intent(inout) :: arr(1000)
  integer :: i, j, k, l, m, p

  !$loki driver-loop
  do i=start,end       ! driver loop 0 (loop 0)
    arr(i) = 2. * arr(i)
  end do

  do j=start,end
    arr(j) = 2. * arr(j)
  end do

  do i=start, end
    !$loki driver-loop
    do j=start,end     ! driver loop 1 (loop 3)
      arr(j) = 2. * arr(j)
    end do

    do j=start,end     ! driver loop 2 (loop 4)
        call target_kernel(arr(j))
    end do
  end do

  do i=start,end
    do j=start,end     ! driver loop 4 (loop 6)
        do l=start,end
            call target_kernel(arr(i+j+l))
        end do
    end do

    !$loki driver-loop
    do j=start,end     ! driver loop 3 (loop 8)
        arr(j+l) = 2. * arr(j+l)
    end do
  end do


  do i=start,end        ! driver loop 5 (loop 9)

    call target_kernel(arr(i))

    do j=start,end
      arr(j) = 2. * arr(j)
    end do

    call target_kernel(arr(i))

  end do

  do j=start,end
    do i=start,end       ! skipped loop
      call target_kernel(arr(i))
    end do

    !$loki driver-loop
    do i=start,end       ! driver loop 6 (loop 13)
      arr(i) = 2. * arr(i)
    end do

    call target_kernel(arr(j))
  end do

  do j=start,end
  do m=start,end
    do i=start,end
      do k=start,end ! driver loop 7 (loop 17)
        do l=start,end
          call target_kernel(arr(l))
        end do
      end do
      !$loki driver-loop
      do k=start,end       ! driver loop 8 (loop 19)
        arr(i) = 2. * arr(i)
      end do
      do k=start,end ! driver loop 9 (loop 20)
        do l=start,end
          do p=start,end
            call target_kernel(arr(p))
          end do
        end do
      end do
    end do
  end do
  end do

end subroutine test_find_driver_loops
"""
    caplog.set_level(WARNING)
    routine = Subroutine.from_source(fcode, frontend=frontend)

    with pragmas_attached(routine, node_type=ir.Loop):
        # Test is_driver_loop utility
        loops = FindNodes(ir.Loop).visit(routine.body)
        assert len(loops) == 23

        assert is_pragma_driver_loop(loops[0])
        assert is_driver_loop(loops[0], targets=())
        assert not is_pragma_driver_loop(loops[1])

        assert not is_driver_loop(loops[1], targets=())
        assert not is_driver_loop(loops[2], targets=())

        assert is_driver_loop(loops[3], targets=())
        assert not is_driver_loop(loops[4], targets=())
        assert is_driver_loop(loops[4], targets=('target_kernel', ))

        driver_loops = find_driver_loops(routine.body, targets=('target_kernel',))
        assert len(caplog.records) == 1
        assert "Nested pragma marked driver loop inside loop" in caplog.records[0].message

        assert len(driver_loops) == 10
        for i in [0, 3, 4, 6, 8, 9, 13, 17, 19, 20]:
            assert loops[i] in driver_loops


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_utilites_get_local_arrays(frontend, tmp_path):
    """ Test :any:`get_local_arrays` utility. """

    fcode_mod = """
module global_var_mod
    integer, parameter :: arr_size = 20
    real :: myarray(arr_size)
end module global_var_mod
"""

    fcode = """
module test_get_local_arrays_mod
implicit none
type my_dim
  integer :: a(2)
end type my_dim
contains

subroutine test_get_local_arrays(n, dims, start, end, arr)
  use global_var_mod, only: myarray
  integer, intent(in) :: n, start, end
  type(my_dim), intent(in) :: dims
  real, intent(inout) :: arr(dims%a(2))
  real :: local(n), tmp
  integer :: i

  tmp = 2.0

  do i=start, end
    local(i) = tmp * ARR(i)
  end do

  do i=start, end
    ARR(ji) = tmp * local(i)
  end do
end subroutine test_get_local_arrays
end module test_get_local_arrays_mod
"""
    global_mod = Module.from_source(fcode_mod, frontend=frontend, xmods=[tmp_path])
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path], definitions=(global_mod,))
    routine = module['test_get_local_arrays']

    local_arrs = get_local_arrays(routine, routine.body, unique=True)
    assert len(local_arrs) == 1
    assert local_arrs[0] == 'local(i)'

    local_arrs = get_local_arrays(routine, routine.body, unique=False)
    assert len(local_arrs) == 2
    assert all(l == 'local(i)' for l in local_arrs)

    local_arrs = get_local_arrays(routine, routine.body.body[-1:], unique=False)
    assert len(local_arrs) == 1
    assert local_arrs[0] == 'local(i)'

    # Test for component arrays on arguments in spec
    local_arrs = get_local_arrays(routine, routine.spec, unique=True)
    assert len(local_arrs) == 1
    assert local_arrs[0] == 'local(n)'


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_utilites_check_routine_sequential(frontend, tmp_path):
    """ Test :any:`check_routine_sequential` utility. """

    fcode = """
module test_check_routine_sequential_mod
implicit none
contains

  subroutine test_acc_seq(i)
    integer, intent(inout) :: i
!$acc routine seq
    i = i + 1
  end subroutine test_acc_seq

  subroutine test_loki_seq(i)
    integer, intent(inout) :: i
!$loki routine seq
    i = i + 1
  end subroutine test_loki_seq

  subroutine test_acc_vec(i)
    integer, intent(inout) :: i
!$acc routine vector
    i = i + 1
  end subroutine test_acc_vec

end module test_check_routine_sequential_mod
"""
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    assert not check_routine_sequential(module['test_acc_seq'])
    assert check_routine_sequential(module['test_loki_seq'])
    assert not check_routine_sequential(module['test_acc_vec'])


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_utilites_substitute_variables_for_definitions(tmp_path, frontend):
    """ Test :any:`substitute_variables_for_definitions` utility. """

    fcode_mod = """
module some_mod
    type some_type
      integer :: a
      integer :: b
      integer :: n
    end type
end module some_mod
"""

    fcode = """
subroutine test_substitute_variables_for_definitions(start, end, arr, derived_var)
  use some_mod, only: some_type
  integer, intent(in) ::start, end
  real, intent(inout) :: arr(derived_var%n)
  type(some_type), intent(in) :: derived_var
  integer :: a, b, i, j, n

  n = derived_var%n
  a = derived_var%a + 1
  i = 2
  j = i

end subroutine test_substitute_variables_for_definitions
"""
    some_mod = Module.from_source(fcode_mod, frontend=frontend, xmods=[tmp_path])
    routine = Subroutine.from_source(fcode, frontend=frontend, xmods=[tmp_path], definitions=some_mod)
    var_map = routine.variable_map
    remapped_1 = substitute_variables_for_definitions(routine, variables=var_map['n'])
    assert len(remapped_1) == 1
    assert remapped_1[0] == 'derived_var%n'
    remap_vars = [var_map[var] for var in ['n', 'a', 'i', 'j', 'b']]
    remapped_2 = substitute_variables_for_definitions(routine, variables=remap_vars)
    assert remapped_2 == ['derived_var%n', parse_expr('derived_var%a + 1'), '2', 'i', 'b']
loki-ecmwf-0.3.6/loki/transformations/tests/test_loop_blocking.py0000664000175000017500000005210415167130205025526 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import pytest
import numpy as np

from loki import Subroutine
from loki.jit_build import jit_compile, clean_test
from loki.expression import symbols as sym, Array
from loki.frontend import available_frontends
from loki.ir import (
    nodes as ir, FindNodes, FindVariables, pragmas_attached
)

from loki.transformations.utilities import find_driver_loops
from loki.transformations.loop_blocking import (
    LoopSplittingVariables, split_loop, block_loop_arrays
)
from loki.types import SymbolAttributes, BasicType
from loki.logging import ERROR

"""
Tests for loop splitting and blocking utilities.
"""

# The number of additional variables created during loop splitting,
# e.g. new loop index for outer loop, new loop bounds
LOKI_LOOP_SLIT_VAR_ADDITION = 7


def loop_symbols(node):
    return [
        (loop.variable, loop.bounds.start, loop.bounds.stop, loop.bounds.step)
        for loop in FindNodes(ir.Loop).visit(node)
    ]


def assignment_symbols(node):
    return [(assign.lhs, assign.rhs) for assign in FindNodes(ir.Assignment).visit(node)]


def array_access_symbols(node):
    return [
        (var.name, var.dimensions) for var in FindVariables().visit(node)
        if isinstance(var, Array)
    ]


def test_loop_splitting_vars(caplog):
    loop_var = sym.Variable(name='i', type=SymbolAttributes(BasicType.INTEGER))

    # Test init with int block size
    blk_size = 1
    loop_splitting_vars = LoopSplittingVariables(loop_var, blk_size)
    loop_var = loop_splitting_vars.loop_var
    assert isinstance(loop_var, sym.Scalar) and loop_var.type.dtype == BasicType.INTEGER

    # Test init with scalar
    blk_size = sym.Variable(name='i', type=SymbolAttributes(BasicType.INTEGER))
    loop_splitting_vars = LoopSplittingVariables(loop_var, blk_size)
    loop_var = loop_splitting_vars.loop_var
    assert isinstance(loop_var, sym.Scalar) and loop_var.type.dtype == BasicType.INTEGER

    # Test init with scalar
    blk_size = sym.IntLiteral(1)
    loop_splitting_vars = LoopSplittingVariables(loop_var, blk_size)
    loop_var = loop_splitting_vars.loop_var
    assert isinstance(loop_var, sym.Scalar) and loop_var.type.dtype == BasicType.INTEGER

    # Test init with bad value
    blk_size = 'i'
    with caplog.at_level(ERROR):
        with pytest.raises(ValueError):
            loop_splitting_vars = LoopSplittingVariables(loop_var, blk_size)
        assert len(caplog.records) == 1
        assert ("LoopSplittingVariables: Block size argument must be an integer constant or a " +
                "scalar variable" in caplog.records[0].message)


@pytest.mark.parametrize('frontend', available_frontends())
def test_1d_splitting(frontend):
    """
    Apply loop blocking of simple loops into two loops
    """
    fcode = """
subroutine test_1d_splitting(a, b, n)
  implicit none
  integer, intent(in) :: n
  real(kind=8), intent(inout) :: a(n)
  real(kind=8), intent(inout) :: b(n)
  integer :: i
  !$loki driver-loop
  do i=1,n
    a(i) = real(i, kind=8)
  end do
end subroutine test_1d_splitting
    """
    block_size = 117
    routine = Subroutine.from_source(fcode, frontend=frontend)
    loops = FindNodes(ir.Loop).visit(routine.ir)
    num_loops = len(loops)
    num_vars = len(routine.variable_map)
    with pragmas_attached(routine, ir.Loop):
        loops = find_driver_loops(routine.body, targets=None)
    split_loop(routine, loops[0], block_size)
    loops = FindNodes(ir.Loop).visit(routine.ir)

    assert len(loops) == num_loops + 1, \
        f"Total number of loops transformation is: {len(loops)} but expected {num_loops + 1}"
    assert len(routine.variable_map) == num_vars + LOKI_LOOP_SLIT_VAR_ADDITION, (
        f"Total number of variables after loop splitting is: {len(routine.variable_map)} "
        f"but expected {num_vars + LOKI_LOOP_SLIT_VAR_ADDITION}"
    )

    assert loop_symbols(routine.ir) == [
        ('i_loop_block_idx', 1, 'i_loop_num_blocks', None),
        ('i_loop_local', 1, 'i_loop_block_end - i_loop_block_start + 1', None)
    ]
    # Ensure inner loop bound has been parametrised correctly
    assert routine.variable_map['i_loop_block_size'].type.initial == block_size

    assert assignment_symbols(routine.body) == [
        ('i_loop_num_blocks', '1 + (-1 + n) / i_loop_block_size'),
        ('i_loop_block_start', '(i_loop_block_idx - 1)*i_loop_block_size + 1'),
        ('i_loop_block_end', 'min(i_loop_block_idx*i_loop_block_size, n)'),
        ('i_loop_iter_num', 'i_loop_block_start + i_loop_local - 1'),
        ('i', 'i_loop_iter_num'), ('a(i)', 'real(i, kind=8)')
    ]


@pytest.mark.parametrize('frontend', available_frontends())
def test_1d_splitting_multi_var(frontend):
    """
    Apply loop blocking of simple loops into two loops
    """
    fcode = """
subroutine test_1d_splitting_multi_var(a, b, n)
  implicit none
  integer, intent(in) :: n
  real(kind=8), intent(inout) :: a(n)
  real(kind=8), intent(inout) :: b(n)
  real(kind=8) :: c(n)
  integer :: i
  !$loki driver-loop
  do i=1,n
    c(1) = c(1) + i
    a(i) = real(i)
  end do
end subroutine test_1d_splitting_multi_var
    """
    block_size = 117
    routine = Subroutine.from_source(fcode, frontend=frontend)
    loops = FindNodes(ir.Loop).visit(routine.ir)
    num_loops = len(loops)
    num_vars = len(routine.variable_map)
    with pragmas_attached(routine, ir.Loop):
        loops = find_driver_loops(routine.body, targets=None)
    split_loop(routine, loops[0], block_size)
    loops = FindNodes(ir.Loop).visit(routine.ir)

    assert len(loops) == num_loops + 1, \
        f"Total number of loops transformation is: {len(loops)} but expected {num_loops + 1}"
    assert len(routine.variable_map) == num_vars + LOKI_LOOP_SLIT_VAR_ADDITION, (
        f"Total number of variables after loop splitting is: {len(routine.variable_map)} "
        f"but expected {num_vars + LOKI_LOOP_SLIT_VAR_ADDITION}"
    )

    assert loop_symbols(routine.ir) == [
        ('i_loop_block_idx', 1, 'i_loop_num_blocks', None),
        ('i_loop_local', 1, 'i_loop_block_end - i_loop_block_start + 1', None)
    ]
    # Ensure inner loop bound has been parametrised correctly
    assert routine.variable_map['i_loop_block_size'].type.initial == block_size

    assert assignment_symbols(routine.body) == [
        ('i_loop_num_blocks', '1 + (-1 + n) / i_loop_block_size'),
        ('i_loop_block_start', '(i_loop_block_idx - 1)*i_loop_block_size + 1'),
        ('i_loop_block_end', 'min(i_loop_block_idx*i_loop_block_size, n)'),
        ('i_loop_iter_num', 'i_loop_block_start + i_loop_local - 1'),
        ('i', 'i_loop_iter_num'), ('c(1)', 'c(1) + i'), ('a(i)', 'real(i)')
    ]


@pytest.mark.parametrize('frontend', available_frontends())
def test_2d_splitting(frontend):
    fcode = """
    subroutine test_2d_splitting(a, b, n)
      implicit none
      integer, intent(in) :: n
      real(kind=8), intent(inout) :: a(n)
      real(kind=8), intent(inout) :: b(n,n)
      real(kind=8) :: c(n)
      integer :: i
      !$loki driver-loop
      do i=1,n
        a(i) = i
        c(1) = a(i)
        b(:,i) = a(i)
      end do
    end subroutine test_2d_splitting
        """
    block_size = 117
    routine = Subroutine.from_source(fcode, frontend=frontend)
    loops = FindNodes(ir.Loop).visit(routine.ir)
    num_loops = len(loops)
    num_vars = len(routine.variable_map)
    with pragmas_attached(routine, ir.Loop):
        loops = find_driver_loops(routine.body, targets=None)
    split_loop(routine, loops[0], block_size)
    loops = FindNodes(ir.Loop).visit(routine.ir)

    assert len(loops) == num_loops + 1, \
        f"Total number of loops transformation is: {len(loops)} but expected {num_loops + 1}"
    assert len(routine.variable_map) == num_vars + LOKI_LOOP_SLIT_VAR_ADDITION, (
        f"Total number of variables after loop splitting is: {len(routine.variable_map)} "
        f"but expected {num_vars + LOKI_LOOP_SLIT_VAR_ADDITION}"
    )

    assert loop_symbols(routine.ir) == [
        ('i_loop_block_idx', 1, 'i_loop_num_blocks', None),
        ('i_loop_local', 1, 'i_loop_block_end - i_loop_block_start + 1', None)
    ]
    # Ensure inner loop bound has been parametrised correctly
    assert routine.variable_map['i_loop_block_size'].type.initial == block_size

    assert assignment_symbols(routine.body) == [
        ('i_loop_num_blocks', '1 + (-1 + n) / i_loop_block_size'),
        ('i_loop_block_start', '(i_loop_block_idx - 1)*i_loop_block_size + 1'),
        ('i_loop_block_end', 'min(i_loop_block_idx*i_loop_block_size, n)'),
        ('i_loop_iter_num', 'i_loop_block_start + i_loop_local - 1'),
        ('i', 'i_loop_iter_num'), ('a(i)', 'i'), ('c(1)', 'a(i)'), ('b(:, i)', 'a(i)')
    ]



@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('block_size', [117])
@pytest.mark.parametrize('n', [500])
def test_3d_splitting(tmp_path, frontend, block_size, n):
    fcode = """
    subroutine test_3d_splitting(a, b, c, n)
      implicit none
      integer, intent(in) :: n
      real(kind=8), intent(inout) :: a(n)
      real(kind=8), intent(inout) :: b(2,n)
      real(kind=8), intent(inout) :: c(2,2,n)
      integer :: i
      !$loki driver-loop
      do i=1,n
        a(i) = i
        b(:,i) = a(i)
        c(:,:,i) = a(i)
      end do
    end subroutine test_3d_splitting
        """
    routine = Subroutine.from_source(fcode, frontend=frontend)
    loops = FindNodes(ir.Loop).visit(routine.ir)
    num_loops = len(loops)
    num_vars = len(routine.variable_map)
    with pragmas_attached(routine, ir.Loop):
        loops = find_driver_loops(routine.body, targets=None)
    split_loop(routine, loops[0], block_size)
    loops = FindNodes(ir.Loop).visit(routine.ir)

    assert len(loops) == num_loops + 1, \
        f"Total number of loops transformation is: {len(loops)} but expected {num_loops + 1}"
    assert len(routine.variable_map) == num_vars + LOKI_LOOP_SLIT_VAR_ADDITION, (
        f"Total number of variables after loop splitting is: {len(routine.variable_map)} "
        f"but expected {num_vars + LOKI_LOOP_SLIT_VAR_ADDITION}"
    )

    filepath = tmp_path / (f'{routine.name}_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname=routine.name)

    a = np.zeros(n, order='F')
    b = np.zeros((2,n), order='F')
    c = np.zeros((2,2,n), order='F')
    function(a, b, c, n)
    a_ref = np.linspace(1, n, n)
    b_ref = np.tile(a_ref, (2, 1))
    c_ref = np.tile(a_ref, (2,2,1))
    assert np.array_equal(a, a_ref), "a should be equal to a_ref=(1, 2, ..., n)"
    assert np.array_equal(b, b_ref), "b should equal b_ref"
    assert np.array_equal(c, c_ref), "c should equal c_ref"

    clean_test(filepath)

"""
--------------------------------------------------------------------------------
Blocking tests

Tests that variables are correctly blocked, and that blocked loops produce
the correct output.
--------------------------------------------------------------------------------
"""

@pytest.mark.parametrize('frontend', available_frontends())
def test_1d_blocking(frontend):
    """
    Apply loop blocking of simple loops into two loops
    """
    fcode = """
subroutine test_1d_blocking(a, b, n)
  implicit none
  integer, intent(in) :: n
  real(kind=8), intent(inout) :: a(n)
  real(kind=8), intent(inout) :: b(n)
  integer :: i
  !$loki driver-loop
  do i=1,n
    a(i) = real(i)
  end do
end subroutine test_1d_blocking
    """
    block_size = 117
    routine = Subroutine.from_source(fcode, frontend=frontend)
    loops = FindNodes(ir.Loop).visit(routine.ir)
    with pragmas_attached(routine, ir.Loop):
        loops = find_driver_loops(routine.body,
                                  targets=None)

    num_loops = len(loops)
    num_vars = len(routine.variable_map)

    splitting_vars, inner_loop, outer_loop = split_loop(routine, loops[0], block_size)
    loops = FindNodes(ir.Loop).visit(routine.ir)
    assert len(loops) == num_loops + 1, \
        f"Total number of loops transformation is: {len(loops)} but expected {num_loops + 1}"
    assert len(routine.variable_map) == num_vars + LOKI_LOOP_SLIT_VAR_ADDITION, (
        f"Total number of variables after loop splitting is: {len(routine.variable_map)} "
        f"but expected {num_vars + LOKI_LOOP_SLIT_VAR_ADDITION}"
    )
    num_vars = len(routine.variable_map)

    blocking_indices = ['i']
    block_loop_arrays(routine, splitting_vars, inner_loop, outer_loop, blocking_indices)
    for var in FindVariables().visit(inner_loop.body):
        if isinstance(var, Array):
            for idx in blocking_indices:
                assert idx not in var.dimensions, "The variable should be blocked and the local variable used"

    assert len(routine.variable_map) == num_vars+1, "Expected 1 loop blocking to be added"

    assert loop_symbols(routine.ir) == [
        ('i_loop_block_idx', 1, 'i_loop_num_blocks', None),
        ('i_loop_local', 1, 'i_loop_block_end - i_loop_block_start + 1', None)
    ]
    # Ensure inner loop bound has been parametrised correctly
    assert routine.variable_map['i_loop_block_size'].type.initial == block_size

    assert assignment_symbols(routine.body) == [
        ('i_loop_num_blocks', '1 + (-1 + n) / i_loop_block_size'),
        ('i_loop_block_start', '(i_loop_block_idx - 1)*i_loop_block_size + 1'),
        ('i_loop_block_end', 'min(i_loop_block_idx*i_loop_block_size, n)'),
        ('a_block(1:i_loop_block_end - i_loop_block_start + 1)', 'a(i_loop_block_start:i_loop_block_end)'),
        ('i_loop_iter_num', 'i_loop_block_start + i_loop_local - 1'),
        ('i', 'i_loop_iter_num'), ('a_block(i_loop_local)', 'real(i)'),
        ('a(i_loop_block_start:i_loop_block_end)', 'a_block(1:i_loop_block_end - i_loop_block_start + 1)')
    ]
    assert array_access_symbols(inner_loop.body) == [('a_block', ('i_loop_local',))]


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('block_size', [117])
@pytest.mark.parametrize('n', [500])
def test_1d_blocking_multi_intent(tmp_path, frontend, block_size, n):
    """
    Apply loop blocking of simple loops into two loops
    """
    fcode = """
subroutine test_1d_blocking_multi_intent(a, b, n)
  implicit none
  integer, intent(in) :: n
  real(kind=8), intent(in) :: a(n)
  real(kind=8), intent(inout) :: b(n)
  integer :: i
  !$loki driver-loop
  do i=1,n
    b(i) = b(i) + a(i)*a(i)
  end do
end subroutine test_1d_blocking_multi_intent
    """
    routine = Subroutine.from_source(fcode, frontend=frontend)
    loops = FindNodes(ir.Loop).visit(routine.ir)
    with pragmas_attached(routine, ir.Loop):
        loops = find_driver_loops(routine.body,
                                  targets=None)

    num_loops = len(loops)
    num_vars = len(routine.variable_map)
    splitting_vars, inner_loop, outer_loop = split_loop(routine, loops[0], block_size)
    loops = FindNodes(ir.Loop).visit(routine.ir)
    assert len(loops) == num_loops + 1, \
        f"Total number of loops transformation is: {len(loops)} but expected {num_loops + 1}"
    assert len(routine.variable_map) == num_vars + LOKI_LOOP_SLIT_VAR_ADDITION, (
        f"Total number of variables after loop splitting is: {len(routine.variable_map)} "
        f"but expected {num_vars + LOKI_LOOP_SLIT_VAR_ADDITION}"
    )

    num_vars = len(routine.variable_map)
    blocking_indices = ['i']
    block_loop_arrays(routine, splitting_vars, inner_loop, outer_loop, blocking_indices)

    assert len(routine.variable_map) == num_vars+2, "Expected 2 loop blocking to be added"
    for var in FindVariables().visit(inner_loop.body):
        if isinstance(var, Array):
            for idx in blocking_indices:
                assert idx not in var.dimensions, "The variable should be blocked and the local variable used"
    block_loop_arrays(routine, splitting_vars, inner_loop, outer_loop, blocking_indices=['i'])


    filepath = tmp_path / (f'{routine.name}_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname=routine.name)

    a = np.linspace(1, n, n)
    b = np.ones(n, order='F')
    a_ref = np.linspace(1, n, n)
    b_ref = b + a*a
    function(a, b, n)
    assert np.array_equal(a, a_ref), "a should be equal to a_ref=(1, 2, ..., n)"
    assert np.array_equal(b, b_ref), "b should equal to (2, 5, ..., 1 + n^2)"
    clean_test(filepath)


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('block_size', [117])
@pytest.mark.parametrize('n', [500])
def test_2d_blocking(tmp_path, frontend, block_size, n):
    fcode = """
    subroutine test_2d_blocking(a, b, n)
      implicit none
      integer, intent(in) :: n
      real(kind=8), intent(inout) :: a(n)
      real(kind=8), intent(inout) :: b(n,n)
      real(kind=8) :: c(n)
      integer :: i
      !$loki driver-loop
      do i=1,n
        a(i) = i
        c(1) = a(i)
        b(:,i) = a(i)
      end do
    end subroutine test_2d_blocking
        """
    routine = Subroutine.from_source(fcode, frontend=frontend)
    loops = FindNodes(ir.Loop).visit(routine.ir)
    num_loops = len(loops)
    num_vars = len(routine.variable_map)
    with pragmas_attached(routine, ir.Loop):
        loops = find_driver_loops(routine.body, targets=None)
    splitting_vars, inner_loop, outer_loop = split_loop(routine, loops[0], block_size)
    loops = FindNodes(ir.Loop).visit(routine.ir)

    assert len(loops) == num_loops + 1, \
        f"Total number of loops transformation is: {len(loops)} but expected {num_loops + 1}"
    assert len(routine.variable_map) == num_vars + LOKI_LOOP_SLIT_VAR_ADDITION, (
        f"Total number of variables after loop splitting is: {len(routine.variable_map)} "
        f"but expected {num_vars + LOKI_LOOP_SLIT_VAR_ADDITION}"
    )

    num_vars = len(routine.variable_map)
    blocking_indices = ['i']
    block_loop_arrays(routine, splitting_vars, inner_loop, outer_loop, blocking_indices)

    assert len(routine.variable_map) == num_vars + 2, "Expected 2 loop blocking to be added"
    for var in FindVariables().visit(inner_loop.body):
        if isinstance(var, Array):
            for idx in blocking_indices:
                assert idx not in var.dimensions, "The variable should be blocked and the local variable used"

    filepath = tmp_path / (f'{routine.name}_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname=routine.name)

    a = np.zeros(n, order='F')
    b = np.zeros((n,n), order='F')
    function(a, b, n)
    a_ref = np.linspace(1, n, n)
    b_ref = np.tile(a_ref, (n,1))
    assert np.array_equal(a, a_ref), "a should be equal to a_ref=(1, 2, ..., n)"
    assert np.array_equal(b, b_ref), "b[:,1] should equal a and a_ref"

    clean_test(filepath)


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('block_size', [117])
@pytest.mark.parametrize('n', [500])
def test_3d_blocking(tmp_path, frontend, block_size, n):
    fcode = """
    subroutine test_3d_blocking(a, b, c, n)
      implicit none
      integer, intent(in) :: n
      real(kind=8), intent(inout) :: a(n)
      real(kind=8), intent(inout) :: b(2,n)
      real(kind=8), intent(inout) :: c(2,2,n)
      integer :: i
      !$loki driver-loop
      do i=1,n
        a(i) = i
        b(:,i) = a(i)
        c(:,:,i) = a(i)
      end do
    end subroutine test_3d_blocking
        """
    routine = Subroutine.from_source(fcode, frontend=frontend)
    loops = FindNodes(ir.Loop).visit(routine.ir)
    num_loops = len(loops)
    num_vars = len(routine.variable_map)
    with pragmas_attached(routine, ir.Loop):
        loops = find_driver_loops(routine.body, targets=None)
    splitting_vars, inner_loop, outer_loop = split_loop(routine, loops[0], block_size)
    loops = FindNodes(ir.Loop).visit(routine.ir)

    assert len(loops) == num_loops + 1, \
        f"Total number of loops transformation is: {len(loops)} but expected {num_loops + 1}"
    assert len(routine.variable_map) == num_vars + LOKI_LOOP_SLIT_VAR_ADDITION, (
        f"Total number of variables after loop splitting is: {len(routine.variable_map)} "
        f"but expected {num_vars + LOKI_LOOP_SLIT_VAR_ADDITION}"
    )

    num_vars = len(routine.variable_map)
    blocking_indices = ['i']
    block_loop_arrays(routine, splitting_vars, inner_loop, outer_loop, blocking_indices)

    assert len(routine.variable_map) == num_vars + 3, "Expected 3 loop blocking to be added"
    for var in FindVariables().visit(inner_loop.body):
        if isinstance(var, Array):
            for idx in blocking_indices:
                assert idx not in var.dimensions, "The variable should be blocked and the local variable used"

    filepath = tmp_path / (f'{routine.name}_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname=routine.name)


    a = np.zeros(n, order='F')
    b = np.zeros((2,n), order='F')
    c = np.zeros((2,2,n), order='F')
    function(a, b, c, n)
    a_ref = np.linspace(1, n, n)
    b_ref = np.tile(a_ref, (2, 1))
    c_ref = np.tile(a_ref, (2,2,1))
    assert np.array_equal(a, a_ref), "a should be equal to a_ref=(1, 2, ..., n)"
    assert np.array_equal(b, b_ref), "b should equal b_ref"
    assert np.array_equal(c, c_ref), "c should equal c_ref"

    clean_test(filepath)
loki-ecmwf-0.3.6/loki/transformations/tests/test_argument_shape.py0000664000175000017500000005164015167130205025713 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from pathlib import Path
import pytest

from loki import  Subroutine, Scheduler, Sourcefile, flatten
from loki.frontend import available_frontends, OMNI, HAVE_FP, FP
from loki.ir import CallStatement, FindNodes
from loki.transformations import (
    ArgumentArrayShapeAnalysis, ExplicitArgumentArrayShapeTransformation
)
from loki.expression import symbols as sym
from loki.types import BasicType


@pytest.fixture(scope='module', name='here')
def fixture_here():
    return Path(__file__).parent


@pytest.mark.parametrize('frontend', available_frontends())
def test_argument_shape_simple(frontend):
    """
    Test to ensure that implicit array argument shapes are correctly derived
    from the calling context, so that the driver-level shapes are propagated
    into the kernel routines.
    """

    fcode_driver = """
  SUBROUTINE trafo_driver(nlon, nlev, a, b, c)
    ! Driver routine with explicit array shapes
    INTEGER, INTENT(IN)   :: nlon, nlev  ! Dimension sizes
    INTEGER, PARAMETER    :: n = 5
    REAL, INTENT(INOUT)   :: a(nlon)
    REAL, INTENT(INOUT)   :: b(nlon,nlev)
    REAL, INTENT(INOUT)   :: c(nlon,n)

    call trafo_kernel(a, b, c)
  END SUBROUTINE trafo_driver
    """

    fcode_kernel = """
  SUBROUTINE trafo_kernel(a, b, c)
    ! Kernel routine with implicit shape array arguments
    REAL, INTENT(INOUT)   :: a(:)
    REAL, INTENT(INOUT)   :: b(:,:)
    REAL, INTENT(INOUT)   :: c(:,:)

  END SUBROUTINE trafo_kernel
    """

    kernel = Subroutine.from_source(fcode_kernel, frontend=frontend)
    driver = Subroutine.from_source(fcode_driver, frontend=frontend)
    driver.enrich(kernel)  # Attach kernel source to driver call

    # Ensure initial call uses implicit argument shapes
    calls = FindNodes(CallStatement).visit(driver.body)
    assert len(calls) == 1 and calls[0].routine
    assert len(calls[0].routine.arguments) == 3
    assert calls[0].routine.arguments[0].shape == (':', )
    assert calls[0].routine.arguments[1].shape == (':', ':')
    assert calls[0].routine.arguments[2].shape == (':', ':')

    arg_shape_trafo = ArgumentArrayShapeAnalysis()
    arg_shape_trafo.apply(driver, role='driver')

    assert kernel.arguments[0].shape == ('nlon',)
    assert kernel.arguments[1].shape == ('nlon', 'nlev')
    assert kernel.arguments[2].shape == ('nlon', 5) if frontend == OMNI else ('nlon', 'n')


@pytest.mark.parametrize('frontend', available_frontends())
def test_argument_shape_nested(frontend):
    """
    Test to ensure that implicit array argument shapes are propagated
    through multiple subroutine calls.
    """

    fcode_driver = """
  SUBROUTINE trafo_driver(nlon, nlev, a, b, c)
    ! Driver routine with explicit array shapes
    INTEGER, INTENT(IN)   :: nlon, nlev  ! Dimension sizes
    INTEGER, PARAMETER    :: n = 5
    REAL, INTENT(INOUT)   :: a(nlon)
    REAL, INTENT(INOUT)   :: b(nlon,nlev)
    REAL, INTENT(INOUT)   :: c(nlon,n)

    call trafo_kernel_a(a, b, c)
  END SUBROUTINE trafo_driver
    """

    fcode_kernel_a = """
  SUBROUTINE trafo_kernel_a(a, b, c)
    REAL, INTENT(INOUT)   :: a(:)
    REAL, INTENT(INOUT)   :: b(:,:)
    REAL, INTENT(INOUT)   :: c(:,:)

    CALL trafo_kernel_b(b, c)
  END SUBROUTINE trafo_kernel_a
    """

    fcode_kernel_b = """
  SUBROUTINE trafo_kernel_b(b, c)
    REAL, INTENT(INOUT)   :: b(:,:)
    REAL, INTENT(INOUT)   :: c(:,:)

  END SUBROUTINE trafo_kernel_b
    """

    kernel_b = Subroutine.from_source(fcode_kernel_b, frontend=frontend)
    kernel_a = Subroutine.from_source(fcode_kernel_a, frontend=frontend)
    kernel_a.enrich(kernel_b)  # Attach kernel source to call
    driver = Subroutine.from_source(fcode_driver, frontend=frontend)
    driver.enrich(kernel_a)  # Attach kernel source to call

    # Ensure initial call uses implicit argument shapes
    calls = FindNodes(CallStatement).visit(driver.body)
    assert len(calls) == 1 and calls[0].routine
    assert len(calls[0].routine.arguments) == 3
    assert tuple(a.shape for a in calls[0].routine.arguments) == ((':', ), (':', ':'), (':', ':'))

    calls = FindNodes(CallStatement).visit(kernel_a.body)
    assert len(calls) == 1 and calls[0].routine
    assert len(calls[0].routine.arguments) == 2
    assert tuple(a.shape for a in calls[0].routine.arguments) == ((':', ':'), (':', ':'))

    # Apply the shape propagation in a manual forward pass
    arg_shape_trafo = ArgumentArrayShapeAnalysis()
    arg_shape_trafo.apply(driver, role='driver')
    arg_shape_trafo.apply(kernel_a, role='kernel')

    assert kernel_a.arguments[0].shape == ('nlon',)
    assert kernel_a.arguments[1].shape == ('nlon', 'nlev')
    assert kernel_a.arguments[2].shape == ('nlon', 5) if frontend == OMNI else ('nlon', 'n')

    assert kernel_b.arguments[0].shape == ('nlon', 'nlev')
    assert kernel_b.arguments[1].shape == ('nlon', 5) if frontend == OMNI else ('nlon', 'n')


@pytest.mark.parametrize('frontend', available_frontends())
def test_argument_shape_multiple(frontend):
    """
    Test to ensure that multiple call paths are also honoured correctly.


    Note that conflicting array shape information is currently not
    detected, since the trnasformation only replaces deferred array
    dimensions (":" ).
    """

    fcode_driver = """
  SUBROUTINE trafo_driver(nlon, nlev, a, b, c)
    INTEGER, INTENT(IN)   :: nlon, nlev  ! Dimension sizes
    INTEGER, PARAMETER    :: n = 5
    REAL, INTENT(INOUT)   :: a(nlon)
    REAL, INTENT(INOUT)   :: b(nlon,nlev)
    REAL, INTENT(INOUT)   :: c(nlon,n)

    call trafo_kernel_a1(a, b, c)

    call trafo_kernel_a2(b, c)

    call trafo_kernel_a3(nlon, nlev, b, c)
  END SUBROUTINE trafo_driver
    """

    fcode_kernel_a1 = """
  SUBROUTINE trafo_kernel_a1(a, b, c)
    ! First-level kernel call, as before
    REAL, INTENT(INOUT)   :: a(:)
    REAL, INTENT(INOUT)   :: b(:,:)
    REAL, INTENT(INOUT)   :: c(:,:)

    CALL trafo_kernel_b(b, c)
  END SUBROUTINE trafo_kernel_a1
    """

    fcode_kernel_a2 = """
  SUBROUTINE trafo_kernel_a2(b, c)
    ! First-level kernel call that agrees with kernel_a1
    REAL, INTENT(INOUT)   :: b(:,:)
    REAL, INTENT(INOUT)   :: c(:,:)

    CALL trafo_kernel_b(b, c)
  END SUBROUTINE trafo_kernel_a2
    """

    fcode_kernel_a3 = """
  SUBROUTINE trafo_kernel_a3(nlon, nlev, b, c)
    ! First-level kernel call that disagrees with kernel_a1
    INTEGER, INTENT(IN) :: nlon, nlev
    REAL :: b(nlev, nlon), c(nlev, nlev)

    CALL trafo_kernel_b(b, c)
  END SUBROUTINE trafo_kernel_a3
    """

    fcode_kernel_b = """
  SUBROUTINE trafo_kernel_b(b, c)
    ! Second-level kernel call
    REAL, INTENT(INOUT)   :: b(:,:)
    REAL, INTENT(INOUT)   :: c(:,:)

  END SUBROUTINE trafo_kernel_b
    """

    kernel_b = Subroutine.from_source(fcode_kernel_b, frontend=frontend)
    kernel_a1 = Subroutine.from_source(fcode_kernel_a1, frontend=frontend)
    kernel_a1.enrich(kernel_b)  # Attach kernel source to call
    kernel_a2 = Subroutine.from_source(fcode_kernel_a2, frontend=frontend)
    kernel_a2.enrich(kernel_b)  # Attach kernel source to call
    kernel_a3 = Subroutine.from_source(fcode_kernel_a3, frontend=frontend)
    kernel_a3.enrich(kernel_b)  # Attach kernel source to call
    driver = Subroutine.from_source(fcode_driver, frontend=frontend)
    driver.enrich(kernel_a1)  # Attach kernel source to call
    driver.enrich(kernel_a2)  # Attach kernel source to call
    driver.enrich(kernel_a3)  # Attach kernel source to call

    # Ensure initial call uses implicit argument shapes
    calls = FindNodes(CallStatement).visit(driver.body)
    assert len(calls) == 3 and all(c.routine for c in calls)
    assert tuple(a.shape for a in calls[0].routine.arguments) == ((':', ), (':', ':'), (':', ':'))
    assert tuple(a.shape for a in calls[1].routine.arguments) == ((':', ':'), (':', ':'))
    assert tuple(a.shape for a in calls[2].routine.arguments[2:]) == (('nlev', 'nlon'), ('nlev', 'nlev'))

    # Apply the legal shape propagation in a manual forward pass
    arg_shape_trafo = ArgumentArrayShapeAnalysis()
    arg_shape_trafo.apply(driver, role='driver')
    arg_shape_trafo.apply(kernel_a1, role='kernel')
    arg_shape_trafo.apply(kernel_a2, role='kernel')
    arg_shape_trafo.apply(kernel_b, role='kernel')

    # Check that the agreeable argument shapes indeed propagate
    assert kernel_a1.arguments[0].shape == ('nlon',)
    assert kernel_a1.arguments[1].shape == ('nlon', 'nlev')
    assert kernel_a1.arguments[2].shape == ('nlon', 5) if frontend == OMNI else ('nlon', 'n')

    assert kernel_a2.arguments[0].shape == ('nlon', 'nlev')
    assert kernel_a2.arguments[1].shape == ('nlon', 5) if frontend == OMNI else ('nlon', 'n')

    assert kernel_b.arguments[0].shape == ('nlon', 'nlev')
    assert kernel_b.arguments[1].shape == ('nlon', 5) if frontend == OMNI else ('nlon', 'n')

    # Now we apply conflicting information and ensure that it completes
    # and does not override the derived shape.
    # TODO: We should eventually provide an option to fail here, so that
    # conflicting shape info can be detected and dealt with, but that's
    # for the future. A failure condition can then be inserted here.

    arg_shape_trafo.apply(kernel_a3, role='kernel')
    assert kernel_b.arguments[0].shape == ('nlon', 'nlev')
    assert kernel_b.arguments[1].shape == ('nlon', 5) if frontend == OMNI else ('nlon', 'n')


@pytest.mark.parametrize('frontend', available_frontends())
def test_argument_shape_transformation(frontend):
    """
    Test that ensures that explicit argument shapes are indeed inserted
    in a multi-layered call tree.
    """

    fcode_driver = """
  SUBROUTINE trafo_driver(nlon, nlev, a, b, c)
    INTEGER, INTENT(IN)   :: nlon, nlev  ! Dimension sizes
    INTEGER, PARAMETER    :: n = 5
    REAL, INTENT(INOUT)   :: a(nlon)
    REAL, INTENT(INOUT)   :: b(nlon,nlev)
    REAL, INTENT(INOUT)   :: c(nlon,n)

    call trafo_kernel_a1(a, b, c)

    call trafo_kernel_a2(b, c)
  END SUBROUTINE trafo_driver
    """

    fcode_kernel_a1 = """
  SUBROUTINE trafo_kernel_a1(a, b, c)
    ! First-level kernel call, as before
    REAL, INTENT(INOUT)   :: a(:)
    REAL, INTENT(INOUT)   :: b(:,:)
    REAL, INTENT(INOUT)   :: c(:,:)

    CALL trafo_kernel_b(b, c)
  END SUBROUTINE trafo_kernel_a1
    """

    fcode_kernel_a2 = """
  SUBROUTINE trafo_kernel_a2(b, c)
    ! First-level kernel call that agrees with kernel_a1
    REAL, INTENT(INOUT)   :: b(:,:)
    REAL, INTENT(INOUT)   :: c(:,:)

    CALL trafo_kernel_b(b, c)
  END SUBROUTINE trafo_kernel_a2
    """

    fcode_kernel_b = """
  SUBROUTINE trafo_kernel_b(b, c)
    ! Second-level kernel call
    REAL, INTENT(INOUT)   :: b(:,:)
    REAL, INTENT(INOUT)   :: c(:,:)

  END SUBROUTINE trafo_kernel_b
    """

    # Manually create subroutines and attach call-signature info
    kernel_b = Subroutine.from_source(fcode_kernel_b, frontend=frontend)
    kernel_a1 = Subroutine.from_source(fcode_kernel_a1, frontend=frontend)
    kernel_a1.enrich(kernel_b)  # Attach kernel source to call
    kernel_a2 = Subroutine.from_source(fcode_kernel_a2, frontend=frontend)
    kernel_a2.enrich(kernel_b)  # Attach kernel source to call
    driver = Subroutine.from_source(fcode_driver, frontend=frontend)
    driver.enrich(kernel_a1)  # Attach kernel source to call
    driver.enrich(kernel_a2)  # Attach kernel source to call

    # Ensure initial call uses implicit argument shapes
    calls = FindNodes(CallStatement).visit(driver.body)
    assert len(calls) == 2 and all(c.routine for c in calls)
    assert tuple(a.shape for a in calls[0].routine.arguments) == ((':', ), (':', ':'), (':', ':'))
    assert tuple(a.shape for a in calls[1].routine.arguments) == ((':', ':'), (':', ':'))

    # Apply the legal shape propagation in a manual forward pass
    arg_shape_analysis = ArgumentArrayShapeAnalysis()
    arg_shape_analysis.apply(driver)
    arg_shape_analysis.apply(kernel_a1)
    arg_shape_analysis.apply(kernel_a2)
    arg_shape_analysis.apply(kernel_b)

    # Apply the insertion of explicit array argument shapes in a backward pass
    arg_shape_trafo = ExplicitArgumentArrayShapeTransformation()
    arg_shape_trafo.apply(kernel_b)
    arg_shape_trafo.apply(kernel_a2)
    arg_shape_trafo.apply(kernel_a1)
    arg_shape_trafo.apply(driver)

    # Check that argument shapes have been applied
    assert kernel_a1.arguments[0].dimensions == ('nlon',)
    assert kernel_a1.arguments[1].dimensions == ('nlon', 'nlev')
    assert kernel_a1.arguments[2].dimensions == ('nlon', 5) if frontend == OMNI else ('nlon', 'n')
    assert 'nlon' in kernel_a1.arguments
    assert 'nlon' in kernel_a1.arguments
    assert 'n' in kernel_a1.arguments or frontend == OMNI

    assert kernel_a2.arguments[0].dimensions == ('nlon', 'nlev')
    assert kernel_a2.arguments[1].dimensions == ('nlon', 5) if frontend == OMNI else ('nlon', 'n')
    assert 'nlon' in kernel_a2.arguments
    assert 'nlon' in kernel_a2.arguments
    assert 'n' in kernel_a2.arguments or frontend == OMNI

    assert kernel_b.arguments[0].dimensions == ('nlon', 'nlev')
    assert kernel_b.arguments[1].dimensions == ('nlon', 5) if frontend == OMNI else ('nlon', 'n')
    assert 'nlon' in kernel_b.arguments
    assert 'nlon' in kernel_b.arguments
    assert 'n' in kernel_b.arguments or frontend == OMNI

    # And finally, check that scalar dimension size variables have been added to calls
    for v  in ('nlon', 'nlev') if frontend == OMNI else ('nlon', 'nlev', 'n'):
        assert (v, v) in FindNodes(CallStatement).visit(kernel_a1.body)[0].kwarguments
        assert (v, v) in FindNodes(CallStatement).visit(kernel_a2.body)[0].kwarguments
        assert (v, v) in FindNodes(CallStatement).visit(driver.body)[0].kwarguments
        assert (v, v) in FindNodes(CallStatement).visit(driver.body)[1].kwarguments


@pytest.mark.parametrize('frontend', available_frontends(skip=[(OMNI, 'OMNI module type definitions not available')]))
def test_argument_shape_transformation_import(frontend, here, tmp_path):
    """
    Test that ensures that explicit argument shapes are indeed inserted
    in a multi-layered call tree.
    """

    config = {
         'default': {
             'mode': 'idem',
             'role': 'kernel',
             'expand': True,
             'strict': True
         },
         'routines': {
             'driver': {'role': 'driver'}
         }
    }

    header = [here/'sources/projArgShape/var_module_mod.F90']
    frontend_type = frontend
    headers = [Sourcefile.from_file(filename=h, frontend=frontend_type) for h in header]
    definitions = flatten(h.modules for h in headers)
    scheduler = Scheduler(paths=here/'sources/projArgShape', config=config, frontend=frontend,
                          definitions=definitions, xmods=[tmp_path])
    scheduler.process(transformation=ArgumentArrayShapeAnalysis())
    scheduler.process(transformation=ExplicitArgumentArrayShapeTransformation())

    item_map = {item.name: item for item in scheduler.items}
    driver = item_map['driver_mod#driver'].source['driver']
    kernel_a = item_map['kernel_a_mod#kernel_a'].source['kernel_a']
    kernel_a1 = item_map['kernel_a1_mod#kernel_a1'].source['kernel_a1']
    kernel_b = item_map['kernel_b_mod#kernel_b'].source['kernel_b']

    # Check that argument shapes have been applied
    assert kernel_a.arguments[0].dimensions == ('nlon',)
    assert kernel_a.arguments[1].dimensions == ('nlon', 'nlev')
    assert kernel_a.arguments[2].dimensions == ('nlon', 'n')
    assert 'nlon' in kernel_a.arguments
    assert 'nlon' in kernel_a.arguments
    assert 'n' not in kernel_a.arguments

    assert kernel_b.arguments[0].dimensions == ('nlon', 'nlev')
    assert kernel_b.arguments[1].dimensions == ('nlon', 'n')
    assert 'nlon' in kernel_b.arguments
    assert 'nlon' in kernel_b.arguments
    assert 'n' not in kernel_b.arguments

    assert kernel_a1.arguments[0].dimensions == ('nlon', 'nlev')
    assert kernel_a1.arguments[1].dimensions == ('nlon', 'n')
    assert 'nlon' in kernel_a1.arguments
    assert 'nlon' in kernel_a1.arguments
    assert 'n' in kernel_a1.arguments

    # And finally, check that scalar dimension size variables have been added to calls
    for v in ('nlon', 'nlev'):
        assert (v, v) in FindNodes(CallStatement).visit(driver.body)[0].kwarguments
        assert (v, v) in FindNodes(CallStatement).visit(driver.body)[1].kwarguments
    for v in ('nlon', 'nlev', 'n'):
        assert (v, v) in FindNodes(CallStatement).visit(kernel_a.body)[0].kwarguments


@pytest.mark.skipif(not HAVE_FP, reason="Assumed size declarations only supported for FP")
@pytest.mark.parametrize('transform', [True, False])
def test_argument_size_assumed_size(transform):
    """
    Test to ensure that assumed size arguments are correctly sized
    from the calling context, so that the driver-level sizes are propagated
    into the kernel routines.
    """

    fcode_driver = """
  SUBROUTINE trafo_driver(nlon, nlev, a, b, c, d, e)
    ! Driver routine with explicit array shapes
    INTEGER, INTENT(IN)   :: nlon, nlev  ! Dimension sizes
    INTEGER, PARAMETER    :: n = 5
    REAL, INTENT(INOUT)   :: a(nlon)
    REAL, INTENT(INOUT)   :: b(nlon,nlev,n)
    REAL, INTENT(INOUT)   :: c(nlon,nlev,n)
    REAL, INTENT(INOUT)   :: d(nlon,nlev)
    REAL, INTENT(INOUT)   :: e(2,4,nlon,nlev)

    call trafo_kernel(nlon, a, b, c, d(:,1:2), e)
  END SUBROUTINE trafo_driver
    """

    fcode_kernel = """
  SUBROUTINE trafo_kernel(nlon, a, b, c, d, e)
    ! Kernel routine with implicit shape array arguments
    INTEGER, INTENT(IN)   :: nlon
    REAL, INTENT(INOUT)   :: a(*)
    REAL, INTENT(INOUT)   :: b(nlon,*)
    REAL, INTENT(INOUT)   :: c(nlon,0:*)
    REAL, INTENT(INOUT)   :: d(*)
    REAL, INTENT(INOUT)   :: e(2,4,3:*)

  END SUBROUTINE trafo_kernel
    """

    kernel = Subroutine.from_source(fcode_kernel, frontend=FP)
    driver = Subroutine.from_source(fcode_driver, frontend=FP)
    driver.enrich(kernel)  # Attach kernel source to driver call

    # Ensure initial call uses assumed size declarations
    calls = FindNodes(CallStatement).visit(driver.body)
    assert len(calls) == 1 and calls[0].routine
    assert len(calls[0].routine.arguments) == 6
    assert calls[0].routine.arguments[0] == 'nlon'
    assert calls[0].routine.arguments[1].shape == ('*', )
    assert calls[0].routine.arguments[2].shape == ('nlon', '*')
    assert calls[0].routine.arguments[3].shape == ('nlon', '0:*')
    assert calls[0].routine.arguments[4].shape == ('*',)
    assert calls[0].routine.arguments[5].shape == (2, 4, '3:*',)

    arg_shape_trafo = ArgumentArrayShapeAnalysis()
    arg_shape_trafo.apply(driver, role='driver')

    assert kernel.arguments[1].shape == ('nlon',)
    assert kernel.arguments[2].shape == ('nlon', 'nlev * n')

    assert kernel.arguments[3].shape[0] == 'nlon'
    assert isinstance(kernel.arguments[3].shape[1], sym.RangeIndex)
    assert kernel.arguments[3].shape[1].lower == 0
    assert kernel.arguments[3].shape[1].upper == '-1 + n*nlev'

    assert kernel.arguments[4].shape == ('2*nlon',)

    assert kernel.arguments[5].shape[0] == 2
    assert kernel.arguments[5].shape[1] == 4
    assert isinstance(kernel.arguments[5].shape[2], sym.RangeIndex)
    assert kernel.arguments[5].shape[2].lower == 3
    assert kernel.arguments[5].shape[2].upper == '2 + nlev*nlon'

    if transform:
        arg_shape_trafo = ExplicitArgumentArrayShapeTransformation()
        arg_shape_trafo.apply(kernel)
        arg_shape_trafo.apply(driver)

        # check that the driver side call was updated
        calls = FindNodes(CallStatement).visit(driver.body)
        assert len(calls) == 1
        assert len(calls[0].arguments) + len(calls[0].kwarguments) == 8
        assert calls[0].kwarguments[0][1] in ['n', 'nlev']
        assert calls[0].kwarguments[1][1] in ['n', 'nlev']
        assert calls[0].kwarguments[0][1] != calls[0].kwarguments[1][1]

        # check that the kernel argument declarations were updated
        arguments = kernel.arguments

        assert len(arguments) == 8
        assert arguments[6] in ['n', 'nlev']
        assert arguments[7] in ['n', 'nlev']
        assert arguments[6] != arguments[7]

        assert arguments[6].type.dtype == BasicType.INTEGER
        assert arguments[7].type.dtype == BasicType.INTEGER

        # check array argument declarations were updated correctly
        assert arguments[1].dimensions == ('nlon',)
        assert arguments[2].dimensions == ('nlon', 'nlev * n')

        assert arguments[3].dimensions[0] == 'nlon'
        assert isinstance(arguments[3].dimensions[1], sym.RangeIndex)
        assert arguments[3].dimensions[1].lower == 0
        assert arguments[3].dimensions[1].upper == '-1 + n*nlev'

        assert arguments[4].dimensions == ('2*nlon',)

        assert arguments[5].dimensions[0] == 2
        assert arguments[5].dimensions[1] == 4
        assert isinstance(arguments[5].dimensions[2], sym.RangeIndex)
        assert arguments[5].dimensions[2].lower == 3
        assert arguments[5].dimensions[2].upper == '2 + nlev*nlon'
loki-ecmwf-0.3.6/loki/transformations/tests/test_transform_loop.py0000664000175000017500000014243515167130205025760 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

# pylint: disable=too-many-lines
import itertools
import pytest
import numpy as np

from loki import Subroutine
from loki.jit_build import jit_compile, clean_test
from loki.frontend import available_frontends
from loki.ir import (
    is_loki_pragma, pragmas_attached, FindNodes, Loop, Conditional,
    Assignment, FindVariables, nodes as ir
)

from loki.transformations.transform_loop import (
    do_loop_interchange, do_loop_fusion, do_loop_fission, do_loop_unroll,
    TransformLoopsTransformation
)


def loop_variables(node):
    return [loop.variable for loop in FindNodes(Loop).visit(node)]


def loop_symbols(node):
    return [
        (loop.variable, loop.bounds.start, loop.bounds.stop, loop.bounds.step)
        for loop in FindNodes(Loop).visit(node)
    ]


def assignment_symbols(node):
    return [(assign.lhs, assign.rhs) for assign in FindNodes(Assignment).visit(node)]


def conditional_strs(node):
    return [cond.condition for cond in FindNodes(Conditional).visit(node)]


def pragma_strs(node):
    return [pragma.content for pragma in FindNodes(ir.Pragma).visit(node)]


def variable_shape(routine, name):
    shape = routine.variable_map[name].shape
    return tuple(shape) if shape else None


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_interchange_plain(frontend):
    """
    Apply loop interchange for two loops without further arguments.
    """
    fcode = """
subroutine transform_loop_interchange_plain(a, m, n)
  integer, intent(out) :: a(m, n)
  integer, intent(in) :: m, n
  integer :: i, j

  !$loki loop-interchange
  do i=1,n
    do j=1,m
      a(j, i) = i + j
    end do
  end do

  ! This loop is to make sure everything else stays as is
  do i=1,n
    do j=1,m
      a(j, i) = a(j, i) - 2
    end do
  end do
end subroutine transform_loop_interchange_plain
    """
    routine = Subroutine.from_source(fcode, frontend=frontend)

    loops = FindNodes(Loop).visit(routine.body)
    assert len(loops) == 4
    assert loop_variables(routine.body) == ['i', 'j', 'i', 'j']
    assert assignment_symbols(routine.body) == [
        ('a(j, i)', 'i + j'), ('a(j, i)', 'a(j, i) - 2')
    ]

    do_loop_interchange(routine)

    loops = FindNodes(Loop).visit(routine.body)
    assert len(loops) == 4
    assert loop_variables(routine.body) == ['j', 'i', 'i', 'j']
    assert loop_symbols(routine.body) == [
        ('j', 1, 'm', None), ('i', 1, 'n', None),
        ('i', 1, 'n', None), ('j', 1, 'm', None)
    ]
    assert assignment_symbols(routine.body) == [
        ('a(j, i)', 'i + j'), ('a(j, i)', 'a(j, i) - 2')
    ]


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_interchange(frontend):
    """
    Apply loop interchange for three loops with specified order.
    """
    fcode = """
subroutine transform_loop_interchange(a, m, n, nclv)
  integer, intent(out) :: a(m, n, nclv)
  integer, intent(in) :: m, n, nclv
  integer :: i, j, k

!$loki loop-interchange (j, i, k)
!$loki some-pragma
  do k=1,nclv
!$loki more-pragma
    do i=1,n
!$loki other-pragma
      do j=1,m
        a(j, i, k) = i + j + k
      end do
    end do
  end do

  ! This loop is to make sure everything else stays as is
  do k=1,nclv
    do i=1,n
      do j=1,m
        a(j, i, k) = a(j, i, k) - 3
      end do
    end do
  end do
end subroutine transform_loop_interchange
    """
    routine = Subroutine.from_source(fcode, frontend=frontend)

    loops = FindNodes(Loop).visit(routine.body)
    assert len(loops) == 6
    assert loop_variables(routine.body) == ['k', 'i', 'j', 'k', 'i', 'j']
    assert assignment_symbols(routine.body) == [
        ('a(j, i, k)', 'i + j + k'), ('a(j, i, k)', 'a(j, i, k) - 3')
    ]
    with pragmas_attached(routine, Loop):
        assert is_loki_pragma(loops[0].pragma, starts_with='some-pragma')
        assert is_loki_pragma(loops[1].pragma, starts_with='more-pragma')
        assert is_loki_pragma(loops[2].pragma, starts_with='other-pragma')

    do_loop_interchange(routine)

    loops = FindNodes(Loop).visit(routine.body)
    assert len(loops) == 6
    assert loop_variables(routine.body) == ['j', 'i', 'k', 'k', 'i', 'j']
    assert loop_symbols(routine.body) == [
        ('j', 1, 'm', None), ('i', 1, 'n', None), ('k', 1, 'nclv', None),
        ('k', 1, 'nclv', None), ('i', 1, 'n', None), ('j', 1, 'm', None)
    ]
    assert assignment_symbols(routine.body) == [
        ('a(j, i, k)', 'i + j + k'), ('a(j, i, k)', 'a(j, i, k) - 3')
    ]

    with pragmas_attached(routine, Loop):
        assert is_loki_pragma(loops[0].pragma, starts_with='some-pragma')
        assert is_loki_pragma(loops[1].pragma, starts_with='more-pragma')
        assert is_loki_pragma(loops[2].pragma, starts_with='other-pragma')


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_interchange_project(frontend):
    """
    Apply loop interchange for two loops with bounds projection.
    """
    fcode = """
subroutine transform_loop_interchange_project(a, m, n)
  integer, intent(inout) :: a(m, n)
  integer, intent(in) :: m, n
  integer :: i, j

  !$loki loop-interchange
  do i=1,n
    do j=i,m
      a(j, i) = i + j
    end do
  end do
end subroutine transform_loop_interchange_project
    """
    routine = Subroutine.from_source(fcode, frontend=frontend)

    loops = FindNodes(Loop).visit(routine.body)
    assert len(loops) == 2
    assert loop_variables(routine.body) == ['i', 'j']
    assert loop_symbols(routine.body) == [('i', 1, 'n', None), ('j', 'i', 'm', None)]
    assert assignment_symbols(routine.body) == [('a(j, i)', 'i + j')]

    do_loop_interchange(routine, project_bounds=True)

    loops = FindNodes(Loop).visit(routine.body)
    assert len(loops) == 2
    assert loop_variables(routine.body) == ['j', 'i']
    assert loop_symbols(routine.body) == [('j', 1, 'm', None), ('i', 1, 'min(n, j)', None)]
    assert assignment_symbols(routine.body) == [('a(j, i)', 'i + j')]


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('insert_loc', (False, True))
def test_transform_loop_fuse_ordering(frontend, insert_loc):
    """
    Apply loop fusion for two loops with matching iteration spaces.
    """
    fcode = f"""
subroutine transform_loop_fuse_ordering(a, b, c, n, m)
  integer, intent(out) :: a(m, n), b(m, n), c(m)
  integer, intent(in) :: n, m
  integer :: i

  !$loki loop-fusion group(1)
  !$loki loop-interchange
  do j=1,m
    do i=1,n
      a(j, i) = i + j
    enddo
  end do

  !$loki loop-fusion group(1)
  do i=1,n
    do j=1,m
      a(j, i) = i + j
    enddo
  end do

  do j=1,m
    c(j) = j
  enddo

  !$loki loop-fusion group(1) {'insert-loc' if insert_loc else ''}
  do i=1,n-1
    do j=1,m
      b(j, i) = n-i+1 + j
    enddo
  end do
end subroutine transform_loop_fuse_ordering
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)
    assert len(FindNodes(Loop).visit(routine.body)) == 7
    do_loop_interchange(routine)
    do_loop_fusion(routine)
    loops = FindNodes(Loop).visit(routine.body)
    assert len(loops) == 5
    loop_0_vars = [var.name.lower() for var in FindVariables().visit(loops[0].body)]
    if insert_loc:
        assert loops[0].variable.name.lower() == 'j'
        assert 'c' in loop_0_vars
    else:
        assert loops[0].variable.name.lower() == 'i'
        assert 'c' not in loop_0_vars

@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_fuse_matching(frontend):
    """
    Apply loop fusion for two loops with matching iteration spaces.
    """
    fcode = """
subroutine transform_loop_fuse_matching(a, b, n)
  integer, intent(out) :: a(n), b(n)
  integer, intent(in) :: n
  integer :: i

  !$loki loop-fusion
  do i=1,n
    a(i) = i
  end do

  !$loki loop-fusion
  do i=1,n
    b(i) = n-i+1
  end do
end subroutine transform_loop_fuse_matching
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(Loop).visit(routine.body)) == 2
    do_loop_fusion(routine)
    assert loop_symbols(routine.body) == [('i', 1, 'n', None)]
    assert assignment_symbols(routine.body) == [('a(i)', 'i'), ('b(i)', 'n - i + 1')]
    assert pragma_strs(routine.body) == ['fused-loop group(default)']


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_fuse_subranges(frontend):
    """
    Apply loop fusion with annotated range for loops with
    non-matching iteration spaces.
    """
    fcode = """
subroutine transform_loop_fuse_subranges(a, b, n)
  integer, intent(out) :: a(n), b(n)
  integer, intent(in) :: n
  integer :: i, j

  a(:) = 0
  b(:) = 0

  !$loki loop-fusion
  do i=1,n
    a(i) = a(i) + i
  end do

  !$loki loop-fusion range(1:n)
  do j=1,15
    b(j) = b(j) + n-j+1
  end do

  !$loki loop-fusion range(1:n)
  do i=16,n
    b(i) = b(i) + n-i+1
  end do
end subroutine transform_loop_fuse_subranges
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(Loop).visit(routine.body)) == 3
    do_loop_fusion(routine)
    assert loop_symbols(routine.body) == [('i', 1, 'n', None)]
    assert conditional_strs(routine.body) == ['i <= 15', 'i >= 16']
    assert assignment_symbols(routine.body) == [
        ('a(:)', '0'), ('b(:)', '0'), ('a(i)', 'a(i) + i'),
        ('b(i)', 'b(i) + n - i + 1'), ('b(i)', 'b(i) + n - i + 1')
    ]
    assert pragma_strs(routine.body) == ['fused-loop group(default)']


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_fuse_groups(frontend):
    """
    Apply loop fusion for multiple loop fusion groups.
    """
    fcode = """
subroutine transform_loop_fuse_groups(a, b, c, n)
  integer, intent(out) :: a(n), b(n), c(n)
  integer, intent(in) :: n
  integer :: i

  c(1) = 1

  !$loki loop-fusion group(g1)
  do i=1,n
    a(i) = i
  end do

  !$loki loop-fusion group(g1)
  do i=1,n
    b(i) = n-i+1
  end do

  !$loki loop-fusion group(loop-group2)
  do i=1,n
    a(i) = a(i) + 1
  end do

  !$loki loop-fusion group(loop-group2)
  do i=1,n
    b(i) = b(i) + 1
  end do

  !$loki loop-fusion group(g1) range(1:n)
  do i=2,n
    c(i) = c(i-1) + 1
  end do
end subroutine transform_loop_fuse_groups
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(Loop).visit(routine.body)) == 5
    do_loop_fusion(routine)
    assert loop_symbols(routine.body) == [('i', 1, 'n', None), ('i', 1, 'n', None)]
    assert conditional_strs(routine.body) == ['i >= 2']
    assert assignment_symbols(routine.body) == [
        ('c(1)', '1'), ('a(i)', 'i'), ('b(i)', 'n - i + 1'),
        ('c(i)', 'c(i - 1) + 1'), ('a(i)', 'a(i) + 1'), ('b(i)', 'b(i) + 1')
    ]
    assert pragma_strs(routine.body) == ['fused-loop group(g1)', 'fused-loop group(loop-group2)']


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_fuse_failures(frontend):
    """
    Test that loop-fusion fails for known mistakes.
    """
    fcode = """
subroutine transform_loop_fuse_failures(a, b, n)
  integer, intent(out) :: a(n), b(n)
  integer, intent(in) :: n
  integer :: i

  !$loki loop-fusion group(1) range(1:n)
  do i=1,n
    a(i) = i
  end do

  !$loki loop-fusion group(1) range(0:n-1)
  do i=0,n-1
    b(i+1) = n-i
  end do
end subroutine transform_loop_fuse_failures
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)
    with pytest.raises(RuntimeError):
        do_loop_fusion(routine)


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_fuse_alignment(frontend):
    fcode = """
subroutine transform_loop_fuse_alignment(a, b, n)
  integer, intent(out) :: a(n), b(n)
  integer, intent(in) :: n
  integer :: i

  !$loki loop-fusion group(1)
  do i=1,n
    a(i) = i
  end do

  !$loki loop-fusion group(1)
  do i=0,n-1
    b(i+1) = n-i
  end do
end subroutine transform_loop_fuse_alignment
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(Loop).visit(routine.body)) == 2
    do_loop_fusion(routine)
    assert loop_symbols(routine.body) == [('i', 0, 'n', None)]
    assert conditional_strs(routine.body) == ['i >= 1', 'i <= n - 1']
    assert assignment_symbols(routine.body) == [('a(i)', 'i'), ('b(i + 1)', 'n - i')]
    assert pragma_strs(routine.body) == ['fused-loop group(1)']


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_fuse_nonmatching_lower(frontend):
    fcode = """
subroutine transform_loop_fuse_nonmatching_lower(a, b, nclv, klev)
  integer, intent(out) :: a(klev), b(klev)
  integer, intent(in) :: nclv, klev
  integer :: jl

  !$loki loop-fusion group(1)
  do jl=1,klev
    a(jl) = jl
  end do

  !$loki loop-fusion group(1)
  do jl=nclv,klev
    b(jl) = jl - nclv
  end do
end subroutine transform_loop_fuse_nonmatching_lower
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(Loop).visit(routine.body)) == 2
    do_loop_fusion(routine)

    assert loop_symbols(routine.body) == [('jl', 'min(1, nclv)', 'klev', None)]
    assert conditional_strs(routine.body) == ['jl >= 1', 'jl >= nclv']
    assert assignment_symbols(routine.body) == [('a(jl)', 'jl'), ('b(jl)', 'jl - nclv')]
    assert pragma_strs(routine.body) == ['fused-loop group(1)']


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_fuse_nonmatching_lower_annotated(frontend):
    fcode = """
subroutine transform_loop_fuse_nonmatching_lower_annotated(a, b, nclv, klev)
  integer, intent(out) :: a(klev), b(klev)
  integer, intent(in) :: nclv, klev
  integer :: jl

  !$loki loop-fusion group(1)
  do jl=1,klev
    a(jl) = jl
  end do

  !$loki loop-fusion group(1) range(1:klev)
  do jl=nclv,klev
    b(jl) = jl - nclv
  end do
end subroutine transform_loop_fuse_nonmatching_lower_annotated
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(Loop).visit(routine.body)) == 2
    do_loop_fusion(routine)

    assert loop_symbols(routine.body) == [('jl', 1, 'klev', None)]
    assert conditional_strs(routine.body) == ['jl >= nclv']
    assert assignment_symbols(routine.body) == [('a(jl)', 'jl'), ('b(jl)', 'jl - nclv')]
    assert pragma_strs(routine.body) == ['fused-loop group(1)']


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_fuse_nonmatching_upper(frontend):
    fcode = """
subroutine transform_loop_fuse_nonmatching_upper(a, b, klev)
  integer, intent(out) :: a(klev), b(klev+1)
  integer, intent(in) :: klev
  integer :: jl

  !$loki loop-fusion group(1)
  do jl=1,klev
    a(jl) = jl
  end do

  !$loki loop-fusion group(1)
  do jl=1,klev+1
    b(jl) = 2*jl
  end do
end subroutine transform_loop_fuse_nonmatching_upper
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(Loop).visit(routine.body)) == 2
    do_loop_fusion(routine)

    assert loop_symbols(routine.body) == [('jl', 1, '1 + klev', None)]
    assert conditional_strs(routine.body) == ['jl <= klev']
    assert assignment_symbols(routine.body) == [('a(jl)', 'jl'), ('b(jl)', '2*jl')]
    assert pragma_strs(routine.body) == ['fused-loop group(1)']


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_fuse_collapse(frontend):
    fcode = """
subroutine transform_loop_fuse_collapse(a, b, klon, klev)
  integer, intent(inout) :: a(klon, klev), b(klon, klev)
  integer, intent(in) :: klon, klev
  integer :: jk, jl

!$loki loop-fusion collapse(2)
  do jk=1,klev
    do jl=1,klon
      a(jl, jk) = jk
    end do
  end do

!$loki loop-fusion collapse(2)
  do jk=1,klev
    do jl=1,klon
      b(jl, jk) = jl + jk
    end do
  end do
end subroutine transform_loop_fuse_collapse
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(Loop).visit(routine.body)) == 4
    do_loop_fusion(routine)
    assert loop_symbols(routine.body) == [('jk', 1, 'klev', None), ('jl', 1, 'klon', None)]
    assert assignment_symbols(routine.body) == [('a(jl, jk)', 'jk'), ('b(jl, jk)', 'jl + jk')]
    assert pragma_strs(routine.body) == ['fused-loop group(default)']


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_fuse_collapse_nonmatching(frontend):
    fcode = """
subroutine transform_loop_fuse_collapse_nonmatching(a, b, klon, klev)
  integer, intent(inout) :: a(klon, klev+1), b(klon+1, klev)
  integer, intent(in) :: klon, klev
  integer :: jk, jl

!$loki loop-fusion collapse(2)
  do jk=1,klev+1
    do jl=1,klon
      a(jl, jk) = jk
    end do
  end do

!$loki loop-fusion collapse(2)
  do jk=1,klev
    do jl=1,klon+1
      b(jl, jk) = jl + jk
    end do
  end do
end subroutine transform_loop_fuse_collapse_nonmatching
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(Loop).visit(routine.body)) == 4
    do_loop_fusion(routine)
    assert loop_symbols(routine.body) == [('jk', 1, '1 + klev', None), ('jl', 1, '1 + klon', None)]
    assert conditional_strs(routine.body) == ['jl <= klon', 'jk <= klev']
    assert assignment_symbols(routine.body) == [('a(jl, jk)', 'jk'), ('b(jl, jk)', 'jl + jk')]
    assert pragma_strs(routine.body) == ['fused-loop group(default)']


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_fuse_collapse_range(tmp_path, frontend):
    fcode = """
subroutine transform_loop_fuse_collapse_range(a, b, klon, klev)
  integer, intent(inout) :: a(klon, klev+1), b(klon+1, klev)
  integer, intent(in) :: klon, klev
  integer :: jk, jl, start = 15

!$loki loop-fusion collapse(2)
  do jk=1,klev+1
    do jl=1,klon
      a(jl, jk) = jk
    end do
  end do

!$loki loop-fusion collapse(2) range(1:1+klev,1:klon+1)
  do jk=start,klev
    do jl=1,klon+1
      b(jl, jk) = jl + jk
    end do
  end do
end subroutine transform_loop_fuse_collapse_range
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/(f'{routine.name}_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname=routine.name)

    # Test the reference solution
    klon, klev = 32, 100
    a = np.zeros(shape=(klon, klev+1), order='F', dtype=np.int32)
    b = np.zeros(shape=(klon+1, klev), order='F', dtype=np.int32)
    function(a=a, b=b, klon=klon, klev=klev)
    assert np.all(a == np.array([list(range(1, klev+2))] * klon, order='F'))
    assert np.all(b[..., 14:] == np.array([[jl + jk for jk in range(15, klev+1)]
                                           for jl in range(1, klon+2)], order='F'))

    # Apply transformation
    assert len(FindNodes(Loop).visit(routine.body)) == 4
    do_loop_fusion(routine)
    loops = FindNodes(Loop).visit(routine.body)
    assert len(loops) == 2
    assert all(loop.bounds.start == 1 for loop in loops)
    assert sum(loop.bounds.stop == '1 + klev' for loop in loops) == 1
    assert sum(loop.bounds.stop == 'klon + 1' for loop in loops) == 1
    assert len(FindNodes(Conditional).visit(routine.body)) == 2

    fused_filepath = tmp_path/(f'{routine.name}_fused_{frontend}.f90')
    fused_function = jit_compile(routine, filepath=fused_filepath, objname=routine.name)

    # Test transformation
    klon, klev = 32, 100
    a = np.zeros(shape=(klon, klev+1), order='F', dtype=np.int32)
    b = np.zeros(shape=(klon+1, klev), order='F', dtype=np.int32)
    fused_function(a=a, b=b, klon=klon, klev=klev)
    assert np.all(a == np.array([list(range(1, klev+2))] * klon, order='F'))
    assert np.all(b[..., 14:] == np.array([[jl + jk for jk in range(15, klev+1)]
                                           for jl in range(1, klon+2)], order='F'))

    clean_test(filepath)
    clean_test(fused_filepath)


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_fission_single(frontend):
    fcode = """
subroutine transform_loop_fission_single(a, b, n)
  integer, intent(out) :: a(n), b(n)
  integer, intent(in) :: n
  integer :: j

  do j=1,n
    a(j) = j
    !$loki loop-fission
    b(j) = n-j
  end do
end subroutine transform_loop_fission_single
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(Loop).visit(routine.body)) == 1
    do_loop_fission(routine)
    assert loop_symbols(routine.body) == [('j', 1, 'n', None), ('j', 1, 'n', None)]
    assert assignment_symbols(routine.body) == [('a(j)', 'j'), ('b(j)', 'n - j')]


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_fission_nested(frontend):
    fcode = """
subroutine transform_loop_fission_nested(a, b, n)
  integer, intent(out) :: a(n), b(n)
  integer, intent(in) :: n
  integer :: j, k

  do j=1,n+1
    if (j <= n) then
      a(j) = j
!$loki loop-fission
      b(j) = n-j
    end if
  end do
end subroutine transform_loop_fission_nested
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(Loop).visit(routine.body)) == 1
    assert len(FindNodes(Conditional).visit(routine.body)) == 1
    do_loop_fission(routine)
    assert loop_symbols(routine.body) == [('j', 1, 'n + 1', None), ('j', 1, 'n + 1', None)]
    assert conditional_strs(routine.body) == ['j <= n', 'j <= n']
    assert assignment_symbols(routine.body) == [('a(j)', 'j'), ('b(j)', 'n - j')]


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_fission_nested_promote(frontend):
    fcode = """
subroutine transform_loop_fission_nested_promote(a, b, n)
  integer, intent(out) :: a(n), b(n)
  integer, intent(in) :: n
  integer :: j, k, zqxfg(5)

  do j=1,n+1
    zqxfg(2) = j
!$loki loop-fission promote(zqxfg)
    if (j <= n) then
      if (zqxfg(2) <= n) then
        a(j) = zqxfg(2)
        b(j) = n-zqxfg(2)
      end if
    end if
  end do
end subroutine transform_loop_fission_nested_promote
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(Loop).visit(routine.body)) == 1
    assert len(FindNodes(Conditional).visit(routine.body)) == 2
    assert len(FindNodes(Assignment).visit(routine.body)) == 3
    do_loop_fission(routine)
    assert loop_symbols(routine.body) == [('j', 1, 'n + 1', None), ('j', 1, 'n + 1', None)]
    assert conditional_strs(routine.body) == ['j <= n', 'zqxfg(2, j) <= n']
    assert assignment_symbols(routine.body) == [
        ('zqxfg(2, j)', 'j'), ('a(j)', 'zqxfg(2, j)'), ('b(j)', 'n - zqxfg(2, j)')
    ]
    assert variable_shape(routine, 'zqxfg') == ('5', '1 + n')


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_fission_collapse(frontend):
    fcode = """
subroutine transform_loop_fission_collapse(a, n)
  integer, intent(out) :: a(n, n+1)
  integer, intent(in) :: n
  integer :: j, k, tmp, tmp2

  tmp = 0
  do j=1,n+1
    tmp = j
    tmp2 = 0
!$loki loop-fission promote(tmp)
    do k=1,n
      tmp2 = tmp + k
!$loki loop-fission collapse(2) promote(tmp2)
      a(k, j) = tmp2
!$loki loop-fission
      a(k, j) = a(k, j) - 1
!$loki loop-fission collapse(2)
      a(k, j) = -1 + a(k, j)
    end do
    tmp = 0
  end do
end subroutine transform_loop_fission_collapse
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(Loop).visit(routine.body)) == 2
    assert len(FindNodes(Assignment).visit(routine.body)) == 8
    do_loop_fission(routine)
    assert loop_symbols(routine.body) == [
        ('j', 1, 'n + 1', None), ('j', 1, 'n + 1', None),
        ('k', 1, 'n', None), ('j', 1, 'n + 1', None),
        ('k', 1, 'n', None), ('k', 1, 'n', None),
        ('j', 1, 'n + 1', None), ('k', 1, 'n', None)
    ]
    assert assignment_symbols(routine.body) == [
        ('tmp(:)', '0'), ('tmp(j)', 'j'), ('tmp2(:, j)', '0'), ('tmp2(k, j)', 'tmp(j) + k'),
        ('a(k, j)', 'tmp2(k, j)'), ('a(k, j)', 'a(k, j) - 1'), ('a(k, j)', '-1 + a(k, j)'), ('tmp(j)', '0')
    ]
    assert variable_shape(routine, 'tmp') == ('1 + n',)
    assert variable_shape(routine, 'tmp2') == ('n', '1 + n')


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_fission_multiple(frontend):
    fcode = """
subroutine transform_loop_fission_multiple(a, b, c, n)
  integer, intent(out) :: a(n), b(n), c(n)
  integer, intent(in) :: n
  integer :: j

  do j=1,n
    a(j) = j
    !$loki loop-fission
    b(j) = n-j
    !$loki loop-fission
    c(j) = a(j) + b(j)
  end do
end subroutine transform_loop_fission_multiple
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(Loop).visit(routine.body)) == 1
    do_loop_fission(routine)
    assert loop_symbols(routine.body) == [('j', 1, 'n', None), ('j', 1, 'n', None), ('j', 1, 'n', None)]
    assert assignment_symbols(routine.body) == [('a(j)', 'j'), ('b(j)', 'n - j'), ('c(j)', 'a(j) + b(j)')]


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_fission_promote(tmp_path, frontend):
    fcode = """
subroutine transform_loop_fission_promote(a, b, n)
  integer, intent(out) :: a(n), b(n)
  integer, intent(in) :: n
  integer :: j, tmp

  do j=1,n
    a(j) = j
    tmp = j - 1
    !$loki loop-fission promote(tmp)
    b(j) = n-tmp
  end do
end subroutine transform_loop_fission_promote
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/(f'{routine.name}_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname=routine.name)

    # Test the reference solution
    n = 100
    a = np.zeros(shape=(n,), dtype=np.int32)
    b = np.zeros(shape=(n,), dtype=np.int32)
    function(a=a, b=b, n=n)
    assert np.all(a == range(1,n+1))
    assert np.all(b == range(n, 0, -1))

    # Apply transformation
    assert len(FindNodes(Loop).visit(routine.body)) == 1
    do_loop_fission(routine)

    loops = FindNodes(Loop).visit(routine.body)
    assert len(loops) == 2
    for loop in loops:
        assert loop.bounds.start == 1
        assert loop.bounds.stop == 'n'
    assert routine.variable_map['tmp'].shape == ('n',)

    fissioned_filepath = tmp_path/(f'{routine.name}_fissioned_{frontend}.f90')
    fissioned_function = jit_compile(routine, filepath=fissioned_filepath, objname=routine.name)

    # Test transformation
    n = 100
    a = np.zeros(shape=(n,), dtype=np.int32)
    b = np.zeros(shape=(n,), dtype=np.int32)
    fissioned_function(a=a, b=b, n=n)
    assert np.all(a == range(1,n+1))
    assert np.all(b == range(n, 0, -1))

    clean_test(filepath)
    clean_test(fissioned_filepath)


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_fission_promote_conflicting_lengths(frontend):
    fcode = """
subroutine transform_loop_fission_promote_conflicting_lengths(a, b, n)
  integer, intent(out) :: a(n), b(n+1)
  integer, intent(in) :: n
  integer :: j, tmp

  do j=1,n
    tmp = j - 1
    !$loki loop-fission promote(tmp)
    a(j) = tmp + 1
  end do

  do j=1,n+1
    tmp = j - 1
    !$loki loop-fission promote(tmp)
    b(j) = n-tmp
  end do
end subroutine transform_loop_fission_promote_conflicting_lengths
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(Loop).visit(routine.body)) == 2
    do_loop_fission(routine)
    assert loop_symbols(routine.body) == [
        ('j', 1, 'n', None), ('j', 1, 'n', None),
        ('j', 1, 'n + 1', None), ('j', 1, 'n + 1', None)
    ]
    assert assignment_symbols(routine.body) == [
        ('tmp(j)', 'j - 1'), ('a(j)', 'tmp(j) + 1'), ('tmp(j)', 'j - 1'), ('b(j)', 'n - tmp(j)')
    ]
    assert variable_shape(routine, 'tmp') == ('1 + n',)


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_fission_promote_array(frontend):
    fcode = """
subroutine transform_loop_fission_promote_array(a, klon, klev)
  integer, intent(inout) :: a(klon, klev)
  integer, intent(in) :: klon, klev
  integer :: jk, jl, zsupsat(klon)

  do jk=1,klev
    zsupsat(:) = 0
    do jl=1,klon
        zsupsat(jl) = jl
    end do
    !$loki loop-fission promote(ZSUPSAT)
    a(:, jk) = zsupsat(:)
  end do
end subroutine transform_loop_fission_promote_array
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(Loop).visit(routine.body)) == 2
    do_loop_fission(routine)
    assert loop_symbols(routine.body) == [('jk', 1, 'klev', None), ('jl', 1, 'klon', None), ('jk', 1, 'klev', None)]
    assert assignment_symbols(routine.body) == [
        ('zsupsat(:, jk)', '0'), ('zsupsat(jl, jk)', 'jl'), ('a(:, jk)', 'zsupsat(:, jk)')
    ]
    assert variable_shape(routine, 'zsupsat') == ('klon', 'klev')


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_fission_promote_multiple(frontend):
    fcode = """
subroutine transform_loop_fission_promote_multiple(a, klon, klev)
  integer, intent(inout) :: a(klon, klev)
  integer, intent(in) :: klon, klev
  integer :: jk, jl, zsupsat(klon), tmp

  do jk=1,klev
    zsupsat(:) = 0
    do jl=1,klon
        zsupsat(jl) = jl
    end do
    tmp = jk
    !$loki loop-fission promote(ZSUPSAT, tmp)
    a(:, jk) = zsupsat(:) + tmp
  end do
end subroutine transform_loop_fission_promote_multiple
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(Loop).visit(routine.body)) == 2
    do_loop_fission(routine)
    assert loop_symbols(routine.body) == [('jk', 1, 'klev', None), ('jl', 1, 'klon', None), ('jk', 1, 'klev', None)]
    assert assignment_symbols(routine.body) == [
        ('zsupsat(:, jk)', '0'), ('zsupsat(jl, jk)', 'jl'), ('tmp(jk)', 'jk'), ('a(:, jk)', 'zsupsat(:, jk) + tmp(jk)')
    ]
    assert variable_shape(routine, 'zsupsat') == ('klon', 'klev')
    assert variable_shape(routine, 'tmp') == ('klev',)


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_fission_multiple_promote(frontend):
    fcode = """
subroutine transform_loop_fission_multiple_promote(a, b, klon, klev, nclv)
  integer, intent(inout) :: a(klon, klev), b(klon, klev, nclv)
  integer, intent(in) :: klon, klev, nclv
  integer :: jm, jk, jl, zsupsat(klon), zqxn(klon, nclv)

  do jk=1,klev
    zsupsat(:) = 0
    do jl=1,klon
        zsupsat(jl) = jl
    end do
    !$loki loop-fission
    do jm=1,nclv
        do jl=1,klon
            zqxn(jl, jm) = jm+jl
        end do
    end do
    !$loki loop-fission promote(ZSUPSAT)
    a(:, jk) = zsupsat(:)
    !$loki loop-fission promote( zQxN )
    do jm=1,nclv
        b(:, jk, jm) = zqxn(:, jm)
    end do
  end do
end subroutine transform_loop_fission_multiple_promote
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(Loop).visit(routine.body)) == 5
    do_loop_fission(routine)
    assert loop_symbols(routine.body) == [
        ('jk', 1, 'klev', None), ('jl', 1, 'klon', None), ('jk', 1, 'klev', None),
        ('jm', 1, 'nclv', None), ('jl', 1, 'klon', None), ('jk', 1, 'klev', None),
        ('jk', 1, 'klev', None), ('jm', 1, 'nclv', None)
    ]
    assert assignment_symbols(routine.body) == [
        ('zsupsat(:, jk)', '0'), ('zsupsat(jl, jk)', 'jl'), ('zqxn(jl, jm, jk)', 'jm + jl'),
        ('a(:, jk)', 'zsupsat(:, jk)'), ('b(:, jk, jm)', 'zqxn(:, jm, jk)')
    ]
    assert variable_shape(routine, 'zsupsat') == ('klon', 'klev')
    assert variable_shape(routine, 'zqxn') == ('klon', 'nclv', 'klev')


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_fission_promote_read_after_write(tmp_path, frontend):
    fcode = """
subroutine transform_loop_fission_promote_read_after_write(a, klon, klev)
  integer, intent(inout) :: a(klon, klev)
  integer, intent(in) :: klon, klev
  integer :: jk, jl, zsupsat(klon), tmp

  do jk=1,klev
    zsupsat(:) = 0
    do jl=1,klon
        zsupsat(jl) = jl
    end do
    tmp = jk
    !$loki loop-fission
    a(:, jk) = zsupsat(:) + tmp
  end do
end subroutine transform_loop_fission_promote_read_after_write
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/(f'{routine.name}_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname=routine.name)

    # Test the reference solution
    klon, klev = 32, 100
    a = np.zeros(shape=(klon, klev), order='F', dtype=np.int32)
    function(a=a, klon=klon, klev=klev)
    assert np.all(a == np.array([[jl + jk for jk in range(1, klev+1)]
                                for jl in range(1, klon+1)], order='F'))

    # Apply transformation
    assert len(FindNodes(Loop).visit(routine.body)) == 2
    do_loop_fission(routine)

    loops = FindNodes(Loop).visit(routine.body)
    assert len(loops) == 3
    assert all(loop.bounds.start == 1 for loop in loops)
    assert sum(loop.bounds.stop == 'klev' for loop in loops) == 2
    assert routine.variable_map['zsupsat'].shape == ('klon', 'klev')
    assert routine.variable_map['tmp'].shape == ('klev',)

    fissioned_filepath = tmp_path/(f'{routine.name}_fissioned_{frontend}.f90')
    fissioned_function = jit_compile(routine, filepath=fissioned_filepath, objname=routine.name)

    # Test transformation
    klon, klev = 32, 100
    a = np.zeros(shape=(klon, klev), order='F', dtype=np.int32)
    fissioned_function(a=a, klon=klon, klev=klev)
    assert np.all(a == np.array([[jl + jk for jk in range(1, klev+1)]
                                for jl in range(1, klon+1)], order='F'))

    clean_test(filepath)
    clean_test(fissioned_filepath)


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_fission_promote_multiple_read_after_write(tmp_path, frontend):
    fcode = """
subroutine transform_loop_fission_promote_mult_r_a_w(a, b, klon, klev, nclv)
  integer, intent(inout) :: a(klon, klev), b(klon, klev, nclv)
  integer, intent(in) :: klon, klev, nclv
  integer :: jm, jk, jl, zsupsat(klon), zqxn(nclv, klon)
  ! Note the shape of zqxn, which is the reverse of the iteration space

  do jk=1,klev
    zsupsat(:) = 0
    do jl=1,klon
        zsupsat(jl) = jl
    end do
    !$loki loop-fission
    do jm=1,nclv
        do jl=1,klon
            zqxn(jm, jl) = jm+jl
        end do
    end do
    !$loki loop-fission
    a(:, jk) = zsupsat(:)
    !$loki loop-fission
    do jm=1,nclv
        b(:, jk, jm) = zqxn(jm, :)
    end do
  end do
end subroutine transform_loop_fission_promote_mult_r_a_w
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/(f'{routine.name}_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname=routine.name)

    # Test the reference solution
    klon, klev, nclv = 32, 100, 5
    a = np.zeros(shape=(klon, klev), order='F', dtype=np.int32)
    b = np.zeros(shape=(klon, klev, nclv), order='F', dtype=np.int32)
    function(a=a, b=b, klon=klon, klev=klev, nclv=nclv)
    assert np.all(a == np.array([[jl] * klev for jl in range(1, klon+1)], order='F'))
    assert np.all(b == np.array([[[jl + jm for jm in range(1, nclv+1)]] * klev
                                for jl in range(1, klon+1)], order='F'))

    # Apply transformation
    assert len(FindNodes(Loop).visit(routine.body)) == 5
    do_loop_fission(routine)

    loops = FindNodes(Loop).visit(routine.body)
    assert len(loops) == 8
    assert all(loop.bounds.start == 1 for loop in loops)
    assert sum(loop.bounds.stop == 'klev' for loop in loops) == 4
    assert sum(loop.bounds.stop == 'klon' for loop in loops) == 2
    assert sum(loop.bounds.stop == 'nclv' for loop in loops) == 2
    assert routine.variable_map['zsupsat'].shape == ('klon', 'klev')
    assert routine.variable_map['zqxn'].shape == ('nclv', 'klon', 'klev')

    fissioned_filepath = tmp_path/(f'{routine.name}_fissioned_{frontend}.f90')
    fissioned_function = jit_compile(routine, filepath=fissioned_filepath, objname=routine.name)

    # Test transformation
    klon, klev, nclv = 32, 100, 5
    a = np.zeros(shape=(klon, klev), order='F', dtype=np.int32)
    b = np.zeros(shape=(klon, klev, nclv), order='F', dtype=np.int32)
    fissioned_function(a=a, b=b, klon=klon, klev=klev, nclv=nclv)
    assert np.all(a == np.array([[jl] * klev for jl in range(1, klon+1)], order='F'))
    assert np.all(b == np.array([[[jl + jm for jm in range(1, nclv+1)]] * klev
                                for jl in range(1, klon+1)], order='F'))

    clean_test(filepath)
    clean_test(fissioned_filepath)


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_fusion_fission(tmp_path, frontend):
    fcode = """
subroutine transform_loop_fusion_fission(a, b, klon, klev)
  integer, intent(inout) :: a(klon, klev), b(klon, klev)
  integer, intent(in) :: klon, klev
  integer :: jk, jl, zsupsat(klon)

!$loki loop-fusion
  do jk=1,klev
    do jl=1,klon
      a(jl, jk) = jk
    end do
  end do

!$loki loop-fusion
  do jk=1,klev
    do jl=1,klon
      zsupsat(jl) = jl
    end do
    !$loki loop-fission promote(zsupsat)
    b(:, jk) = a(:, jk) + zsupsat(:)
  end do
end subroutine transform_loop_fusion_fission
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/(f'{routine.name}_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname=routine.name)

    # Test the reference solution
    klon, klev = 32, 100
    a = np.zeros(shape=(klon, klev), order='F', dtype=np.int32)
    b = np.zeros(shape=(klon, klev), order='F', dtype=np.int32)
    function(a=a, b=b, klon=klon, klev=klev)
    assert np.all(a == np.array([list(range(1, klev+1))] * klon, order='F'))
    assert np.all(b == np.array([[jl + jk for jk in range(1, klev+1)]
                                for jl in range(1, klon+1)], order='F'))

    # Apply transformation
    assert len(FindNodes(Loop).visit(routine.body)) == 4
    do_loop_fusion(routine)
    assert len(FindNodes(Loop).visit(routine.body)) == 3
    do_loop_fission(routine)

    loops = FindNodes(Loop).visit(routine.body)
    assert len(loops) == 4
    assert all(loop.bounds.start == 1 for loop in loops)
    assert sum(loop.bounds.stop == 'klev' for loop in loops) == 2
    assert sum(loop.bounds.stop == 'klon' for loop in loops) == 2
    assert routine.variable_map['zsupsat'].shape == ('klon', 'klev')

    fissioned_filepath = tmp_path/(f'{routine.name}_fissioned_{frontend}.f90')
    fissioned_function = jit_compile(routine, filepath=fissioned_filepath, objname=routine.name)

    # Test transformation
    klon, klev = 32, 100
    a = np.zeros(shape=(klon, klev), order='F', dtype=np.int32)
    b = np.zeros(shape=(klon, klev), order='F', dtype=np.int32)
    fissioned_function(a=a, b=b, klon=klon, klev=klev)
    assert np.all(a == np.array([list(range(1, klev+1))] * klon, order='F'))
    assert np.all(b == np.array([[jl + jk for jk in range(1, klev+1)]
                                for jl in range(1, klon+1)], order='F'))

    clean_test(filepath)
    clean_test(fissioned_filepath)


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_unroll(frontend):
    fcode = """
subroutine test_transform_loop_unroll(s)
    implicit none
    integer :: a
    integer, intent(inout) :: s

    !Loop A
    !$loki loop-unroll
    do a=1, 10
        s = s + a + 1
    end do

end subroutine test_transform_loop_unroll
 """
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(Loop).visit(routine.body)) == 1
    do_loop_unroll(routine)
    assert not FindNodes(Loop).visit(routine.body)
    assert len(FindNodes(Assignment).visit(routine.body)) == 10
    assert assignment_symbols(routine.body) == [('s', f's + {i} + 1') for i in range(1, 11)]


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_unroll_step(frontend):
    fcode = """
subroutine test_transform_loop_unroll_step(s)
    implicit none
    integer :: a
    integer, intent(inout) :: s

    !Loop A
    !$loki loop-unroll
    do a=-2, 7, 2
        s = s + a + 1
    end do

end subroutine test_transform_loop_unroll_step
 """
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(Loop).visit(routine.body)) == 1
    do_loop_unroll(routine)
    assert not FindNodes(Loop).visit(routine.body)
    assert len(FindNodes(Assignment).visit(routine.body)) == 5
    assert assignment_symbols(routine.body) == [
        ('s', 's + -2 + 1'), ('s', 's + 0 + 1'), ('s', 's + 2 + 1'), ('s', 's + 4 + 1'), ('s', 's + 6 + 1')
    ]


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_unroll_non_literal_range(frontend):
    fcode = """
subroutine test_transform_loop_unroll_non_literal_range(s)
    implicit none
    integer :: a, i
    integer, intent(inout) :: s

    i = 10

    !Loop A
    !$loki loop-unroll
    do a=1, i
        s = s + a + 1
    end do

end subroutine test_transform_loop_unroll_non_literal_range
 """
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(Loop).visit(routine.body)) == 1
    do_loop_unroll(routine)
    assert loop_symbols(routine.body) == [('a', 1, 'i', None)]
    assert assignment_symbols(routine.body) == [('i', '10'), ('s', 's + a + 1')]
    assert not pragma_strs(routine.body)


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_unroll_nested(frontend):
    fcode = """
subroutine test_transform_loop_unroll_nested(s)
    implicit none
    integer :: a, b
    integer, intent(inout) :: s

    !Loop A
    !$loki loop-unroll
    do a=1, 10
        !Loop B
        do b=1, 5
            s = s + a + b + 1
        end do
    end do

end subroutine test_transform_loop_unroll_nested
 """
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(Loop).visit(routine.body)) == 2
    do_loop_unroll(routine)
    assert not FindNodes(Loop).visit(routine.body)
    assert len(FindNodes(Assignment).visit(routine.body)) == 50
    assert assignment_symbols(routine.body)[:3] == [
        ('s', 's + 1 + 1 + 1'), ('s', 's + 1 + 2 + 1'), ('s', 's + 1 + 3 + 1')
    ]
    assert assignment_symbols(routine.body)[-3:] == [
        ('s', 's + 10 + 3 + 1'), ('s', 's + 10 + 4 + 1'), ('s', 's + 10 + 5 + 1')
    ]


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_unroll_nested_restricted_depth(frontend):
    fcode = """
subroutine test_transform_loop_unroll_nested_restricted_depth(s)
    implicit none
    integer :: a, b
    integer, intent(inout) :: s

    !Loop A
    !$loki loop-unroll depth(1)
    do a=1, 10
        !Loop B
        do b=1, 5
            s = s + a + b + 1
        end do
    end do

end subroutine test_transform_loop_unroll_nested_restricted_depth
 """
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(Loop).visit(routine.body)) == 2
    do_loop_unroll(routine)
    assert loop_symbols(routine.body) == [('b', 1, 5, None)] * 10
    assert len(FindNodes(Assignment).visit(routine.body)) == 10
    assert assignment_symbols(routine.body)[:3] == [
        ('s', 's + 1 + b + 1'), ('s', 's + 2 + b + 1'), ('s', 's + 3 + b + 1')
    ]
    assert not pragma_strs(routine.body)


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_unroll_nested_restricted_depth_unrollable(frontend):
    fcode = """
subroutine test_transform_loop_unroll_nested_restricted_depth(s)
    implicit none
    integer :: a, b, i
    integer, intent(inout) :: s

    i = 10

    !Loop A
    !$loki loop-unroll depth(1)
    do a=1, i
        !Loop B
        do b=1, 5
            s = s + a + b + 1
        end do
    end do

end subroutine test_transform_loop_unroll_nested_restricted_depth
 """
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(Loop).visit(routine.body)) == 2
    do_loop_unroll(routine)
    assert loop_symbols(routine.body) == [('a', 1, 'i', None)]
    assert len(FindNodes(Assignment).visit(routine.body)) == 6
    assert assignment_symbols(routine.body) == [
        ('i', '10'), ('s', 's + a + 1 + 1'), ('s', 's + a + 2 + 1'),
        ('s', 's + a + 3 + 1'), ('s', 's + a + 4 + 1'), ('s', 's + a + 5 + 1')
    ]


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_unroll_nested_counters(tmp_path, frontend):
    fcode = """
subroutine test_transform_loop_unroll_nested_counters(s)
    implicit none

    integer :: a, b
    integer, intent(inout) :: s

    !Loop A
    !$loki loop-unroll
    do a=1, 10
        !Loop B
        do b=1, a
            s = s + a + b + 1
        end do
    end do

end subroutine test_transform_loop_unroll_nested_counters
 """
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path / f'{routine.name}_{frontend}.f90'
    function = jit_compile(routine, filepath=filepath, objname=routine.name)

    # Test the reference solution
    s = np.array(0)
    function(s=s)
    tuples = [a + b + 1 for (a, b) in itertools.product(range(1, 11), range(1, 11)) if b <= a]
    assert s == sum(tuples)

    # Apply transformation
    assert len(FindNodes(Loop).visit(routine.body)) == 2
    do_loop_unroll(routine)
    assert len(FindNodes(Loop).visit(routine.body)) == 0 and \
           len(FindNodes(Assignment).visit(routine.body)) == len(tuples)

    unrolled_filepath = tmp_path / f'{routine.name}_unrolled_{frontend}.f90'
    unrolled_function = jit_compile(routine, filepath=unrolled_filepath, objname=routine.name)

    # Test transformation
    s = np.array(0)
    unrolled_function(s=s)
    assert s == sum(a + b + 1 for (a, b) in itertools.product(range(1, 11), range(1, 11)) if b <= a)

    clean_test(filepath)
    clean_test(unrolled_filepath)


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_unroll_nested_neighbours(frontend):
    fcode = """
subroutine test_transform_loop_unroll_nested_neighbours(s)
    implicit none

    integer :: a, b, c
    integer, intent(inout) :: s

    !Loop A
    !$loki loop-unroll depth(1)
    do a=1, 10
        !Loop B
        !$loki loop-unroll
        do b=1, 5
            s = s + a + b + 1
        end do
        !Loop C
        do c=1, 5
            s = s + a + c + 1
        end do
    end do

end subroutine test_transform_loop_unroll_nested_neighbours
 """
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(Loop).visit(routine.body)) == 3
    do_loop_unroll(routine)
    assert loop_symbols(routine.body) == [('c', 1, 5, None)] * 10
    assert len(FindNodes(Assignment).visit(routine.body)) == 60
    assert assignment_symbols(routine.body)[:6] == [
        ('s', 's + 1 + 1 + 1'), ('s', 's + 1 + 2 + 1'), ('s', 's + 1 + 3 + 1'),
        ('s', 's + 1 + 4 + 1'), ('s', 's + 1 + 5 + 1'), ('s', 's + 1 + c + 1')
    ]


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('loop_interchange', [False, True])
@pytest.mark.parametrize('loop_fusion', [False, True])
@pytest.mark.parametrize('loop_fission', [False, True])
@pytest.mark.parametrize('loop_unroll', [False, True])
def test_transform_loop_transformation(frontend, loop_interchange, loop_fusion, loop_fission, loop_unroll):
    fcode = """
subroutine transform_loop()
  integer, parameter :: m = 8
  integer, parameter :: n = 16

  integer :: array(m,n)
  integer :: a(n), b(n)
  integer :: i, j, s

  !$loki loop-interchange
  do i=1,n
    do j=1,m
      array(j, i) = i + j
    end do
  end do

  !$loki loop-fusion
  do i=1,n
    a(i) = i
  end do

  !$loki loop-fusion
  do i=1,n
    b(i) = n-i+1
  end do

  do j=1,n
    a(j) = j
    !$loki loop-fission
    b(j) = n-j
  end do

  !$loki loop-unroll
  do i=1, 10
      s = s + i + 1
  end do
end subroutine transform_loop
    """

    routine = Subroutine.from_source(fcode, frontend=frontend)
    transform = TransformLoopsTransformation(loop_interchange=loop_interchange, loop_fusion=loop_fusion,
                                             loop_fission=loop_fission, loop_unroll=loop_unroll)

    num_pragmas = len(FindNodes(ir.Pragma).visit(routine.body))
    num_loops = len(FindNodes(ir.Loop).visit(routine.body))

    transform.apply(routine)
    pragmas = FindNodes(ir.Pragma).visit(routine.body)
    loops = FindNodes(ir.Loop).visit(routine.body)

    if loop_interchange:
        num_pragmas -= 1
        assert loops[0].variable == 'j'
        assert not any('loop-interchange' in pragma.content for pragma in pragmas)
        assert FindNodes(ir.Loop).visit(loops[0].body)[0].variable == 'i'

    if loop_fusion:
        num_pragmas -= 1
        num_loops -= 1
        assert not any('loop-fusion' in pragma.content for pragma in pragmas)
        assert len(FindNodes(ir.Assignment).visit(loops[2].body)) == 2

    if loop_fission:
        num_pragmas -= 1
        num_loops += 1
        assert not any('loop-fission' in pragma.content for pragma in pragmas)

    if loop_unroll:
        num_pragmas -= 1
        num_loops -= 1
        assert not any('loop-unroll' in pragma.content for pragma in pragmas)

    assert len(loops) == num_loops
    assert len(pragmas) == num_pragmas


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_unroll_before_fuse(frontend):
    fcode = """
    subroutine test_loop_unroll_before_fuse(n, map, a, b)
       integer, intent(in) :: n
       integer, intent(in) :: map(3,3)
       real, intent(inout) :: a(n)
       real, intent(in) :: b(:)

       integer :: i,j,k

       !$loki loop-unroll
       do k=1,3
          !$loki loop-unroll
          do j=1,3
            !$loki loop-fusion
            do i=1,n
              a(i) = a(i) + b(map(j,k))
            enddo
          enddo
       enddo

    end subroutine test_loop_unroll_before_fuse
"""

    routine = Subroutine.from_source(fcode, frontend=frontend)
    assert len(FindNodes(ir.Loop).visit(routine.body)) == 3

    do_loop_unroll(routine)
    loops = FindNodes(ir.Loop).visit(routine.body)
    assert len(loops) == 9
    assert all(loop.variable == 'i' for loop in loops)

    pragmas = FindNodes(ir.Pragma).visit(routine.body)
    assert len(pragmas) == 9
    assert all(p.content == 'loop-fusion' for p in pragmas)

    do_loop_fusion(routine)
    assert len(FindNodes(ir.Loop).visit(routine.body)) == 1
loki-ecmwf-0.3.6/loki/transformations/tests/test_transform_derived_types.py0000664000175000017500000014656615167130205027666 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from pathlib import Path
import pytest

from loki import (
    Sourcefile, Scheduler, ProcedureItem,
    ProcedureDeclaration, BasicType, CaseInsensitiveDict, SGraph
)
from loki.expression import Scalar, Array
from loki.frontend import available_frontends, OMNI
from loki.ir import (
    FindNodes, FindVariables, FindInlineCalls, CallStatement
)

from loki.transformations.transform_derived_types import (
    DerivedTypeArgumentsTransformation,
    TypeboundProcedureCallTransformation
)
from loki.transformations.sanitise import do_resolve_associates
#pylint: disable=too-many-lines

@pytest.fixture(scope='module', name='here')
def fixture_here():
    return Path(__file__).parent


@pytest.fixture(name='config')
def fixture_config():
    """
    Default configuration dict with basic options.
    """
    return {
        'default': {
            'mode': 'idem',
            'role': 'kernel',
            'expand': True,
            'strict': True,
            'enable_imports': True,
        },
        'routines': {
            'driver': {
                'role': 'driver',
                'expand': True,
            },
        }
    }


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_derived_type_arguments_analysis(frontend, tmp_path):
    fcode = f"""
module transform_derived_type_arguments_mod

    implicit none

    type some_derived_type
{'!$loki dimension(n)' if frontend is not OMNI else ''}
        real, allocatable :: a(:)
{'!$loki dimension(m, n)' if frontend is not OMNI else ''}
        real, allocatable :: b(:,:)
{'!$loki dimension(m, n)' if frontend is not OMNI else ''}
        real, allocatable :: c(:,:)
    end type some_derived_type

contains

    subroutine kernel(m, n, P_a, P_b, Q, R)
        integer                , intent(in)    :: m, n
        real, intent(inout)                    :: P_a(n), P_b(m, n)
        type(some_derived_type), intent(in)    :: Q
        type(some_derived_type), intent(out)   :: R
        integer :: j, k

        do j=1,n
            R%a(j) = P_a(j) + Q%a(j)
            do k=1,m
                R%b(k, j) = P_b(k, j) - Q%b(k, j) - Q%c(k, j)
            end do
        end do
    end subroutine kernel
end module transform_derived_type_arguments_mod
    """.strip()

    source = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    item = ProcedureItem(name='transform_derived_type_arguments_mod#kernel', source=source)
    source['kernel'].apply(DerivedTypeArgumentsTransformation(), role='kernel', item=item)

    # Make sure the trafo data contains the right information
    assert item.trafo_data[DerivedTypeArgumentsTransformation._key] == {
        'expansion_map': {
            'q': ('q%a', 'q%b', 'q%c'),
            'r': ('r%a', 'r%b'),
        },
        'orig_argnames': ('m', 'n', 'p_a', 'p_b', 'q', 'r')
    }

    # Make sure the trafo data is actual variable nodes with proper type information
    # but not attached to any scope
    for members in item.trafo_data[DerivedTypeArgumentsTransformation._key]['expansion_map'].values():
        for member in members:
            assert isinstance(member, (Scalar, Array))
            assert member.scope is None
            assert member.type.dtype != BasicType.DEFERRED


@pytest.mark.parametrize('all_derived_types', (False, True))
@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_derived_type_arguments_expansion_trivial_derived_type(frontend, all_derived_types, tmp_path):
    fcode = """
module transform_derived_type_arguments_mod

    implicit none

    type some_derived_type
        real :: a
        real :: b
    end type some_derived_type

contains

    subroutine caller(z)
        integer, intent(in) :: z
        type(some_derived_type) :: t_io
        type(some_derived_type) :: t_in, t_out
        integer :: m, n
        integer :: i, j

        m = 100
        n = 10

        t_in%a = real(m-1)
        t_in%b = real(n-1)

        call kernel(m, n, t_io%a, t_io%b, t_in, t_out)

    end subroutine caller

    subroutine kernel(m, n, P_a, P_b, Q, R)
        integer                , intent(in)    :: m, n
        real, intent(inout)                    :: P_a, P_b
        type(some_derived_type), intent(in)    :: Q
        type(some_derived_type), intent(out)   :: R
        integer :: j, k

        R%a = P_a + Q%a
        R%b = P_b - Q%b
    end subroutine kernel
end module transform_derived_type_arguments_mod
    """.strip()

    source = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    call_tree = [
        ProcedureItem(name='transform_derived_type_arguments_mod#caller', source=source, config={'role': 'driver'}),
        ProcedureItem(name='transform_derived_type_arguments_mod#kernel', source=source, config={'role': 'kernel'}),
    ]

    graph_dic = {call_tree[0]: [call_tree[1]]}
    graph = SGraph.from_dict(graph_dic)
    # Apply transformation
    transformation = DerivedTypeArgumentsTransformation(all_derived_types=all_derived_types)
    for item in reversed(call_tree):
        transformation.apply(item.scope_ir, role=item.role, item=item, sub_sgraph=graph.get_sub_sgraph(item))

    # all derived types, disregarding whether the derived type has pointer/allocatable/derived type members or not
    if all_derived_types:
        call_args = ('m', 'n', 't_io%a', 't_io%b', 't_in%a', 't_in%b', 't_out%a', 't_out%b')
        kernel_args = ('m', 'n', 'P_a', 'P_b', 'Q_a', 'Q_b', 'R_a', 'R_b')
    # only the derived type(s) with pointer/allocatable/derived type members, thus no changes expected!
    else:
        call_args = ('m', 'n', 't_io%a', 't_io%b', 't_in', 't_out')
        kernel_args = ('m', 'n', 'P_a', 'P_b', 'Q', 'R')

    call = FindNodes(CallStatement).visit(source['caller'].ir)[0]
    assert call.name == 'kernel'
    assert call.arguments == call_args
    assert source['kernel'].arguments == kernel_args
    assert all(v.type.intent for v in source['kernel'].arguments)

    # Make sure rescoping hasn't accidentally overwritten the
    # type information for local variables that have the same name
    # as the shape of another variable
    assert source['caller'].variable_map['m'].type.intent is None
    assert source['caller'].variable_map['n'].type.intent is None


@pytest.mark.parametrize('all_derived_types', (False, True))
@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_derived_type_arguments_expansion_trivial_derived_type_scheduler(
        frontend, all_derived_types, config, here, tmp_path
):

    proj = here / 'sources/projDerivedTypes'

    scheduler = Scheduler(paths=[proj], config=config, seed_routines=['driver'], frontend=frontend, xmods=[tmp_path])

    # Apply transformation
    transformation = DerivedTypeArgumentsTransformation(all_derived_types=all_derived_types)
    scheduler.process(transformation=transformation)

    # all derived types, disregarding whether the derived type has pointer/allocatable/derived type members or not
    if all_derived_types:
        call_args = ('m', 'n', 't_io%a', 't_io%b', 't_in%a', 't_in%b', 't_out%a', 't_out%b')
        kernel_args = ('m', 'n', 'P_a', 'P_b', 'Q_a', 'Q_b', 'R_a', 'R_b')
    # only the derived type(s) with pointer/allocatable/derived type members, thus no changes expected!
    else:
        call_args = ('m', 'n', 't_io%a', 't_io%b', 't_in', 't_out')
        kernel_args = ('m', 'n', 'P_a', 'P_b', 'Q', 'R')

    driver = scheduler["driver_mod#driver"].ir
    kernel = scheduler["kernel_mod#kernel"].ir
    calls = FindNodes(CallStatement).visit(driver.body)
    call = calls[0]
    assert call.name == 'kernel'
    assert call.arguments == call_args
    assert kernel.arguments == kernel_args
    assert all(v.type.intent for v in kernel.arguments)

    # Make sure rescoping hasn't accidentally overwritten the
    # type information for local variables that have the same name
    # as the shape of another variable
    assert driver.variable_map['m'].type.intent is None
    assert driver.variable_map['n'].type.intent is None


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_derived_type_arguments_expansion(frontend, tmp_path):
    fcode = f"""
module transform_derived_type_arguments_mod

    implicit none

    type some_derived_type
{'!$loki dimension(n)' if frontend is not OMNI else ''}
        real, allocatable :: a(:)
{'!$loki dimension(m, n)' if frontend is not OMNI else ''}
        real, allocatable :: b(:,:)
    end type some_derived_type

contains

    subroutine caller(z)
        integer, intent(in) :: z
        type(some_derived_type) :: t_io
        type(some_derived_type), allocatable :: t_in(:), t_out(:)
        integer :: m, n
        integer :: i, j

        m = 100
        n = 10

        allocate(t_io%a(n))
        allocate(t_io%b(m, n))

        do j=1,n
            t_io%a(j) = real(j)
            t_io%b(:, j) = real(j)
        end do

        allocate(t_in(z), t_out(z))

        do i=1,z
            allocate(t_in(i)%a(n))
            allocate(t_in(i)%b(m, n))
            allocate(t_out(i)%a(n))
            allocate(t_out(i)%b(m, n))

            do j=1,n
                t_in(i)%a(j) = real(i-1)
                t_in(i)%b(:, j) = real(i-1)
            end do
        end do

        do i=1,z
            call kernel(m, n, t_io%a, t_io%b, t_in(i), t_out(i))
        end do

        deallocate(t_io%a)
        deallocate(t_io%b)

        do i=1,z
            deallocate(t_in(i)%a)
            deallocate(t_in(i)%b)
            deallocate(t_out(i)%a)
            deallocate(t_out(i)%b)
        end do

        deallocate(t_in, t_out)
    end subroutine caller

    subroutine kernel(m, n, P_a, P_b, Q, R)
        integer                , intent(in)    :: m, n
        real, intent(inout)                    :: P_a(n), P_b(m, n)
        type(some_derived_type), intent(in)    :: Q
        type(some_derived_type), intent(out)   :: R
        integer :: j, k

        do j=1,n
            R%a(j) = P_a(j) + Q%a(j)
            do k=1,m
                R%b(k, j) = P_b(k, j) - Q%b(k, j)
            end do
        end do
    end subroutine kernel
end module transform_derived_type_arguments_mod
    """.strip()

    source = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    call_tree = [
        ProcedureItem(name='transform_derived_type_arguments_mod#caller', source=source, config={'role': 'driver'}),
        ProcedureItem(name='transform_derived_type_arguments_mod#kernel', source=source, config={'role': 'kernel'}),
    ]

    graph_dic = {call_tree[0]: [call_tree[1]]}
    graph = SGraph.from_dict(graph_dic)
    # Apply transformation
    transformation = DerivedTypeArgumentsTransformation()
    for item in reversed(call_tree):
        transformation.apply(item.ir, role=item.role, item=item, sub_sgraph=graph.get_sub_sgraph(item))

    # Make sure derived type arguments are flattened
    call_args = (
        'm', 'n', 't_io%a', 't_io%b', 't_in(i)%a', 't_in(i)%b',
        't_out(i)%a', 't_out(i)%b'
    )
    kernel_args = ('m', 'n', 'P_a(n)', 'P_b(m, n)', 'Q_a(:)', 'Q_b(:, :)', 'R_a(:)', 'R_b(:, :)')

    call = FindNodes(CallStatement).visit(source['caller'].ir)[0]
    assert call.name == 'kernel'
    assert call.arguments == call_args
    assert source['kernel'].arguments == kernel_args
    assert all(v.type.intent for v in source['kernel'].arguments)

    # Make sure rescoping hasn't accidentally overwritten the
    # type information for local variables that have the same name
    # as the shape of another variable
    assert source['caller'].variable_map['m'].type.intent is None
    assert source['caller'].variable_map['n'].type.intent is None


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_derived_type_arguments_inline_call(frontend, tmp_path):
    """
    Verify correct expansion of inline calls to functions
    """
    fcode_my_mod = """
module my_mod
    implicit none
    type my_type
        integer, allocatable :: a(:)
        integer :: b
    end type my_type
contains
    function kernel(r, s) result(t)
        type(my_type), intent(in) :: r, s
        real :: t
        t = sum(r%a + s%a) + r%b + s%b
    end function kernel
end module my_mod
    """.strip()

    fcode_driver = """
subroutine driver(arr, n, s, t)
    use my_mod, only: my_type, kernel
    implicit none
    type(my_type), intent(in) :: arr(n), s
    integer, intent(in) :: n
    real, intent(inout) :: t(n)
    integer :: j
    do j=1,n
        t(j) = kernel(arr(j), s)
    end do
end subroutine driver
    """.strip()

    source_my_mod = Sourcefile.from_source(fcode_my_mod, frontend=frontend, xmods=[tmp_path])
    source_driver = Sourcefile.from_source(
        fcode_driver, frontend=frontend, definitions=source_my_mod.definitions, xmods=[tmp_path]
    )

    kernel = ProcedureItem('my_mod#kernel', config={'role': 'kernel'}, source=source_my_mod)
    driver = ProcedureItem('#driver', config={'role': 'driver'}, source=source_driver)

    graph_dic = {driver: [kernel]}
    graph = SGraph.from_dict(graph_dic)

    transformation = DerivedTypeArgumentsTransformation()
    transformation.apply(kernel.ir, item=kernel, role=kernel.role)
    transformation.apply(driver.ir, item=driver, role=driver.role, sub_sgraph=graph.get_sub_sgraph(driver))

    assert kernel.trafo_data[transformation._key] == {
        'orig_argnames': ('r', 's'),
        'expansion_map': {'r': ('r%a', 'r%b'), 's': ('s%a', 's%b')}
    }

    assert kernel.ir.arguments == ('r_a(:)', 'r_b', 's_a(:)', 's_b')

    inline_calls = list(FindInlineCalls().visit(driver.ir.body))
    assert len(inline_calls) == 1
    assert inline_calls[0].parameters == ('arr(j)%a', 'arr(j)%b', 's%a', 's%b')


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_derived_type_arguments_multilevel(frontend, tmp_path):
    """
    Verify correct behaviour of the derived type argument flattening when
    used in multi-level call trees. There it is mandatory to traverse the tree from
    the leaf upwards to make sure every use of derived type members is seen by
    the calling subroutine.
    """
    fcode = """
module transform_derived_type_arguments_multilevel
    implicit none

    type some_type
        real, allocatable :: a(:), b(:), c(:)
    end type some_type

contains

    subroutine caller(n, obj)
        integer, intent(in) :: n
        type(some_type), intent(inout) :: obj

        call setup_obj(obj, n)
    end subroutine caller

    subroutine setup_obj(obj, n)
        type(some_type), intent(inout) :: obj
        integer, intent(in) :: n

        call deallocate_obj(obj)

        allocate(obj%a(n))
        allocate(obj%b(n))
    end subroutine setup_obj

    subroutine deallocate_obj(obj)
        type(some_type), intent(inout) :: obj

        if(allocated(obj%a)) deallocate(obj%a)
        if(allocated(obj%b)) deallocate(obj%b)
        if(allocated(obj%c)) deallocate(obj%c)
    end subroutine deallocate_obj

end module transform_derived_type_arguments_multilevel
    """.strip()

    source = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    orig_args = {
        'caller': ('n', 'obj'),
        'setup_obj': ('obj', 'n'),
        'deallocate_obj': ('obj',),
    }

    transformed_args = {
        'caller': ('n', 'obj'),
        'setup_obj': ('obj_a(:)', 'obj_b(:)', 'obj_c(:)', 'n'),
        'deallocate_obj': ('obj_a(:)', 'obj_b(:)', 'obj_c(:)'),
    }

    for routine in source.subroutines:
        assert routine.arguments == orig_args[routine.name.lower()]

    call_tree = [
        ProcedureItem(
            name='transform_derived_type_arguments_multilevel#caller',
            source=source, config={'role': 'driver'}
        ),
        ProcedureItem(
            name='transform_derived_type_arguments_multilevel#setup_obj',
            source=source, config={'role': 'kernel'}
        ),
        ProcedureItem(
            name='transform_derived_type_arguments_multilevel#deallocate_obj',
            source=source, config={'role': 'kernel'}
        ),
    ]

    graph_dic = {call_tree[0]: [call_tree[1]], call_tree[1]: [call_tree[2]]}
    graph = SGraph.from_dict(graph_dic)

    # Apply transformation
    transformation = DerivedTypeArgumentsTransformation()
    for item in reversed(call_tree):
        transformation.apply(item.ir, role=item.role, item=item, sub_sgraph=graph.get_sub_sgraph(item))

    for item in call_tree:
        if item.role == 'driver':
            assert not item.trafo_data[transformation._key]
        else:
            assert item.trafo_data[transformation._key]['orig_argnames'] == orig_args[item.ir.name.lower()]

    for routine in source.subroutines:
        assert routine.arguments == transformed_args[routine.name.lower()]
        assert all(a.type.intent for a in routine.arguments)


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_derived_type_arguments_expansion_nested(frontend, tmp_path):
    fcode_header = f"""
module header_mod
    implicit none
    integer, parameter :: jprb = selected_real_kind(13, 300)
    integer, parameter :: NUMBER_TWO = 2

    type some_derived_type
{'!$loki dimension(n)' if frontend is not OMNI else ''}
        real(kind=jprb), allocatable :: a(:)
{'!$loki dimension(m, n)' if frontend is not OMNI else ''}
        real(kind=jprb), allocatable :: b(:,:)
    end type some_derived_type

    type constants_type
        real(kind=jprb) :: c
        real(kind=jprb), allocatable :: other(:)
    end type constants_type
end module header_mod
    """.strip()

    fcode_bucket = """
module bucket_mod
    use header_mod, only: some_derived_type, constants_type, number_two
    implicit none
    integer, parameter :: NUMBER_FIVE = 5

    type bucket_type
        type(some_derived_type) :: a
        type(some_derived_type) :: b(NUMBER_FIVE)
        type(constants_type) :: constants(number_two)
    end type bucket_type

contains

    subroutine setup(t, m, n)
        type(some_derived_type), intent(inout) :: t
        integer, intent(in) :: m, n

        allocate(t%a(n))
        allocate(t%b(m, n))
    end subroutine setup

    subroutine teardown(t)
        type(some_derived_type), intent(inout) :: t
        deallocate(t%a)
        deallocate(t%b)
    end subroutine teardown

    subroutine init(t, m, n)
        use header_mod, only: jprb
        type(some_derived_type), intent(inout) :: t
        integer, intent(in) :: m, n
        integer j

        do j=1,n
            t%a(j) = real(j, kind=jprb)
            t%b(:, j) = real(j, kind=jprb)
        end do
    end subroutine init

    subroutine kernel(m, n, P_a, P_b, Q, R, c)
        use header_mod, only: jprb
        integer                , intent(in)    :: m, n
        real(kind=jprb), intent(in)            :: P_a(n), P_b(m, n), c
        type(some_derived_type), intent(in)    :: Q
        type(some_derived_type), intent(out)   :: R
        integer :: j, k

        do j=1,n
            R%a(j) = P_a(j) + Q%a(j) + c
            do k=1,m
                R%b(k, j) = P_b(k, j) - Q%b(k, j) + c
            end do
        end do
    end subroutine kernel
end module bucket_mod
    """.strip()

    fcode_layer = """
module layer_mod
contains
    subroutine layer(m, n, P_a, P_b, Q, R)
        use bucket_mod, only: bucket_type
        use header_mod, only: some_derived_type
        implicit none
        integer                , intent(in) :: m, n
        type(some_derived_type), intent(in) :: P_a, P_b(5)
        type(bucket_type), intent(in)       :: Q
        type(bucket_type), intent(out)      :: R
        integer :: k

        call kernel(m, n, P_a%a, P_a%b, Q%a, R%a, Q%constants(1)%c)
        do k=1,5
            call kernel(m, n, P_b(k)%a, P_b(k)%b, Q%b(k), R%b(k), Q%constants(2)%c)
        end do
    end subroutine layer
end module layer_mod
    """.strip()

    fcode_caller = """
subroutine caller(z)
    use bucket_mod, only: bucket_type, setup, init, teardown
    use layer_mod, only: layer
    implicit none

    integer, intent(in) :: z
    type(bucket_type) :: t_io
    type(bucket_type), allocatable :: t_in(:), t_out(:)
    integer :: m, n
    integer :: i, j, k

    m = 100
    n = 10

    call setup(t_io%a, m, n)
    call init(t_io%a, m, n)
    do k=1,5
        call setup(t_io%b(k), m, n)
        call init(t_io%b(k), m, n)
    end do

    allocate(t_in(z), t_out(z))

    do i=1,z
        call setup(t_in(i)%a, m, n)
        call setup(t_out(i)%a, m, n)

        do j=1,n
            t_in(i)%a%a(j) = real(i-1)
            t_in(i)%a%b(:, j) = real(i-1)
        end do

        do k=1,5
            call setup(t_in(i)%b(k), m, n)
            call setup(t_out(i)%b(k), m, n)

            do j=1,n
                t_in(i)%b(k)%a(j) = real(i-1)
                t_in(i)%b(k)%b(:, j) = real(i-1)
            end do
        end do
    end do

    do i=1,z
        call layer(m, n, t_io%a, t_io%b, t_in(i), t_out(i))
    end do

    do i=1,z
        call teardown(t_in(i)%a)
        call teardown(t_out(i)%a)

        do k=1,5
            call teardown(t_in(i)%b(k))
            call teardown(t_out(i)%b(k))
        end do
    end do

    deallocate(t_in)
    deallocate(t_out)

    do k=1,5
        call teardown(t_io%b(k))
        call teardown(t_io%b(k))
    end do
    call teardown(t_io%a)
end subroutine caller
    """.strip()

    header = Sourcefile.from_source(fcode_header, frontend=frontend, xmods=[tmp_path])
    bucket = Sourcefile.from_source(fcode_bucket, frontend=frontend, definitions=header.definitions, xmods=[tmp_path])
    layer = Sourcefile.from_source(
        fcode_layer, frontend=frontend,
        definitions=header.definitions + bucket.definitions, xmods=[tmp_path]
    )
    source = Sourcefile.from_source(
        fcode_caller, frontend=frontend, xmods=[tmp_path],
        definitions=header.definitions + bucket.definitions + layer.definitions
    )

    items = {
        'caller': ProcedureItem(
            name='#caller', source=source, config={'role': 'driver'}
        ),
        'layer': ProcedureItem(
            name='layer_mod#layer', source=layer, config={'role': 'kernel'}
        ),
        'setup': ProcedureItem(
            name='bucket_mod#setup', source=bucket, config={'role': 'kernel'}
        ),
        'init': ProcedureItem(
            name='bucket_mod#init', source=bucket, config={'role': 'kernel'}
        ),
        'kernel': ProcedureItem(
            name='bucket_mod#kernel', source=bucket, config={'role': 'kernel'}
        ),
        'teardown': ProcedureItem(
            name='bucket_mod#teardown', source=bucket, config={'role': 'kernel'}
        ),
    }

    call_tree = [
        ('caller', ['setup', 'init', 'layer', 'teardown']),
        ('setup', []),
        ('init', []),
        ('layer', ['kernel']),
        ('kernel', []),
        ('teardown', [])
    ]

    assert len(items['layer'].ir.imports) == 2

    graph_dic = {}
    graph_dic = {items[name]: [items[child] for child in successors] for name, successors in call_tree}
    graph = SGraph.from_dict(graph_dic)

    # Apply transformation
    transformation = DerivedTypeArgumentsTransformation()
    for name, _ in reversed(call_tree):
        item = items[name]
        transformation.apply(item.ir, role=item.role, item=item, sub_sgraph=graph.get_sub_sgraph(item))

    key = DerivedTypeArgumentsTransformation._key

    # Check analysis result in kernel
    assert key in items['kernel'].trafo_data
    assert items['kernel'].trafo_data[key]['expansion_map'] == {
        'q': ('q%a', 'q%b'),
        'r': ('r%a', 'r%b'),
    }
    assert items['kernel'].trafo_data[key]['orig_argnames'] == (
        'm', 'n', 'p_a', 'p_b', 'q', 'r', 'c'
    )

    # Check analysis result in layer
    assert key in items['layer'].trafo_data
    assert items['layer'].trafo_data[key]['expansion_map'] == {
        'p_a': ('p_a%a', 'p_a%b'),
        'q': ('q%a%a', 'q%a%b', 'q%b', 'q%constants'),
        'r': ('r%a%a', 'r%a%b', 'r%b')
    }
    assert items['layer'].trafo_data[key]['orig_argnames'] == (
        'm', 'n', 'p_a', 'p_b', 'q', 'r'
    )

    # Check arguments of setup
    assert items['setup'].ir.arguments == (
        't_a(:)', 't_b(:, :)', 'm', 'n'
    )

    # Check arguments of init
    assert items['init'].ir.arguments == (
        't_a(:)', 't_b(:, :)', 'm', 'n'
    )

    # Check arguments of teardown
    assert items['teardown'].ir.arguments == (
        't_a(:)', 't_b(:, :)'
    )

    # Check arguments of kernel
    assert items['kernel'].ir.arguments == (
        'm', 'n', 'P_a(n)', 'P_b(m, n)', 'Q_a(:)', 'Q_b(:, :)', 'R_a(:)', 'R_b(:, :)', 'c'
    )

    # Check call arguments in layer
    calls = FindNodes(CallStatement).visit(items['layer'].ir.ir)
    assert len(calls) == 2

    assert calls[0].arguments == (
        'm', 'n', 'p_a_a', 'p_a_b', 'q_a_a', 'q_a_b', 'r_a_a', 'r_a_b', 'q_constants(1)%c'
    )
    assert calls[1].arguments == (
        'm', 'n', 'p_b(k)%a', 'p_b(k)%b', 'q_b(k)%a', 'q_b(k)%b', 'r_b(k)%a', 'r_b(k)%b', 'q_constants(2)%c'
    )

    # Check arguments of layer
    assert items['layer'].ir.arguments == (
        'm', 'n', 'p_a_a(:)', 'p_a_b(:, :)', 'p_b(5)', 'q_a_a(:)', 'q_a_b(:, :)', 'q_b(:)', 'q_constants(:)',
        'r_a_a(:)', 'r_a_b(:, :)', 'r_b(:)'
    )

    # Check imports
    assert 'constants_type' in items['layer'].ir.imported_symbols
    if frontend != OMNI:
        # OMNI inlines parameters
        assert 'jprb' in items['layer'].ir.imported_symbols
        assert 'jprb' in items['setup'].ir.imported_symbols
        assert 'jprb' in items['teardown'].ir.imported_symbols

    # No additional imports added for init and kernel
    assert len(items['init'].ir.imports) == 1
    assert len(items['kernel'].ir.imports) == 1

    # Cached property updated?
    assert len(items['layer'].ir.imports) == 3

    # Check call arguments in caller
    for call in FindNodes(CallStatement).visit(items['caller'].ir.body):
        if call.name in ('setup', 'init'):
            assert len(call.arguments) == 4
        elif call.name == 'teardown':
            assert len(call.arguments) == 2
        elif call.name == 'layer':
            assert call.arguments == (
                'm', 'n', 't_io%a%a', 't_io%a%b', 't_io%b',
                't_in(i)%a%a', 't_in(i)%a%b', 't_in(i)%b', 't_in(i)%constants',
                't_out(i)%a%a', 't_out(i)%a%b', 't_out(i)%b',
            )
        else:
            pytest.xfail('Unknown call name')


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_derived_type_arguments_typebound_proc(frontend, tmp_path):
    fcode = f"""
module transform_derived_type_arguments_mod

    implicit none

    type some_derived_type
{'!$loki dimension(n)' if frontend is not OMNI else ''}
        real, allocatable :: a(:)
{'!$loki dimension(m, n)' if frontend is not OMNI else ''}
        real, allocatable :: b(:,:)
{'!$loki dimension(m, n)' if frontend is not OMNI else ''}
        real, allocatable :: c(:,:)
    contains
        procedure, pass :: kernel_a
        procedure :: kernel_b_c => kernel
        procedure, pass(this) :: reduce
    end type some_derived_type

contains

    subroutine kernel_a(this, out, n)
        class(some_derived_type), intent(inout) :: this
        real, allocatable, intent(inout)        :: out(:)
        integer                , intent(in)     :: n
        integer :: j

        do j=1,n
            out(j) = this%a(j) + 1.
        end do
    end subroutine kernel_a

    subroutine kernel(this, other, m, n)
        class(some_derived_type), intent(in)   :: this
        type(some_derived_type), intent(inout) :: other
        integer                , intent(in)    :: m, n
        integer :: j, k

        do j=1,n
            do k=1,m
                other%b(k, j) = 1.e3 - this%b(k, j) - this%c(k, j)
            end do
        end do
    end subroutine kernel

    function reduce(start, this) result(val)
        real, intent(in) :: start
        class(some_derived_type), intent(in) :: this
        real :: val
        val = start + sum(this%a + sum(this%b + this%c, 1))
    end function reduce
end module transform_derived_type_arguments_mod
    """.strip()

    fcode_driver = """
subroutine driver(some, result)
    use transform_derived_type_arguments_mod, only: some_derived_type, reduce
    implicit none
    type(some_derived_type), intent(in) :: some
    real, intent(inout) :: result
    result = reduce(result, some)
end subroutine driver
    """.strip()

    source = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    kernel_a = ProcedureItem(name='transform_derived_type_arguments_mod#kernel_a', source=source)
    kernel = ProcedureItem(name='transform_derived_type_arguments_mod#kernel', source=source)
    reduce = ProcedureItem(name='transform_derived_type_arguments_mod#reduce', source=source)
    source_driver = Sourcefile.from_source(
        fcode_driver, frontend=frontend, definitions=source.definitions, xmods=[tmp_path]
    )
    driver = ProcedureItem(name='#driver', source=source_driver)

    # Check procedure bindings before the transformation
    typedef = source['some_derived_type']
    assert typedef.variable_map['kernel_a'].type.pass_attr is True
    assert typedef.variable_map['kernel_b_c'].type.pass_attr in (None, True)
    assert typedef.variable_map['reduce'].type.pass_attr == 'this'
    proc_decls = [decl for decl in typedef.declarations if isinstance(decl, ProcedureDeclaration)]
    assert len(proc_decls) == 3
    assert proc_decls[0].symbols[0] == 'kernel_a'
    assert proc_decls[1].symbols[0] == 'kernel_b_c'
    assert proc_decls[2].symbols[0] == 'reduce'

    graph_dic = {driver: [reduce]}
    graph = SGraph.from_dict(graph_dic)

    # Apply transformation
    transformation = DerivedTypeArgumentsTransformation(key='some_key')
    source['kernel_a'].apply(transformation, role='kernel', item=kernel_a)
    source['kernel'].apply(transformation, role='kernel', item=kernel)
    source['reduce'].apply(transformation, role='kernel', item=reduce)
    source_driver['driver'].apply(transformation, role='driver', item=driver, sub_sgraph=graph)

    # Check analysis outcome
    assert 'some_key' in kernel_a.trafo_data
    assert 'some_key' in kernel.trafo_data
    assert 'some_key' in reduce.trafo_data

    assert kernel_a.trafo_data['some_key']['expansion_map'] == {
        'this': ('this%a',),
    }
    assert kernel_a.trafo_data['some_key']['orig_argnames'] == ('this', 'out', 'n')
    assert kernel.trafo_data['some_key']['expansion_map'] == {
        'this': ('this%b', 'this%c'),
        'other': ('other%b',)
    }
    assert kernel.trafo_data['some_key']['orig_argnames'] == ('this', 'other', 'm', 'n')
    assert reduce.trafo_data['some_key']['expansion_map'] == {
        'this': ('this%a', 'this%b', 'this%c'),
    }
    assert reduce.trafo_data['some_key']['orig_argnames'] == ('start', 'this')

    # Check transformation outcome
    assert kernel_a.ir.arguments == ('this_a(:)', 'out(:)', 'n')
    assert kernel.ir.arguments == ('this_b(:, :)', 'this_c(:, :)', 'other_b(:, :)', 'm', 'n')
    assert reduce.ir.arguments == ('start', 'this_a(:)', 'this_b(:, :)', 'this_c(:, :)')

    inline_calls = list(FindInlineCalls().visit(driver.ir.body))
    assert len(inline_calls) == 1
    assert inline_calls[0].parameters == ('result', 'some%a', 'some%b', 'some%c')

    # Check updated procedure bindings
    typedef = source['some_derived_type']
    assert typedef.variable_map['kernel_a'].type.pass_attr is False
    assert typedef.variable_map['kernel_b_c'].type.pass_attr is False
    assert typedef.variable_map['reduce'].type.pass_attr is False
    proc_decls = [decl for decl in typedef.declarations if isinstance(decl, ProcedureDeclaration)]
    assert len(proc_decls) == 3
    assert proc_decls[0].symbols[0] == 'kernel_a'
    assert proc_decls[1].symbols[0] == 'kernel_b_c'
    assert proc_decls[2].symbols[0] == 'reduce'

    # Check output of fgen
    assert source.to_fortran().count(' NOPASS ') == 3


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_derived_type_arguments_import_rename(frontend, tmp_path):
    fcode1 = """
module some_mod
    implicit none
    type some_type
        integer, allocatable :: a(:)
    end type some_type
contains
    subroutine some_routine(t)
        type(some_type), intent(inout) :: t
        t%a = 1.
    end subroutine some_routine
end module some_mod
    """.strip()
    fcode2 = """
subroutine some_routine(t)
    use some_mod, only: some_type, routine => some_routine
    type(some_type), intent(inout) :: t
    call routine(t)
end subroutine some_routine
    """.strip()

    source1 = Sourcefile.from_source(fcode1, frontend=frontend, xmods=[tmp_path])
    source2 = Sourcefile.from_source(fcode2, frontend=frontend, definitions=source1.definitions, xmods=[tmp_path])

    callee = ProcedureItem(name='some_mod#some_routine', source=source1)
    caller = ProcedureItem(name='#some_routine', source=source2)

    graph_dic = {caller: [callee]}
    graph = SGraph.from_dict(graph_dic)

    transformation = DerivedTypeArgumentsTransformation()
    source1['some_routine'].apply(transformation, item=callee, role='kernel')
    source2['some_routine'].apply(transformation, item=caller, role='kernel', sub_sgraph=graph)

    assert caller.trafo_data[transformation._key]['expansion_map'] == {
        't': ('t%a',),
    }
    assert callee.trafo_data[transformation._key]['expansion_map'] == {
        't': ('t%a',),
    }

    assert caller.ir.arguments == ('t_a(:)',)
    assert callee.ir.arguments == ('t_a(:)',)

    calls = FindNodes(CallStatement).visit(caller.ir.body)
    assert len(calls) == 1
    assert calls[0].arguments == ('t_a',)


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_derived_type_arguments_optional_named_arg(frontend, tmp_path):
    fcode = """
module some_mod
    implicit none
    type some_type
        integer, allocatable :: arr(:)
    end type some_type
contains
    subroutine callee(t, val, opt1, opt2)
        type(some_type), intent(inout) :: t
        integer, intent(in) :: val
        integer, intent(in), optional :: opt1
        integer, intent(in), optional :: opt2

        t%arr(:) = val

        if (present(opt1)) then
            t%arr(:) = t%arr(:) + opt1
        endif
        if (present(opt2)) then
            t%arr(:) = t%arr(:) + opt2
        endif
    end subroutine callee

    subroutine caller(t)
        type(some_type), intent(inout) :: t
        call callee(t, 1, opt2=2)
        call callee(t, 1, 1)
        call callee(opt1=1, val=1, t=t, opt2=2)
    end subroutine caller
end module some_mod
    """.strip()
    source = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    callee = ProcedureItem(name='some_mod#callee', source=source)
    caller = ProcedureItem(name='some_mod#caller', source=source)

    graph_dic = {caller: [callee]}
    graph = SGraph.from_dict(graph_dic)

    transformation = DerivedTypeArgumentsTransformation()
    source['callee'].apply(transformation, item=callee, role='kernel')
    source['caller'].apply(transformation, item=caller, role='driver', sub_sgraph=graph)

    assert not caller.trafo_data[transformation._key]
    assert callee.trafo_data[transformation._key]['expansion_map'] == {
        't': ('t%arr',)
    }

    calls = FindNodes(CallStatement).visit(caller.ir.body)
    assert len(calls) == 3
    assert calls[0].arguments == ('t%arr', '1')
    assert calls[0].kwarguments == (('opt2', '2'),)
    assert calls[1].arguments == ('t%arr', '1', '1')
    assert not calls[1].kwarguments
    assert not calls[2].arguments
    assert calls[2].kwarguments == (('opt1', '1'), ('val', '1'), ('t_arr', 't%arr'), ('opt2', '2'))


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_derived_type_arguments_recursive(frontend, tmp_path):
    fcode = """
module some_mod
    implicit none
    type some_type
        integer, allocatable :: arr(:)
    end type some_type
contains
    recursive subroutine callee(t, val, opt1, opt2, recurse)
        type(some_type), intent(inout) :: t
        integer, intent(in) :: val
        integer, intent(in), optional :: opt1
        integer, intent(in), optional :: opt2
        logical, intent(in), optional :: recurse

        if (present(recurse)) then
            if (recurse) then
                call callee(t, val, opt1, opt2, recurse=.false.)
            endif
        endif

        t%arr(:) = val

        if (present(opt1)) then
            t%arr(:) = t%arr(:) + opt1
        endif
        if (present(opt2)) then
            t%arr(:) = t%arr(:) + opt2
        endif
    end subroutine callee

    recursive function plus(t, val, idx, stop_recurse) result(retval)
        type(some_type), intent(in) :: t
        integer, intent(in) :: val, idx
        logical, intent(in), optional :: stop_recurse
        integer :: retval

        if (present(stop_recurse)) then
            if (stop_recurse) then
                retval = t%arr(idx)
                return
            end if
        endif

        if (val == 2) then
            retval = plus(t, 1, idx)
        elseif (val < 2) then
            retval = plus(t, 0, idx, stop_recurse=.true.)
        else
            retval = plus(t, val-1, idx)
        endif

        retval = retval + 1

    end function plus

    subroutine caller(t)
        type(some_type), intent(inout) :: t
        call callee(t, 1, opt2=2)
        call callee(t, 1, 1)
        t%arr(1) = plus(t, 32, 1)
    end subroutine caller
end module some_mod
    """.strip()
    source = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    callee = ProcedureItem(name='some_mod#callee', source=source)
    caller = ProcedureItem(name='some_mod#caller', source=source)
    plus = ProcedureItem(name='some_mod#plus', source=source)

    graph_dic = {caller: [callee, plus]}
    graph = SGraph.from_dict(graph_dic)

    transformation = DerivedTypeArgumentsTransformation()
    source['callee'].apply(transformation, item=callee, role='kernel')
    source['plus'].apply(transformation, item=plus, role='kernel')
    source['caller'].apply(transformation, item=caller, role='driver', sub_sgraph=graph)

    assert not caller.trafo_data[transformation._key]
    assert callee.trafo_data[transformation._key]['expansion_map'] == {
        't': ('t%arr',)
    }
    assert plus.trafo_data[transformation._key]['expansion_map'] == {
        't': ('t%arr',)
    }

    calls = FindNodes(CallStatement).visit(caller.ir.body)
    assert len(calls) == 2
    assert calls[0].arguments == ('t%arr', '1')
    assert calls[0].kwarguments == (('opt2', '2'),)
    assert calls[1].arguments == ('t%arr', '1', '1')
    assert not calls[1].kwarguments

    inline_calls = list(FindInlineCalls().visit(caller.ir.body))
    assert len(inline_calls) == 1
    assert inline_calls[0].parameters == ('t%arr', '32', '1')
    assert not inline_calls[0].kw_parameters

    assert callee.ir.arguments == ('t_arr(:)', 'val', 'opt1', 'opt2', 'recurse')
    assert callee.ir.arguments[0].type.intent == 'inout'

    calls = FindNodes(CallStatement).visit(callee.ir.body)
    assert len(calls) == 1
    assert calls[0].arguments == ('t_arr', 'val', 'opt1', 'opt2')
    assert calls[0].kwarguments == (('recurse', 'False'),)

    inline_calls = list(FindInlineCalls().visit(plus.ir.body))
    inline_calls = [call for call in inline_calls if call.name == 'plus']
    assert len(inline_calls) == 3
    for call in inline_calls:
        if call.kwarguments:
            assert call.parameters == ('t_arr', '0', 'idx')
            assert call.kwarguments == (('stop_recurse', 'True'),)
        else:
            assert call.parameters in [
                ('t_arr', '1', 'idx'), ('t_arr', 'val - 1', 'idx')
            ]

    assert plus.ir.arguments == ('t_arr(:)', 'val', 'idx', 'stop_recurse')
    assert plus.ir.arguments[0].type.intent == 'in'


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_derived_type_arguments_renamed_calls(frontend, tmp_path):
    fcode_header = """
module header_mod
    implicit none
    type some_type
        integer, allocatable :: some(:)
        integer, allocatable :: other(:)
    end type some_type
end module header_mod
    """.strip()
    fcode_some = """
module some_mod
    implicit none
contains
    subroutine sub(t)
        use header_mod, only: some_type
        type(some_type), intent(inout) :: t
        t%some(:) = 1
    end subroutine sub
end module some_mod
    """.strip()
    fcode_other = """
module other_mod
    implicit none
contains
    subroutine sub(t)
        use header_mod, only: some_type
        type(some_type), intent(inout) :: t
        t%other(:) = 2
    end subroutine sub
end module other_mod
    """.strip()
    fcode_caller = """
subroutine caller(t)
    use header_mod, only: some_type
    use some_mod, only: some_sub => sub
    use other_mod, only: sub
    implicit none
    type(some_type), intent(inout) :: t
    call some_sub(t)
    call sub(t)
end subroutine caller
    """.strip()

    source_header = Sourcefile.from_source(fcode_header, frontend=frontend, xmods=[tmp_path])
    source_some = Sourcefile.from_source(
        fcode_some, frontend=frontend,
        definitions=source_header.definitions, xmods=[tmp_path]
    )
    source_other = Sourcefile.from_source(
        fcode_other, frontend=frontend, definitions=source_header.definitions, xmods=[tmp_path]
    )
    source_caller = Sourcefile.from_source(
        fcode_caller, frontend=frontend, xmods=[tmp_path],
        definitions=source_header.definitions + source_some.definitions + source_other.definitions
    )

    some_sub = ProcedureItem(name='some_mod#sub', source=source_some)
    other_sub = ProcedureItem(name='other_mod#sub', source=source_other)
    caller = ProcedureItem(name='#caller', source=source_caller)

    graph_dic = {caller: [some_sub, other_sub]}
    graph = SGraph.from_dict(graph_dic)

    transformation = DerivedTypeArgumentsTransformation()
    source_some['sub'].apply(transformation, item=some_sub, role='kernel')
    source_other['sub'].apply(transformation, item=other_sub, role='kernel')
    source_caller['caller'].apply(transformation, item=caller, role='driver', sub_sgraph=graph)

    assert some_sub.ir.arguments == ('t_some(:)',)
    assert other_sub.ir.arguments == ('t_other(:)',)
    calls = FindNodes(CallStatement).visit(caller.ir.body)
    assert len(calls) == 2
    assert calls[0].arguments == ('t%some',)
    assert calls[1].arguments == ('t%other',)


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_derived_type_arguments_associate_intent(frontend, tmp_path):
    fcode = """
module some_mod
    implicit none
    type some_type
        real, allocatable :: arr(:)
    end type some_type
contains
    subroutine some_routine(t)
        type(some_type), intent(inout) :: t
        associate(arr=>t%arr)
            arr(:) = arr(:) + 1
        end associate
    end subroutine some_routine
end module some_mod
    """.strip()
    source = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    variables = FindVariables().visit(source['some_routine'].body)
    assert variables == {'arr(:)', 'arr', 't%arr', 't'}
    variable_map = CaseInsensitiveDict((v.name, v) for v in variables)
    assert variable_map['t'].type.intent == 'inout'
    assert variable_map['arr'].type.intent is None

    do_resolve_associates(source['some_routine'])
    variables = FindVariables().visit(source['some_routine'].body)
    assert variables == {'t', 't%arr(:)'}
    variable_map = CaseInsensitiveDict((v.name, v) for v in variables)
    assert variable_map['t'].type.intent == 'inout'
    assert variable_map['t%arr'].type.intent is None

    transformation = DerivedTypeArgumentsTransformation()
    source['some_routine'].apply(transformation, role='kernel')
    variables = FindVariables().visit(source['some_routine'].body)
    assert variables == {'t_arr(:)'}
    variable_map = CaseInsensitiveDict((v.name, v) for v in variables)
    assert variable_map['t_arr'].type.intent == 'inout'


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_derived_type_arguments_non_array(frontend, tmp_path):
    fcode = """
module some_mod
    implicit none
    type scalar_type
        integer :: i
    end type scalar_type
    type array_type
        integer, allocatable :: a(:)
    end type array_type
    type nested_type
        type(scalar_type) :: s
    end type nested_type
contains
    subroutine kernel(s, a, n)
        type(scalar_type), intent(inout) :: s
        type(array_type), intent(inout) :: a
        type(nested_type), intent(inout) :: n
        s%i = 1
        a%a(:) = 2
        n%s%i = 3
    end subroutine kernel
end module some_mod
    """.strip()
    source = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    transformation = DerivedTypeArgumentsTransformation()
    source['kernel'].apply(transformation, role='kernel')
    # Only type with derived type member
    assert source['kernel'].arguments == ('s', 'a_a(:)', 'n_s_i')


@pytest.mark.parametrize('duplicate', [False,True])
@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_typebound_procedure_calls(tmp_path, frontend, config, duplicate):
    fcode1 = """
module typebound_procedure_calls_mod
    implicit none

    type my_type
        integer :: val
    contains
        procedure :: reset
        procedure :: add => add_my_type
    end type my_type

    type other_type
        type(my_type) :: arr(3)
    contains
        procedure :: add => add_other_type
        procedure :: total_sum
    end type other_type

contains

    subroutine reset(this)
        class(my_type), intent(inout) :: this
        this%val = 0
    end subroutine reset

    subroutine add_my_type(this, val)
        class(my_type), intent(inout) :: this
        integer, intent(in) :: val
        this%val = this%val + val
    end subroutine add_my_type

    subroutine add_other_type(this, other)
        class(other_type) :: this
        type(other_type) :: other
        integer :: i
        do i=1,3
            call this%arr(i)%add(other%arr(i)%val)
        end do
    end subroutine add_other_type

    function total_sum(this) result(result)
        class(other_type), intent(in) :: this
        integer :: result
        integer :: i
        result = 0
        do i=1,3
            result = result + this%arr(i)%val
        end do
    end function total_sum

end module typebound_procedure_calls_mod
    """.strip()

    fcode2 = """
module other_typebound_procedure_calls_mod
    use typebound_procedure_calls_mod, only: other_type
    use function_mod, only: some_type
    implicit none

    type third_type
        type(other_type) :: stuff(2)
        type(some_type) :: some
    contains
        procedure :: init
        procedure :: print => print_content
    end type third_type

contains

    subroutine init(this)
        class(third_type), intent(inout) :: this
        integer :: i, j
        do i=1,2
            do j=1,3
                call this%stuff(i)%arr(j)%reset()
                call this%stuff(i)%arr(j)%add(i+j)
            end do
        end do
    end subroutine init

    subroutine print_content(this)
        class(third_type), intent(inout) :: this
        integer :: val
        call this%stuff(1)%add(this%stuff(2))
        val = this%stuff(1)%total_sum()
        print *, val
    end subroutine print_content
end module other_typebound_procedure_calls_mod
    """.strip()

    fcode3 = """
module function_mod
    implicit none
    type some_type
    contains
        procedure :: some_func
    end type some_type
contains
    function some_func(this)
        class(some_type) :: this
        integer some_func
        some_func = 1
    end function some_func
end module function_mod
    """.strip()

    fcode4 = """
subroutine driver
    use other_typebound_procedure_calls_mod, only: third_type
    implicit none
    type(third_type) :: data
    integer :: mysum

    call data%init()
    call data%stuff(1)%arr(1)%add(1)
    mysum = data%stuff(1)%total_sum() + data%stuff(2)%total_sum()
    associate (some => data%some)
        mysum = mysum + some%some_func()
    end associate
    call data%print
end subroutine driver
    """.strip()

    (tmp_path/'typebound_procedure_calls_mod.F90').write_text(fcode1)
    (tmp_path/'other_typebound_procedure_calls_mod.F90').write_text(fcode2)
    (tmp_path/'function_mod.F90').write_text(fcode3)
    (tmp_path/'driver.F90').write_text(fcode4)

    scheduler = Scheduler(
        paths=[tmp_path], config=config, seed_routines=['driver'], frontend=frontend, xmods=[tmp_path]
    )

    transformation = TypeboundProcedureCallTransformation(duplicate_typebound_kernels=duplicate)
    scheduler.process(transformation=transformation)

    driver = scheduler['#driver'].ir
    calls = FindNodes(CallStatement).visit(driver.body)
    assert len(calls) == 3
    assert calls[0].name == 'init'
    assert calls[0].arguments == ('data',)
    assert calls[1].name == 'add_my_type'
    assert calls[1].arguments == ('data%stuff(1)%arr(1)', '1')
    assert calls[2].name == 'print_content'
    assert calls[2].arguments == ('data',)

    calls = FindInlineCalls().visit(driver.body)
    assert len(calls) == 3
    assert {str(call).lower() for call in calls} == {
        'total_sum(data%stuff(1))', 'total_sum(data%stuff(2))', 'some_func(some)'
    }

    assert 'init' in driver.imported_symbols
    assert 'add_my_type' in driver.imported_symbols
    assert 'print_content' in driver.imported_symbols
    assert 'total_sum' in driver.imported_symbols

    add_other_type = scheduler['typebound_procedure_calls_mod#add_other_type'].ir
    calls = FindNodes(CallStatement).visit(add_other_type.body)
    assert len(calls) == 1
    assert calls[0].name == 'add_my_type'
    assert calls[0].arguments == ('this%arr(i)', 'other%arr(i)%val')

    init = scheduler['other_typebound_procedure_calls_mod#init'].ir
    calls = FindNodes(CallStatement).visit(init.body)
    assert len(calls) == 2
    assert calls[0].name == 'reset'
    assert calls[0].arguments == ('this%stuff(i)%arr(j)',)
    assert calls[1].name == 'add_my_type'
    assert calls[1].arguments == ('this%stuff(i)%arr(j)', 'i + j')

    print_content = scheduler['other_typebound_procedure_calls_mod#print_content'].ir
    calls = FindNodes(CallStatement).visit(print_content.body)
    assert len(calls) == 1
    assert calls[0].name == 'add_other_type'
    assert calls[0].arguments == ('this%stuff(1)', 'this%stuff(2)')

    calls = list(FindInlineCalls().visit(print_content.body))
    assert len(calls) == 1
    assert str(calls[0]).lower() == 'total_sum(this%stuff(1))'

    if duplicate:
        mod = scheduler['typebound_procedure_calls_mod#add_other_type'].ir.parent

        expected_routines = [
            'reset', 'add_my_type', 'add_other_type', 'total_sum',
            'add_other_type_', 'total_sum_', 'reset_', 'add_my_type_',
        ]
        assert all(r.name.lower() in expected_routines for r in mod.subroutines)

        my_type = mod['my_type']
        assert my_type.variable_map['reset'].type.bind_names == ('reset_',)
        assert my_type.variable_map['add'].type.bind_names == ('add_my_type_',)
        other_type = mod['other_type']
        assert other_type.variable_map['add'].type.bind_names == ('add_other_type_',)
        assert other_type.variable_map['total_sum'].type.bind_names == ('total_sum_',)

        other_mod = scheduler['other_typebound_procedure_calls_mod#init'].ir.parent

        assert [r.name.lower() for r in other_mod.subroutines] == [
            'init', 'print_content', 'init_', 'print_content_'
        ]

        third_type = other_mod['third_type']
        assert third_type.variable_map['init'].type.bind_names == ('init_',)
        assert third_type.variable_map['print'].type.bind_names == ('print_content_',)

        assert [
            str(call.name) for call in FindNodes(CallStatement).visit(other_mod['init_'].ir)
        ] == ['this%stuff(i)%arr(j)%reset', 'this%stuff(i)%arr(j)%add']
loki-ecmwf-0.3.6/loki/transformations/tests/test_routine_signatures.py0000664000175000017500000001410615167130205026636 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki import Module, Subroutine
from loki.frontend import available_frontends
from loki.ir import FindNodes, CallStatement
from loki.transformations.routine_signatures import RemoveDuplicateArgs

@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('pass_as_kwarg', (True, False))
@pytest.mark.parametrize('recurse_to_kernels', (True, False))
@pytest.mark.parametrize('rename_common', (True, False))
def test_utilities_remove_duplicate_args(tmp_path, frontend, pass_as_kwarg, recurse_to_kernels, rename_common):
    """
    Test lowering constant array indices
    """
    fcode_driver = f"""
subroutine driver(nlon,nlev,nb,var)
  use kernel_mod, only: kernel
  implicit none
  integer, intent(in) :: nlon,nlev,nb
  real, intent(inout) :: var(nlon,nlev,5,nb)
  integer :: ibl
  integer :: offset
  integer :: some_val
  integer :: loop_start, loop_end
  loop_start = 2
  loop_end = nb
  some_val = 0
  offset = 1
  !$omp test
  do ibl=loop_start, loop_end
    call kernel(nlon,nlev, &
      & {'var1=' if pass_as_kwarg else ''}var(:,:,1,ibl),&
      & {'var2=' if pass_as_kwarg else ''}var(:,:,1,ibl),&
      & {'another_var=' if pass_as_kwarg else ''}var(:,:,2:5,ibl),&
      & {'icend=' if pass_as_kwarg else ''}offset,&
      & {'lstart=' if pass_as_kwarg else ''}loop_start,&
      & {'lend=' if pass_as_kwarg else ''}loop_end,&
      & {'kend=' if pass_as_kwarg else ''}nlev)
    call kernel(nlon,nlev, &
      & {'var1=' if pass_as_kwarg else ''}var(:,:,1,ibl),&
      & {'var2=' if pass_as_kwarg else ''}var(:,:,1,ibl),&
      & {'another_var=' if pass_as_kwarg else ''}var(:,:,2:5,ibl),&
      & {'icend=' if pass_as_kwarg else ''}offset,&
      & {'lstart=' if pass_as_kwarg else ''}loop_start,&
      & {'lend=' if pass_as_kwarg else ''}loop_end,&
      & {'kend=' if pass_as_kwarg else ''}nlev)
  enddo
end subroutine driver
"""

    fcode_kernel = """
module kernel_mod
implicit none
contains
subroutine kernel(nlon,nlev,var1,var2,another_var,icend,lstart,lend,kend)
  use compute_mod, only: compute
  implicit none
  integer, intent(in) :: nlon,nlev,icend,lstart,lend,kend
  real, intent(inout) :: var1(nlon,nlev)
  real, intent(inout) :: var2(nlon,nlev)
  real, intent(inout) :: another_var(nlon,nlev,4)
  integer :: jk, jl, jt
  var1(:,:) = 0.
  do jk = 1,kend
    do jl = 1, nlon
      var1(jl, jk) = 0.
      var2(jl, jk) = 1.0
      do jt= 1,4
        another_var(jl, jk, jt) = 0.0
      end do
    end do
  end do
  call compute(nlon,nlev,var1, var2)
  call compute(nlon,nlev,var1, var2)
end subroutine kernel
end module kernel_mod
"""

    fcode_nested_kernel = """
module compute_mod
implicit none
contains
subroutine compute(nlon,nlev,b_var,a_var)
  implicit none
  integer, intent(in) :: nlon,nlev
  real, intent(inout) :: b_var(nlon,nlev)
  real, intent(inout) :: a_var(nlon,nlev)
  real :: VAR ! create name clash on purpose (if rename_common)
  b_var(:,:) = 0.
  a_var(:,:) = 1.0
end subroutine compute
end module compute_mod
"""

    nested_kernel_mod = Module.from_source(fcode_nested_kernel, frontend=frontend, xmods=[tmp_path])
    kernel_mod = Module.from_source(fcode_kernel, frontend=frontend, definitions=nested_kernel_mod, xmods=[tmp_path])
    driver = Subroutine.from_source(fcode_driver, frontend=frontend, definitions=kernel_mod, xmods=[tmp_path])

    transformation = RemoveDuplicateArgs(recurse_to_kernels=recurse_to_kernels, rename_common=rename_common)
    transformation.apply(driver, role='driver', targets=('kernel',))
    transformation.apply(kernel_mod['kernel'], role='kernel', targets=('compute',))
    transformation.apply(nested_kernel_mod['compute'], role='kernel')

    # driver
    kernel_var_name = 'var' if rename_common else 'var1'
    kernel_calls = FindNodes(CallStatement).visit(driver.body)
    for kernel_call in kernel_calls:
        if pass_as_kwarg:
            assert (kernel_var_name, 'var(:, :, 1, ibl)') in kernel_call.kwarguments
            assert ('var2', 'var(:, :, 1, ibl)') not in kernel_call.kwarguments
            arg1 = kernel_call.kwarguments[0][1]
            arg2 = kernel_call.kwarguments[1][1]
        else:
            assert 'var(:, :, 1, ibl)' in kernel_call.arguments
            assert 'var2(:, :, 1, ibl)' not in kernel_call.arguments
            arg1 = kernel_call.arguments[2]
            arg2 = kernel_call.arguments[3]
        assert arg1.dimensions == (':', ':', '1', 'ibl')
        assert arg2.dimensions == (':', ':', '2:5', 'ibl')
    # kernel
    kernel_vars = kernel_mod['kernel'].variable_map
    kernel_args = kernel_mod['kernel']._dummies
    assert kernel_var_name in kernel_args
    assert 'var2' not in kernel_args
    assert 'var2' not in kernel_vars
    assert kernel_vars[kernel_var_name].shape == ('nlon', 'nlev')
    assert kernel_vars['another_var'].dimensions == ('nlon', 'nlev', 4)
    compute_calls = FindNodes(CallStatement).visit(kernel_mod['kernel'].body)
    for compute_call in compute_calls:
        assert kernel_var_name in compute_call.arguments
        assert 'var2' not in compute_call.arguments
    # nested_kernel
    nested_kernel = nested_kernel_mod['compute']
    nested_kernel_vars = nested_kernel.variable_map
    nested_kernel_args = [arg.name.lower() for arg in nested_kernel.arguments]
    # it's always 'b_var' as a rename would clash with the already "used" variable "var"
    nested_kernel_var_name = 'b_var'
    if recurse_to_kernels:
        assert nested_kernel_var_name in nested_kernel_args
        assert 'a_var' not in nested_kernel_args
        assert nested_kernel_var_name in nested_kernel_vars
        assert 'a_var' not in nested_kernel_vars
    else:
        assert 'b_var' in nested_kernel_args
        assert 'a_var' in nested_kernel_args
        assert 'b_var' in nested_kernel_vars
        assert 'a_var' in nested_kernel_vars
loki-ecmwf-0.3.6/loki/transformations/tests/sources/0000775000175000017500000000000015167130205022755 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/transformations/tests/sources/projDerivedTypes/0000775000175000017500000000000015167130205026257 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/transformations/tests/sources/projDerivedTypes/some_derived_type.F900000664000175000017500000000025515167130205032247 0ustar  alastairalastairmodule some_derived_type_mod

implicit none

    type some_derived_type
        real :: a
        real :: b
    end type some_derived_type

end module some_derived_type_mod
loki-ecmwf-0.3.6/loki/transformations/tests/sources/projDerivedTypes/driver_mod.F900000664000175000017500000000077015167130205030675 0ustar  alastairalastairmodule driver_mod

use some_derived_type_mod, only: some_derived_type
use kernel_mod, only: kernel
implicit none

contains
  subroutine driver(z)
        integer, intent(in) :: z
        type(some_derived_type) :: t_io
        type(some_derived_type) :: t_in, t_out
        integer :: m, n
        integer :: i, j

        m = 100
        n = 10

        t_in%a = real(m-1)
        t_in%b = real(n-1)

        call kernel(m, n, t_io%a, t_io%b, t_in, t_out)

  end subroutine driver
end module driver_mod
loki-ecmwf-0.3.6/loki/transformations/tests/sources/projDerivedTypes/kernel_mod.F900000664000175000017500000000073315167130205030661 0ustar  alastairalastairmodule kernel_mod

use some_derived_type_mod, only: some_derived_type
implicit none

contains

  subroutine kernel(m, n, P_a, P_b, Q, R)
        integer                , intent(in)    :: m, n
        real, intent(inout)                    :: P_a, P_b
        type(some_derived_type), intent(in)    :: Q
        type(some_derived_type), intent(out)   :: R
        integer :: j, k

        R%a = P_a + Q%a
        R%b = P_b - Q%b
  end subroutine kernel

end module kernel_mod
loki-ecmwf-0.3.6/loki/transformations/tests/sources/projArgShape/0000775000175000017500000000000015167130205025342 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/transformations/tests/sources/projArgShape/kernel_a_mod.F900000664000175000017500000000050415167130205030240 0ustar  alastairalastairMODULE KERNEL_A_MOD
USE KERNEL_A1_MOD, ONLY: KERNEL_A1
IMPLICIT NONE
CONTAINS
  SUBROUTINE kernel_a(a, b, c)
    USE VAR_MODULE_MOD, only: n
    REAL, INTENT(INOUT)   :: a(:)
    REAL, INTENT(INOUT)   :: b(:,:)
    REAL, INTENT(INOUT)   :: c(:,:)

    CALL kernel_a1(b, c)
  END SUBROUTINE kernel_a

END MODULE KERNEL_A_MOD
loki-ecmwf-0.3.6/loki/transformations/tests/sources/projArgShape/kernel_b_mod.F900000664000175000017500000000044015167130205030240 0ustar  alastairalastairMODULE KERNEL_B_MOD
USE VAR_MODULE_MOD, only: n
IMPLICIT NONE
CONTAINS

  SUBROUTINE kernel_b(b, c)
    ! USE VAR_MODULE_MOD, only: n
    ! Second-level kernel call
    REAL, INTENT(INOUT)   :: b(:,:)
    REAL, INTENT(INOUT)   :: c(:,:)

  END SUBROUTINE kernel_b
END MODULE KERNEL_B_MOD
loki-ecmwf-0.3.6/loki/transformations/tests/sources/projArgShape/kernel_a1_mod.F900000664000175000017500000000034715167130205030326 0ustar  alastairalastairMODULE KERNEL_A1_MOD
IMPLICIT NONE
CONTAINS

  SUBROUTINE kernel_a1(b, c)
    ! Second-level kernel call
    REAL, INTENT(INOUT)   :: b(:,:)
    REAL, INTENT(INOUT)   :: c(:,:)

  END SUBROUTINE kernel_a1

END MODULE KERNEL_A1_MOD
loki-ecmwf-0.3.6/loki/transformations/tests/sources/projArgShape/driver_mod.F900000664000175000017500000000071415167130205027756 0ustar  alastairalastairMODULE DRIVER_MOD
USE KERNEL_A_MOD, ONLY: KERNEL_A
USE KERNEL_B_MOD, ONLY: KERNEL_B
IMPLICIT NONE
CONTAINS
  SUBROUTINE driver(nlon, nlev, a, b, c)
    INTEGER, INTENT(IN)   :: nlon, nlev  ! Dimension sizes
    INTEGER, PARAMETER    :: n = 5
    REAL, INTENT(INOUT)   :: a(nlon)
    REAL, INTENT(INOUT)   :: b(nlon,nlev)
    REAL, INTENT(INOUT)   :: c(nlon,n)

    call kernel_a(a, b, c)

    call kernel_b(b, c)
  END SUBROUTINE driver

END MODULE DRIVER_MOD
loki-ecmwf-0.3.6/loki/transformations/tests/sources/projArgShape/var_module_mod.F900000664000175000017500000000011715167130205030615 0ustar  alastairalastairMODULE VAR_MODULE_MOD
INTEGER, PARAMETER    :: n = 5
END MODULE VAR_MODULE_MOD
loki-ecmwf-0.3.6/loki/transformations/tests/sources/projSccCuf/0000775000175000017500000000000015167130205025016 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/transformations/tests/sources/projSccCuf/module/0000775000175000017500000000000015167130205026303 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/transformations/tests/sources/projSccCuf/module/kernel.f900000664000175000017500000000324715167130205030111 0ustar  alastairalastairMODULE KERNEL_MOD
    IMPLICIT NONE
    CONTAINS
    SUBROUTINE kernel(start, iend, nlon, nz, q, t, z)
        INTEGER, INTENT(IN) :: start, iend, nlon, nz
        REAL, INTENT(INOUT) :: t(nlon,nz)
        REAL, INTENT(INOUT) :: q(nlon,nz)
        REAL, INTENT(INOUT) :: z(nlon,nz)
        REAL    :: local_z(nlon, nz)
        INTEGER :: jl, jk
        REAL :: c

        c = SOME_FUNC(A=5.345)
        DO jk = 2, nz
          DO jl = start, iend
            call ELEMENTAL_DEVICE(z(jl, jk))
          END DO
        END DO

        call DEVICE(nlon, nz, 2, start, iend, z)

        c = 5.345
        DO jk = 2, nz
          DO jl = start, iend
            t(jl, jk) = c * jk
            q(jl, jk) = q(jl, jk-1) + t(jl, jk) * c
          END DO
        END DO

        DO jk = 2, nz
          DO jl = start, iend
            local_z(jl, jk) = 0.0
            z(jl, jk) = local_z(jl, jk)
          END DO
        END DO

    END SUBROUTINE kernel

    PURE ELEMENTAL SUBROUTINE ELEMENTAL_DEVICE(x) ! elemental
      REAL, INTENT(INOUT) :: x
      x = 0.0
    END SUBROUTINE ELEMENTAL_DEVICE

    SUBROUTINE DEVICE(nlon, nz, jk_start, start, iend, x)
        INTEGER, INTENT(IN) :: jk_start, start, iend, nlon, nz
        REAL, INTENT(INOUT) :: x(nlon, nz)
        REAL    :: local_x(nlon, nz)
        INTEGER :: jk, jl
        DO jk = jk_start, nz
            DO jl = start, iend
                local_x(jl, jk) = 0.0
                x(jl, jk) = local_x(jl, jk)
            END DO
        END DO
    END SUBROUTINE DEVICE

    FUNCTION SOME_FUNC(A)
        REAL, INTENT(IN) :: A
        REAL :: SOME_FUNC
        !$loki routine seq
        SOME_FUNC = A
    END FUNCTION SOME_FUNC

END MODULE KERNEL_MOD
loki-ecmwf-0.3.6/loki/transformations/tests/sources/projSccCuf/module/driver.f900000664000175000017500000000170415167130205030120 0ustar  alastairalastair
MODULE driver_mod
    USE KERNEL_MOD, ONLY: KERNEL
    IMPLICIT NONE
CONTAINS
    SUBROUTINE driver(nlon, nz, nb, tot, q, t, z)
        INTEGER, INTENT(IN)   :: nlon, nz, nb  ! Size of the horizontal and vertical
        INTEGER, INTENT(IN)   :: tot
        REAL, INTENT(INOUT)   :: t(nlon,nz,nb)
        REAL, INTENT(INOUT)   :: q(nlon,nz,nb)
        REAL, INTENT(INOUT)   :: z(nlon,nz+1,nb)
        INTEGER :: b, start, iend, ibl, icend

        start = 1
        iend = tot
        do b=1,iend,nlon
          ibl = (b-1)/nlon+1
          icend = MIN(nlon,tot-b+1)
          call kernel(start, icend, nlon, nz, q(:,:,b), t(:,:,b), z(:,:,b))
        end do

       do b=1,iend,nlon
          ibl = (b-1)/nlon+1
          icend = MIN(nlon,tot-b+1)
          call kernel(start, icend, nlon, nz, q(:,:,b), t(:,:,b), z(:,:,b))
          call kernel(start, icend, nlon, nz, q(:,:,b), t(:,:,b), z(:,:,b))
       end do

    END SUBROUTINE driver
END MODULE driver_mod
loki-ecmwf-0.3.6/loki/transformations/tests/test_cloudsc2_tl_ad.py0000664000175000017500000001222015167130205025561 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import os
import io
import resource
from subprocess import CalledProcessError
from pathlib import Path
import pandas as pd
import pytest

from loki.frontend import FP
from loki.logging import warning
from loki.tools import (
    execute, write_env_launch_script, local_loki_setup, local_loki_cleanup
)

pytestmark = pytest.mark.skipif('CLOUDSC2_DIR' not in os.environ, reason='CLOUDSC2_DIR not set')


@pytest.fixture(scope='module', name='here')
def fixture_here():
    return Path(os.environ['CLOUDSC2_DIR'])


@pytest.fixture(scope='module', name='local_loki_bundle')
def fixture_local_loki_bundle(here):
    """Call setup utilities for injecting ourselves into the CLOUDSC bundle"""
    lokidir, target, backup = local_loki_setup(here)
    yield lokidir
    local_loki_cleanup(target, backup)


@pytest.fixture(scope='module', name='bundle_create')
def fixture_bundle_create(here, local_loki_bundle):
    """Inject ourselves into the CLOUDSC bundle"""
    env = os.environ.copy()
    env['CLOUDSC_BUNDLE_LOKI_DIR'] = local_loki_bundle

    # Run ecbundle to fetch dependencies
    execute(
        ['./cloudsc-bundle', 'create'], cwd=here, silent=False, env=env
    )


@pytest.mark.usefixtures('bundle_create')
@pytest.mark.parametrize('frontend', [FP])
def test_cloudsc2_tl_ad(here, frontend):
    build_cmd = [
        './cloudsc-bundle', 'build', '--retry-verbose', '--clean',
        '--with-loki', '--loki-frontend=' + str(frontend), '--without-loki-install',
    ]

    if 'CLOUDSC2_ARCH' in os.environ:
        build_cmd += [f"--arch={os.environ['CLOUDSC2_ARCH']}"]
    else:
        # Build without OpenACC support as this makes problems
        # with older versions of GNU
        build_cmd += ['--cmake=ENABLE_ACC=OFF']

    execute(build_cmd, cwd=here, silent=False)

    # Raise stack limit
    resource.setrlimit(resource.RLIMIT_STACK, (resource.RLIM_INFINITY, resource.RLIM_INFINITY))
    env = os.environ.copy()
    env.update({'OMP_STACKSIZE': '2G', 'NVCOMPILER_ACC_CUDA_HEAPSIZE': '2G'})

    # Run the produced binaries
    nl_binaries = [
        ('dwarf-cloudsc2-nl-loki-idem', '2', '16384', '32'),
        ('dwarf-cloudsc2-nl-loki-scc', '1', '16384', '32'),
        ('dwarf-cloudsc2-nl-loki-scc-hoist', '1', '16384', '32'),
    ]
    tl_binaries = [
        ('dwarf-cloudsc2-tl-loki-idem', '1', '1024', '32'),
        ('dwarf-cloudsc2-tl-loki-scc', '1', '1024', '32'),
        ('dwarf-cloudsc2-tl-loki-scc-hoist', '1', '1024', '32'),
    ]
    ad_binaries = [
        ('dwarf-cloudsc2-ad-loki-idem',),
        ('dwarf-cloudsc2-ad-loki-scc',),
        ('dwarf-cloudsc2-ad-loki-scc-hoist',),
    ]

    failures, warnings = {}, {}

    for binary, *args in nl_binaries:
        # Write a script to source env.sh and launch the binary
        script = write_env_launch_script(here, binary, args)

        # Run the script and verify error norms
        try:
            output = execute([str(script)], cwd=here/'build', capture_output=True, silent=False, env=env)
            results = pd.read_fwf(io.StringIO(output.stdout.decode()), index_col='Variable')
            no_errors = results['AbsMaxErr'].astype('float') == 0
            if not no_errors.all(axis=None):
                only_small_errors = results['MaxRelErr-%'].astype('float') < 1e-12
                if not only_small_errors.all(axis=None):
                    failures[binary] = results
                else:
                    warnings[binary] = results
        except CalledProcessError as err:
            failures[binary] = err.stderr.decode()

    for binary, *args in tl_binaries:
        # Write a script to source env.sh and launch the binary
        script = write_env_launch_script(here, binary, args)

        # Run the script and verify error norms
        try:
            output = execute([str(script)], cwd=here/'build', capture_output=True, silent=False, env=env)
            if 'TEST PASSED' not in output.stdout.decode():
                failures[binary] = output.stdout.decode()
        except CalledProcessError as err:
            failures[binary] = err.stderr.decode()

    for binary, *args in ad_binaries:
        # Write a script to source env.sh and launch the binary
        script = write_env_launch_script(here, binary, args)

        # Run the script and verify error norms
        try:
            output = execute([str(script)], cwd=here/'build', capture_output=True, silent=False, env=env)
            if 'TEST OK' not in output.stdout.decode():
                failures[binary] = output.stdout.decode()
        except CalledProcessError as err:
            failures[binary] = err.stderr.decode()

    if warnings:
        msg = '\n'.join([f'{binary}:\n{results}' for binary, results in warnings.items()])
        warning(msg)

    if failures:
        msg = '\n'.join([f'{binary}:\n{results}' for binary, results in failures.items()])
        pytest.fail(msg)
loki-ecmwf-0.3.6/loki/transformations/tests/test_block_index_inject.py0000664000175000017500000010030615167130205026520 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from shutil import rmtree
import pytest

from loki import (
    Dimension, gettempdir, Scheduler, OMNI, FindNodes, Assignment, FindVariables, CallStatement, Subroutine,
    Item, available_frontends, Module, ir, get_pragma_parameters
)
from loki.batch import TransformationError
from loki.transformations import (
        BlockViewToFieldViewTransformation, InjectBlockIndexTransformation,
        LowerBlockIndexTransformation, LowerBlockLoopTransformation
)
from loki.expression import symbols as sym
from loki.types import DerivedType

@pytest.fixture(scope='module', name='horizontal')
def fixture_horizontal():
    return Dimension(name='horizontal', size='nlon', index='jl', bounds=('start', 'end'),
                     aliases=('nproma',), bounds_aliases=('bnds%start', 'bnds%end'))


@pytest.fixture(scope='module', name='blocking')
def fixture_blocking():
    return Dimension(name='blocking', size='nb', index='ibl', index_aliases=('bnds%kbl', 'jkglo'))


@pytest.fixture(scope='function', name='config')
def fixture_config():
    """
    Default configuration dict with basic options.
    """
    return {
        'default': {
            'mode': 'idem',
            'role': 'kernel',
            'expand': True,
            'strict': True,
            'enable_imports': True,
            'disable': ['*%init', '*%final', 'abor1'],
        },
    }


@pytest.fixture(scope='module', name='blockview_to_fieldview_code', params=[True, False])
def fixture_blockview_to_fieldview_code(request):
    fcode = {
        #-------------
        'variable_mod': (
        #-------------
"""
module variable_mod
  implicit none

  type variable_3d
      real, pointer :: p(:,:) => null()
      real, pointer :: p_field(:,:,:) => null()
  end type variable_3d

  type variable_3d_ptr
      integer :: comp
      type(variable_3d), pointer :: ptr => null()
  end type variable_3d_ptr

end module variable_mod
"""
        ).strip(),
        #-------------
        'model_type_mod': (
        #-------------
"""
module model_type_mod
  implicit none

  type model_type
     real, allocatable :: some_rdonly_var(:,:)
  end type model_type

end module model_type_mod
"""
        ).strip(),
        #--------------------
        'field_variables_mod': (
        #--------------------
"""
module field_variables_mod
  use variable_mod, only: variable_3d, variable_3d_ptr
  implicit none

  type field_variables
      type(variable_3d_ptr), allocatable :: gfl_ptr_g(:)
      type(variable_3d_ptr), pointer :: gfl_ptr(:) => null()
      type(variable_3d) :: var
  end type field_variables

end module field_variables_mod
"""
        ).strip(),
        #-------------------
        'container_type_mod': (
        #-------------------
"""
module container_type_mod
  implicit none

  type container_3d_var
    real, pointer :: p(:,:) => null()
    real, pointer :: p_field(:,:,:) => null()
  end type container_3d_var

  type container_type
    type(container_3d_var), allocatable :: vars(:)
  end type container_type

end module container_type_mod
"""
        ).strip(),
        #--------------
        'dims_type_mod': (
        #--------------
"""
module dims_type_mod
   type dims_type
      integer :: start, end, kbl, nb
   end type dims_type
end module dims_type_mod
"""
        ).strip(),
        #-------
        'driver': (
        #-------
f"""
subroutine driver(data, ydvars, container, nlon, nlev, ydmodel, {'start, end, nb' if request.param else 'bnds'})
   use field_array_module, only: field_2rb_array, field_3rb_array
   use container_type_mod, only: container_type
   use field_variables_mod, only: field_variables
   use model_type_mod, only: model_type
   {'use dims_type_mod, only: dims_type' if not request.param else ''}
   implicit none

   #include "kernel.intfb.h"

   real, intent(inout) :: data(:,:,:)
   integer, intent(in) :: nlon, nlev
   type(field_variables), intent(inout) :: ydvars
   type(container_type), intent(inout) :: container
   type(model_type), intent(inout) :: ydmodel
   {'integer, intent(in) :: start, end, nb' if request.param else 'type(dims_type), intent(in) :: bnds'}

   integer :: ibl, jl
   type(field_2rb_array) :: yla_other
   type(field_3rb_array) :: yla_data

   call yla_data%init(data)

   do ibl=1,{'nb' if request.param else 'bnds%nb'}
      {'bnds%kbl = ibl' if not request.param else ''}
      {'do jl = start,end' if request.param else 'do jl = bnds%start,bnds%end'}
         yla_data%p(jl,:) = ydmodel%some_rdonly_var(jl,ibl)
      enddo
      call kernel(nlon, nlev, {'start, end, ibl' if request.param else 'bnds'}, ydvars, container, yla_data, yla_other)
   enddo

   !$loki driver-loop
   do ibl=1,{'nb' if request.param else 'bnds%nb'}
      {'do jl = start,end' if request.param else 'do jl = bnds%start,bnds%end'}
         yla_data%p(jl,:) = 1.
      enddo
   enddo

   call yla_data%final()

end subroutine driver
"""
        ).strip(),
        #-------
        'kernel': (
        #-------
f"""
subroutine kernel(nlon, nlev, {'start, end, ibl' if request.param else 'bnds'}, ydvars, container, yda_data, yda_other)
   use field_array_module, only: field_2rb_array, field_3rb_array
   use container_type_mod, only: container_type
   use field_variables_mod, only: field_variables
   {'use dims_type_mod, only: dims_type' if not request.param else ''}
   implicit none

#include "another_kernel.intfb.h"
#include "abor1.intfb.h"

   integer, intent(in) :: nlon, nlev
   type(field_variables), intent(inout) :: ydvars
   type(container_type), intent(inout) :: container
   {'integer, intent(in) :: start, end, ibl' if request.param else 'type(dims_type), intent(in) :: bnds'}
   type(field_3rb_array), intent(inout) :: yda_data
   type(field_2rb_array), intent(in) :: yda_other

   integer :: jl, jfld
   {'associate(start=>bnds%start, end=>bnds%end, ibl=>bnds%kbl)' if not request.param else ''}

   if(nlon < 0) call abor1('kernel')

   ydvars%var%p_field(:,:) = 0. !... this should only get the block-index
   ydvars%var%p_field(:,:,ibl) = 0. !... this should be untouched

   yda_data%p(start:end,:) = 1
   ydvars%var%p(start:end,:) = 1

   do jfld=1,size(ydvars%gfl_ptr)
      do jl=start,end
         ydvars%gfl_ptr(jfld)%ptr%p(jl,:) = yda_data%p(jl,:)
         container%vars(ydvars%gfl_ptr(jfld)%comp)%p(jl,:) = 0.
      enddo
   enddo

   call another_kernel(nlon, nlev, data=yda_data%p, other=yda_other%p)

   {'end associate' if not request.param else ''}
end subroutine kernel
"""
        ).strip(),
        #-------
        'another_kernel': (
        #-------
"""
subroutine another_kernel(nproma, nlev, data, other)
   implicit none
   !... not a sequential routine but still labelling it as one to test the
   !... bail-out mechanism
   !$loki routine seq
   integer, intent(in) :: nproma, nlev
   real, intent(inout) :: data(nproma, nlev)
   real, intent(in) :: other(nproma, nlev)
end subroutine another_kernel
"""
        ).strip(),
        #-------
        'empty_kernel': (
        #-------
"""
subroutine empty_kernel()
   implicit none
end subroutine empty_kernel
"""
        ).strip()
    }

    workdir = gettempdir()/'test_blockview_to_fieldview'
    if workdir.exists():
        rmtree(workdir)
    workdir.mkdir()
    for name, code in fcode.items():
        (workdir/f'{name}.F90').write_text(code)

    yield workdir, request.param

    rmtree(workdir)


@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI,
                         'OMNI fails to import undefined module.')]))
@pytest.mark.parametrize('force_inject_arrays', [True, False])
def test_blockview_to_fieldview_pipeline(horizontal, blocking, config, frontend, blockview_to_fieldview_code,
                                         force_inject_arrays, tmp_path):

    config['routines'] = {
        'driver': {
            'role': 'driver',
            'exclude_var_names': ['ydmodel']
        },
        'empty_kernel': {'role': 'kernel'}
    }
    if force_inject_arrays:
        config['routines'].update({'kernel': {'role': 'kernel', 'force_inject_arrays': ['yda_other%p_field']}})

    scheduler = Scheduler(
        paths=(blockview_to_fieldview_code[0],), config=config, seed_routines='driver', frontend=frontend,
        xmods=[tmp_path]
    )
    scheduler.process(BlockViewToFieldViewTransformation(horizontal, global_gfl_ptr=True))
    scheduler.process(InjectBlockIndexTransformation(blocking))

    kernel = scheduler['#kernel'].ir
    aliased_bounds = not blockview_to_fieldview_code[1]
    ibl_expr = blocking.index
    if aliased_bounds:
        ibl_expr = blocking.indices[1]

    assigns = FindNodes(Assignment).visit(kernel.body)

    # check that access pointers for arrays without horizontal index in dimensions were not updated
    assert assigns[0].lhs == f'ydvars%var%p_field(:,:,{ibl_expr})'
    assert assigns[1].lhs == f'ydvars%var%p_field(:,:,{ibl_expr})'

    # check that vector notation was resolved correctly
    assert assigns[2].lhs == f'yda_data%p_field(jl, :, {ibl_expr})'
    assert assigns[3].lhs == f'ydvars%var%p_field(jl, :, {ibl_expr})'

    # check thread-local ydvars%gfl_ptr was replaced with its global equivalent
    gfl_ptr_vars = {v for v in FindVariables().visit(kernel.body) if 'ydvars%gfl_ptr' in v.name.lower()}
    gfl_ptr_g_vars = {v for v in FindVariables().visit(kernel.body) if 'ydvars%gfl_ptr_g' in v.name.lower()}
    assert gfl_ptr_g_vars
    assert not gfl_ptr_g_vars - gfl_ptr_vars

    assert assigns[4].lhs == f'ydvars%gfl_ptr_g(jfld)%ptr%p_field(jl,:,{ibl_expr})'
    assert assigns[4].rhs == f'yda_data%p_field(jl,:,{ibl_expr})'
    assert assigns[5].lhs == f'container%vars(ydvars%gfl_ptr_g(jfld)%comp)%p_field(jl,:,{ibl_expr})'

    # check callstatement was updated correctly
    calls = FindNodes(CallStatement).visit(kernel.body)
    assert f'yda_data%p_field(:,:,{ibl_expr})' in calls[1].arg_map.values()

    # check force injection of block index in sequence associated args (yay!)
    if force_inject_arrays:
        assert f'yda_other%p_field(:,{ibl_expr})' in calls[1].arg_map.values()
    else:
        assert 'yda_other%p_field' in calls[1].arg_map.values()

    # check code in driver loop was transformed correctly
    driver = scheduler['#driver'].ir
    loops = FindNodes(ir.Loop, greedy=True).visit(driver.body)
    assert len(loops) == 2
    assigns = FindNodes(Assignment).visit(loops[0].body)
    assign_loc = 1 if aliased_bounds else 0
    assert assigns[assign_loc].lhs == 'yla_data%p_field(jl,:,ibl)'

    # check block-index was not injected in explicitly marked variables
    assert assigns[assign_loc].rhs == 'ydmodel%some_rdonly_var(jl,ibl)'

    # now check if loop without target calls has been transformed correctly
    assigns = FindNodes(Assignment).visit(loops[1].body)
    assert assigns[0].lhs == 'yla_data%p_field(jl,:,ibl)'
    assert assigns[0].rhs == '1.'


@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI,
                         'OMNI fails to import undefined module.')]))
@pytest.mark.parametrize('global_gfl_ptr', [False, True])
def test_blockview_to_fieldview_only(horizontal, blocking, config, frontend, blockview_to_fieldview_code,
                                     global_gfl_ptr, tmp_path):

    config['routines'] = {
        'driver': {
            'role': 'driver',
            'exclude_var_names': ['ydmodel']
        },
    }

    scheduler = Scheduler(
        paths=(blockview_to_fieldview_code[0],), config=config, seed_routines='driver', frontend=frontend,
        xmods=[tmp_path]
    )
    scheduler.process(BlockViewToFieldViewTransformation(horizontal, global_gfl_ptr=global_gfl_ptr))

    kernel = scheduler['#kernel'].ir
    aliased_bounds = not blockview_to_fieldview_code[1]
    ibl_expr = blocking.index
    if aliased_bounds:
        ibl_expr = blocking.indices[1]

    assigns = FindNodes(Assignment).visit(kernel.body)

    # check that access pointers for arrays without horizontal index in dimensions were not updated
    assert assigns[0].lhs == 'ydvars%var%p_field(:,:)'
    assert assigns[1].lhs == f'ydvars%var%p_field(:,:,{ibl_expr})'

    # check that vector notation was resolved correctly
    assert assigns[2].lhs == 'yda_data%p_field(jl, :)'
    assert assigns[3].lhs == 'ydvars%var%p_field(jl, :)'

    # check thread-local ydvars%gfl_ptr was replaced with its global equivalent
    if global_gfl_ptr:
        gfl_ptr_vars = {v for v in FindVariables().visit(kernel.body) if 'ydvars%gfl_ptr' in v.name.lower()}
        gfl_ptr_g_vars = {v for v in FindVariables().visit(kernel.body) if 'ydvars%gfl_ptr_g' in v.name.lower()}
        assert gfl_ptr_g_vars
        assert not gfl_ptr_g_vars - gfl_ptr_vars
    else:
        assert not {v for v in FindVariables().visit(kernel.body) if 'ydvars%gfl_ptr_g' in v.name.lower()}

    assert assigns[4].rhs == 'yda_data%p_field(jl,:)'
    if global_gfl_ptr:
        assert assigns[4].lhs == 'ydvars%gfl_ptr_g(jfld)%ptr%p_field(jl,:)'
        assert assigns[5].lhs == 'container%vars(ydvars%gfl_ptr_g(jfld)%comp)%p_field(jl,:)'
    else:
        assert assigns[4].lhs == 'ydvars%gfl_ptr(jfld)%ptr%p_field(jl,:)'
        assert assigns[5].lhs == 'container%vars(ydvars%gfl_ptr(jfld)%comp)%p_field(jl,:)'

    # check callstatement was updated correctly
    calls = FindNodes(CallStatement).visit(kernel.body)
    assert 'yda_data%p_field' in calls[1].arg_map.values()

    # check that the dummy definition for field_3rb_array also contains the field pointer
    driver = scheduler['#driver'].ir
    yla_data = driver.variable_map['yla_data']
    _typedef = yla_data.type.dtype.typedef
    field_ptr = [v for v in _typedef.variables if v.name == 'F_P']
    assert field_ptr
    assert isinstance(field_ptr[0].type.dtype, DerivedType)
    assert field_ptr[0].type.dtype.name.lower() == 'field_3rb'
    assert field_ptr[0].type.pointer
    assert field_ptr[0].type.polymorphic


@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI,
                         'OMNI correctly complains about rank mismatch in assignment.')]))
def test_simple_blockindex_inject(blocking, frontend):
    fcode = """
subroutine kernel(nlon,nlev,nb,var)
  implicit none

  interface
    subroutine compute(nlon,nlev,var)
      implicit none
      integer, intent(in) :: nlon,nlev
      real, intent(inout) :: var(nlon,nlev)
    end subroutine compute
  end interface

  integer, intent(in) :: nlon,nlev,nb
  real, intent(inout) :: var(nlon,nlev,4,nb) !... this dummy arg was potentially promoted by a previous transformation

  integer :: ibl

  do ibl=1,nb !... this loop was potentially lowered by a previous transformation
     var(:,:,:) = 0.
     call compute(nlon,nlev,var(:,:,1))
  enddo

end subroutine kernel
"""

    kernel = Subroutine.from_source(fcode, frontend=frontend)
    InjectBlockIndexTransformation(blocking).apply(kernel, role='kernel', targets=('compute',))

    assigns = FindNodes(Assignment).visit(kernel.body)
    assert assigns[0].lhs == 'var(:,:,:,ibl)'

    calls = FindNodes(CallStatement).visit(kernel.body)
    assert 'var(:,:,1,ibl)' in calls[0].arguments


@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI,
                         'OMNI complains about undefined type.')]))
def test_blockview_to_fieldview_exception(frontend, horizontal):
    fcode = """
subroutine kernel(nlon,nlev,start,end,var)
  implicit none

  interface
    subroutine compute(nlon,nlev,var)
      implicit none
      integer, intent(in) :: nlon,nlev
      real, intent(inout) :: var(nlon,nlev)
    end subroutine compute
  end interface

  integer, intent(in) :: nlon,nlev,start,end
  type(wrapped_field) :: var

  call compute(nlon,nlev,var%p)

end subroutine kernel
"""

    kernel = Subroutine.from_source(fcode, frontend=frontend)
    item = Item(name='#kernel', source=kernel)
    item.trafo_data['BlockViewToFieldViewTransformation'] = {'definitions': []}
    with pytest.raises(TransformationError):
        BlockViewToFieldViewTransformation(horizontal).apply(kernel, item=item, role='kernel',
                                           targets=('compute',))

    with pytest.raises(TransformationError):
        BlockViewToFieldViewTransformation(horizontal).apply(kernel, role='kernel',
                                           targets=('compute',))


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('block_dim_arg', (False, True))
@pytest.mark.parametrize('recurse_to_kernels', (False, True))
def test_simple_lower_loop(blocking, frontend, block_dim_arg, recurse_to_kernels, tmp_path):

    fcode_driver = f"""
subroutine driver(nlon,nlev,nb,var)
  use kernel_mod, only: kernel
  implicit none
  integer, intent(in) :: nlon,nlev,nb
  real, intent(inout) :: var(nlon,nlev,nb)
  real :: some_var(nlon,nlev,nb)
  integer :: ibl
  integer :: offset
  integer :: some_val
  integer :: loop_start, loop_end
  loop_start = 2
  loop_end = nb
  some_val = 0
  offset = 1
  !$omp test
  do ibl=loop_start, loop_end
    call kernel(nlon,nlev,var(:,:,ibl), some_var(:,:,ibl),offset, loop_start, &
    &           loop_end{', ibl, nb' if block_dim_arg else ''})
  enddo
end subroutine driver
"""

    fcode_kernel = f"""
module kernel_mod
implicit none
contains
subroutine kernel(nlon,nlev,var,another_var,icend,lstart,lend{', ibl, nb' if block_dim_arg else ''})
  use compute_mod, only: compute
  implicit none
  integer, intent(in) :: nlon,nlev,icend,lstart,lend
  real, intent(inout) :: var(nlon,nlev)
  real, intent(inout) :: another_var(nlon, nlev)
  {'integer, intent(in) :: ibl' if block_dim_arg else ''}
  {'integer, intent(in) :: nb' if block_dim_arg else ''}
  integer :: jk, jl
  var(:,:) = 0.
  do jk = 1,nlev
    do jl = 1, nlon
      var(jl, jk) = 0.
    end do
  end do
  call compute(nlon,nlev,var)
  call compute(nlon,nlev,another_var)
end subroutine kernel
end module kernel_mod
"""

    fcode_nested_kernel = """
module compute_mod
implicit none
contains
subroutine compute(nlon,nlev,var)
  implicit none
  integer, intent(in) :: nlon,nlev
  real, intent(inout) :: var(nlon,nlev)
  var(:,:) = 0.
end subroutine compute
end module compute_mod
"""

    nested_kernel_mod = Module.from_source(fcode_nested_kernel, frontend=frontend, xmods=[tmp_path])
    kernel_mod = Module.from_source(fcode_kernel, frontend=frontend, definitions=nested_kernel_mod, xmods=[tmp_path])
    driver = Subroutine.from_source(fcode_driver, frontend=frontend, definitions=kernel_mod, xmods=[tmp_path])

    # lower block index (dimension/shape) as prerequisite for 'InjectBlockIndexTransformation'
    targets = ('kernel', 'compute')
    LowerBlockIndexTransformation(blocking, recurse_to_kernels=recurse_to_kernels).apply(driver,
            role='driver', targets=targets)
    LowerBlockIndexTransformation(blocking, recurse_to_kernels=recurse_to_kernels).apply(kernel_mod['kernel'],
            role='kernel', targets=targets)
    LowerBlockIndexTransformation(blocking, recurse_to_kernels=recurse_to_kernels).apply(nested_kernel_mod['compute'],
            role='kernel')

    kernel_call = FindNodes(ir.CallStatement).visit(driver.body)[0]
    if block_dim_arg:
        assert blocking.size in kernel_call.arguments
        assert blocking.index in kernel_call.arguments
    else:
        assert blocking.size in [kwarg[0] for kwarg in kernel_call.kwarguments]
        assert blocking.index in [kwarg[0] for kwarg in kernel_call.kwarguments]
    assert blocking.size in kernel_mod['kernel'].arguments
    assert blocking.index in kernel_mod['kernel'].arguments

    kernel_array_args = [arg for arg in kernel_mod['kernel'].arguments if isinstance(arg, sym.Array)]
    nested_kernel_array_args = [arg for arg in nested_kernel_mod['compute'].arguments if isinstance(arg, sym.Array)]
    for array in kernel_array_args:
        assert blocking.size in array.dimensions
        assert blocking.size in array.shape
    if recurse_to_kernels:
        for array in nested_kernel_array_args:
            assert blocking.size in array.dimensions
            assert blocking.size in array.shape
    else:
        for array in nested_kernel_array_args:
            assert blocking.size not in array.dimensions
            assert blocking.size not in array.shape

    arrays = [var for var in FindVariables().visit(kernel_mod['kernel'].body) if isinstance(var, sym.Array)]
    for array in arrays:
        if array.name.lower() in [arg.name.lower() for arg in kernel_mod['kernel'].arguments]:
            assert blocking.size in array.shape
            assert blocking.index not in array.dimensions

    InjectBlockIndexTransformation(blocking).apply(driver, role='driver', targets=targets)
    InjectBlockIndexTransformation(blocking).apply(kernel_mod['kernel'], role='kernel', targets=targets)
    InjectBlockIndexTransformation(blocking).apply(nested_kernel_mod['compute'], role='kernel')

    arrays = [var for var in FindVariables().visit(kernel_mod['kernel'].body) if isinstance(var, sym.Array)]
    for array in arrays:
        if array.name.lower() in [arg.name.lower() for arg in kernel_mod['kernel'].arguments]:
            assert blocking.size in array.shape
            assert not array.dimensions or blocking.index in array.dimensions

    driver_loops = FindNodes(ir.Loop).visit(driver.body)
    kernel_loops = FindNodes(ir.Loop).visit(kernel_mod['kernel'].body)
    assert any(loop.variable == blocking.index for loop in driver_loops)
    assert not any(loop.variable == blocking.index for loop in kernel_loops)

    LowerBlockLoopTransformation(blocking).apply(driver, role='driver', targets=targets)
    LowerBlockLoopTransformation(blocking).apply(kernel_mod['kernel'], role='kernel', targets=targets)
    LowerBlockLoopTransformation(blocking).apply(nested_kernel_mod['compute'], role='kernel')

    driver_calls = FindNodes(ir.CallStatement).visit(driver.body)
    assert driver_calls[0].pragma[0].keyword.lower() == 'loki'
    assert 'removed_loop' in driver_calls[0].pragma[0].content.lower()
    parameters = get_pragma_parameters(driver_calls[0].pragma, starts_with='removed_loop')
    assert parameters == {'var': 'ibl', 'lower': 'loop_start', 'upper': 'loop_end', 'step': '1'}
    driver_loops = FindNodes(ir.Loop).visit(driver.body)
    kernel_loops = FindNodes(ir.Loop).visit(kernel_mod['kernel'].body)
    assert not any(loop.variable == blocking.index for loop in driver_loops)
    assert any(loop.variable == blocking.index for loop in kernel_loops)
    kernel_call = FindNodes(ir.CallStatement).visit(driver.body)[0]
    if block_dim_arg:
        assert blocking.size in kernel_call.arguments
        assert blocking.index not in kernel_call.arguments
    else:
        assert blocking.size in [kwarg[0] for kwarg in kernel_call.kwarguments]
        assert blocking.index not in [kwarg[0] for kwarg in kernel_call.kwarguments]
    assert blocking.size in kernel_mod['kernel'].arguments
    assert blocking.index not in kernel_mod['kernel'].arguments
    assert blocking.index in kernel_mod['kernel'].variable_map


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('recurse_to_kernels', (False, True))
@pytest.mark.parametrize('targets', (('kernel', 'another_kernel', 'compute'), ('kernel', 'compute')))
def test_lower_loop(blocking, frontend, recurse_to_kernels, targets, tmp_path):

    fcode_driver = """
subroutine driver(nlon,nlev,nb,var)
  use kernel_mod, only: kernel
  use another_kernel_mod, only: another_kernel
  implicit none
  integer, intent(in) :: nlon,nlev,nb
  real, intent(inout) :: var(nlon,nlev,nb)
  real :: some_var(nlon,nlev,nb)
  integer :: jkglo, ibl, status, jk, jl
  do jkglo=1,nb,nlev
    ibl = (jkglo-1)/(nlev+1)
    call kernel(nlon,nlev,var(:,:,ibl), some_var(:,:,ibl))
    call another_kernel(nlon,nlev,var(:,:,ibl), some_var(:,:,ibl))
    status = 1
  enddo
  do jkglo=1,nb,nlev
    ibl = (jkglo-1)/(nlev+1)
    call kernel(nlon,nlev,var(:,:,ibl), some_var(:,:,ibl))
  enddo
  do jkglo=1,nb,nlev
    ibl = (jkglo-1)/(nlev+1)
    do jk = 1,nlev
     do jl = 1, nlon
      some_var(jl, jk, jkglo) = 0.
    end do
   end do
  enddo
end subroutine driver
"""

    fcode_kernel = """
module kernel_mod
implicit none
contains
subroutine kernel(nlon,nlev,var,another_var)
  use compute_mod, only: compute
  implicit none
  integer, intent(in) :: nlon,nlev
  real, intent(inout) :: var(nlon,nlev)
  real, intent(inout) :: another_var(nlon, nlev)
  var(:,:) = 0.
  call compute(nlon,nlev,var)
  call compute(nlon,nlev,another_var)
end subroutine kernel
end module kernel_mod
"""

    fcode_another_kernel = """
module another_kernel_mod
implicit none
contains
subroutine another_kernel(nlon,nlev,var,another_var)
  implicit none
  integer, intent(in) :: nlon,nlev
  real, intent(inout) :: var(nlon,nlev)
  real, intent(inout) :: another_var(nlon, nlev)
  integer :: jk, jl
  var(:,:) = 0.
  do jk = 1,nlev
    do jl = 1, nlon
      var(jl, jk) = 0.
      another_var(jl, jk) = 0.
    end do
  end do
end subroutine another_kernel
end module another_kernel_mod
"""

    fcode_nested_kernel = """
module compute_mod
implicit none
contains
subroutine compute(nlon,nlev,var)
  implicit none
  integer, intent(in) :: nlon,nlev
  real, intent(inout) :: var(nlon,nlev)
  integer :: jk, jl
  do jk = 1,nlev
    do jl = 1, nlon
      var(jl, jk) = 0.
    end do
  end do
end subroutine compute
end module compute_mod
"""

    nested_kernel_mod = Module.from_source(fcode_nested_kernel, frontend=frontend, xmods=[tmp_path])
    kernel_mod = Module.from_source(fcode_kernel, frontend=frontend, definitions=nested_kernel_mod, xmods=[tmp_path])
    another_kernel_mod = Module.from_source(fcode_another_kernel, frontend=frontend, xmods=[tmp_path])
    driver = Subroutine.from_source(fcode_driver, frontend=frontend, definitions=(kernel_mod, another_kernel_mod),
            xmods=[tmp_path])

    LowerBlockIndexTransformation(blocking, recurse_to_kernels=recurse_to_kernels).apply(driver,
            role='driver', targets=targets)
    LowerBlockIndexTransformation(blocking, recurse_to_kernels=recurse_to_kernels).apply(kernel_mod['kernel'],
            role='kernel', targets=targets)
    if 'another_kernel' in targets:
        LowerBlockIndexTransformation(blocking,
                recurse_to_kernels=recurse_to_kernels).apply(another_kernel_mod['another_kernel'],
                role='kernel', targets=targets)
    LowerBlockIndexTransformation(blocking,
            recurse_to_kernels=recurse_to_kernels).apply(nested_kernel_mod['compute'],
            role='kernel')

    kernel_calls = [call for call in FindNodes(ir.CallStatement).visit(driver.body)
            if str(call.name).lower() in targets]
    for kernel_call in kernel_calls:
        assert blocking.size in [kwarg[0] for kwarg in kernel_call.kwarguments]
        assert blocking.index in [kwarg[0] for kwarg in kernel_call.kwarguments]
    assert blocking.size in kernel_mod['kernel'].arguments
    assert blocking.index in kernel_mod['kernel'].arguments
    if 'another_kernel' in targets:
        assert blocking.size in another_kernel_mod['another_kernel'].arguments
        assert blocking.index in another_kernel_mod['another_kernel'].arguments

    kernel_array_args = [arg for arg in kernel_mod['kernel'].arguments if isinstance(arg, sym.Array)]
    another_kernel_array_args = [arg for arg in another_kernel_mod['another_kernel'].arguments
            if isinstance(arg, sym.Array)]
    nested_kernel_array_args = [arg for arg in nested_kernel_mod['compute'].arguments if isinstance(arg, sym.Array)]
    test_array_args = kernel_array_args
    test_array_args += another_kernel_array_args if 'another_kernel' in targets else []
    test_array_args += nested_kernel_array_args if recurse_to_kernels else []
    for array in test_array_args:
        assert blocking.size in array.dimensions
        assert blocking.size in array.shape
    if not recurse_to_kernels:
        for array in nested_kernel_array_args:
            assert blocking.size not in array.dimensions
            assert blocking.size not in array.shape

    arrays = [var for var in FindVariables().visit(kernel_mod['kernel'].body) if isinstance(var, sym.Array)]
    arrays += [var for var in FindVariables().visit(another_kernel_mod['another_kernel'].body)
            if isinstance(var, sym.Array)] if 'another_kernel' in targets else []
    arrays += [var for var in FindVariables().visit(nested_kernel_mod['compute'].body)
            if isinstance(var, sym.Array)] if recurse_to_kernels else []
    for array in arrays:
        if array.name.lower() in [arg.name.lower() for arg in test_array_args]:
            assert blocking.size in array.shape
            assert blocking.index not in array.dimensions

    InjectBlockIndexTransformation(blocking).apply(driver, role='driver', targets=targets)
    InjectBlockIndexTransformation(blocking).apply(kernel_mod['kernel'], role='kernel', targets=targets)
    if 'another_kernel' in targets:
        InjectBlockIndexTransformation(blocking).apply(another_kernel_mod['another_kernel'],
                role='kernel', targets=targets)
    InjectBlockIndexTransformation(blocking).apply(nested_kernel_mod['compute'], role='kernel')

    arrays = [var for var in FindVariables().visit(kernel_mod['kernel'].body) if isinstance(var, sym.Array)]
    arrays += [var for var in FindVariables().visit(another_kernel_mod['another_kernel'].body)
            if isinstance(var, sym.Array)] if 'another_kernel' in targets else []
    arrays += [var for var in FindVariables().visit(nested_kernel_mod['compute'].body)
            if isinstance(var, sym.Array)] if recurse_to_kernels else []
    for array in arrays:
        if array.name.lower() in [arg.name.lower() for arg in test_array_args]:
            assert blocking.size in array.shape
            assert not array.dimensions or blocking.index in array.dimensions

    driver_loops = FindNodes(ir.Loop).visit(driver.body)
    kernel_loops = FindNodes(ir.Loop).visit(kernel_mod['kernel'].body)
    another_kernel_loops = FindNodes(ir.Loop).visit(another_kernel_mod['another_kernel'].body)
    assert any(loop.variable in blocking.indices for loop in driver_loops)
    assert not any(loop.variable in blocking.indices for loop in kernel_loops)
    if 'another_kernel' in targets:
        assert not any(loop.variable in blocking.indices for loop in another_kernel_loops)

    LowerBlockLoopTransformation(blocking).apply(driver, role='driver', targets=targets)
    LowerBlockLoopTransformation(blocking).apply(kernel_mod['kernel'], role='kernel', targets=targets)
    if 'another_kernel' in targets:
        LowerBlockLoopTransformation(blocking).apply(another_kernel_mod['another_kernel'],
                role='kernel', targets=targets)
    LowerBlockLoopTransformation(blocking).apply(nested_kernel_mod['compute'], role='kernel')

    driver_calls = [call for call in FindNodes(ir.CallStatement).visit(driver.body) if call.pragma is not None]
    if 'another_kernel' in targets:
        assert len(driver_calls) == 3
    else:
        assert len(driver_calls) == 2
    for driver_call in driver_calls:
        assert driver_call.pragma[0].keyword.lower() == 'loki'
        assert 'removed_loop' in driver_call.pragma[0].content.lower()
        parameters = get_pragma_parameters(driver_call.pragma, starts_with='removed_loop')
        assert parameters == {'var': 'jkglo', 'lower': '1', 'upper': 'nb', 'step': 'nlev'}
    driver_loops = FindNodes(ir.Loop).visit(driver.body)
    kernel_loops = FindNodes(ir.Loop).visit(kernel_mod['kernel'].body)
    another_kernel_loops = FindNodes(ir.Loop).visit(another_kernel_mod['another_kernel'].body)
    assert len([loop for loop in driver_loops if loop.variable in blocking.indices]) == 1
    assert any(loop.variable in blocking.indices for loop in kernel_loops)
    if 'another_kernel' in targets:
        assert any(loop.variable in blocking.indices for loop in another_kernel_loops)
    kernel_call = FindNodes(ir.CallStatement).visit(driver.body)[0]
    assert blocking.size in [kwarg[0] for kwarg in kernel_call.kwarguments]
    assert blocking.index not in [kwarg[0] for kwarg in kernel_call.kwarguments]
    for index in blocking.indices:
        assert index not in [kwarg[0] for kwarg in kernel_call.kwarguments]
        assert index not in kernel_mod['kernel'].arguments
        if 'another_kernel' in targets:
            assert index not in another_kernel_mod['another_kernel'].arguments
    assert blocking.size in kernel_mod['kernel'].arguments
    assert blocking.index not in kernel_mod['kernel'].arguments
    assert blocking.index in kernel_mod['kernel'].variable_map
loki-ecmwf-0.3.6/loki/transformations/field_api.py0000664000175000017500000004622515167130205022427 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
A set of utility classes for dealing with FIELD API boilerplate in
parallel kernels and offload regions.
"""

from enum import Enum
from itertools import chain

from loki.expression import symbols as sym
from loki.ir import nodes as ir
from loki.types import Scope


__all__ = [
    'FieldAPITransferType', 'FieldPointerMap', 'get_field_type', 'field_get_device_data',
    'field_get_host_data', 'field_sync_device', 'field_sync_host', 'field_create_device_data',
    'field_delete_device_data', 'field_wait_for_async_queue', 'FieldAPIAccessorType'
]


class FieldAPITransferType(Enum):
    READ_ONLY = 'RDONLY'
    READ_WRITE = 'RDWR'
    WRITE_ONLY = 'WRONLY'
    FORCE = 'FORCE'

    @property
    def suffix(self):
        return self.value


class FieldAPITransferDirection(Enum):
    DEVICE_TO_HOST = 'HOST'
    HOST_TO_DEVICE = 'DEVICE'

    @property
    def suffix(self):
        return self.value


class FieldAPIAccessorType(Enum):

    TYPE_BOUND = 'GET'
    """
    Create FIELD_API data access calls using the native type-bound methods e.g.
    
    .. code-block:: fortran
    
        CALL FIELD%GET_HOST/DEVICE_DATA_...()
    """

    GENERIC = 'SGET'
    """
    Create FIELD_API data access calls using a generic interface that
    takes a FIELD as an argument e.g.
    
    .. code-block:: fortran
    
        CALL SGET_HOST/DEVICE_DATA_...(..., FIELD)

    This mode offers additional safety for uninitialised and zero-sized fields.
    """

    def __str__(self):
        return self.value


class FieldPointerMap:
    """
    Helper class to map FIELD API pointers to intents and access descriptors.

    This utility is used to store arrays passed to target kernel calls
    and easily access corresponding device pointers added by the transformation.
    """

    def __init__(self, inargs, inoutargs, outargs, scope, ptr_prefix='loki_ptr_'):
        # Ensure no duplication between in/inout/out args
        inoutargs += tuple(v for v in inargs if v in outargs)
        inargs = tuple(v for v in inargs if v not in inoutargs)
        outargs = tuple(v for v in outargs if v not in inoutargs)

        # Filter out duplicates and return as tuple
        self.inargs = tuple(dict.fromkeys(a.clone(dimensions=None) for a in inargs))
        self.inoutargs = tuple(dict.fromkeys(a.clone(dimensions=None) for a in inoutargs))
        self.outargs = tuple(dict.fromkeys(a.clone(dimensions=None) for a in outargs))

        # Filter out duplicates across argument tuples
        self.inargs = tuple(a for a in self.inargs if a not in self.inoutargs)

        self.scope = scope

        self.ptr_prefix = ptr_prefix

    def dataptr_from_array(self, a: sym.Array):
        """
        Returns a contiguous pointer :any:`Variable` with types matching the array :data:`a`.
        """
        shape = (sym.RangeIndex((None, None)),) * (len(a.shape)+1)
        dataptr_type = a.type.clone(pointer=True, contiguous=True, shape=shape, intent=None)
        base_name = a.name if a.parent is None else '_'.join(a.name.split('%'))
        return sym.Variable(name=self.ptr_prefix + base_name, type=dataptr_type, dimensions=shape)

    @staticmethod
    def field_ptr_from_view(field_view):
        """
        Returns a symbol for the pointer to the corresponding Field object.
        """
        type_chain = field_view.name.split('%')
        field_type_name = 'F_' + type_chain[-1]
        return field_view.parent.get_derived_type_member(field_type_name)

    @property
    def args(self):
        """ A tuple of all argument symbols, concatanating in/inout/out arguments """
        return tuple(chain(*(self.inargs, self.inoutargs, self.outargs)))

    @property
    def dataptrs(self):
        """ Create a list of contiguous data pointer symbols """
        return tuple(dict.fromkeys(self.dataptr_from_array(a) for a in self.args))

    @property
    def host_to_device_calls(self):
        """
        Returns a tuple of :any:`CallStatement` for host-to-device transfers on fields.
        """
        READ_ONLY, READ_WRITE = FieldAPITransferType.READ_ONLY, FieldAPITransferType.READ_WRITE

        host_to_device = tuple(field_get_device_data(
            self.field_ptr_from_view(arg), self.dataptr_from_array(arg), READ_ONLY, scope=self.scope
        ) for arg in self.inargs)
        host_to_device += tuple(field_get_device_data(
            self.field_ptr_from_view(arg), self.dataptr_from_array(arg), READ_WRITE, scope=self.scope
        ) for arg in self.inoutargs)
        host_to_device += tuple(field_get_device_data(
            self.field_ptr_from_view(arg), self.dataptr_from_array(arg), READ_WRITE, scope=self.scope
        ) for arg in self.outargs)

        return tuple(dict.fromkeys(host_to_device))

    @property
    def sync_host_calls(self):
        """
        Returns a tuple of :any:`CallStatement` for host-synchronization transfers on fields.
        """
        READ_WRITE = FieldAPITransferType.READ_WRITE

        sync_host = tuple(
            field_sync_host(self.field_ptr_from_view(arg), transfer_type=READ_WRITE, scope=self.scope)
            for arg in self.inoutargs
        )
        sync_host += tuple(
            field_sync_host(self.field_ptr_from_view(arg), transfer_type=READ_WRITE, scope=self.scope)
            for arg in self.outargs
        )
        return tuple(dict.fromkeys(sync_host))

    def host_to_device_force_calls(self, queue=None, blk_bounds=None, offset=None):
        """
        Returns a tuple of :any:`CallStatement` for host-to-device force transfers on fields.
        """
        FORCE = FieldAPITransferType.FORCE
        host_to_device = tuple(field_get_device_data(
            self.field_ptr_from_view(arg), self.dataptr_from_array(arg), transfer_type=FORCE,
            scope=self.scope, queue=queue, blk_bounds=blk_bounds, offset=offset)
                               for arg in chain(self.inargs, self.inoutargs, self.outargs))
        return tuple(dict.fromkeys(host_to_device))


    def sync_host_force_calls(self, queue=None, blk_bounds=None, offset=None):
        """
        Returns a tuple of :any:`CallStatement` for host-synchronization transfers on fields.
        """
        FORCE = FieldAPITransferType.FORCE

        sync_host = tuple(field_sync_host(
            self.field_ptr_from_view(arg), transfer_type=FORCE, scope=self.scope,
            queue=queue, blk_bounds=blk_bounds, offset=offset)
                          for arg in chain(self.inoutargs, self.outargs))
        return tuple(dict.fromkeys(sync_host))


def get_field_type(a: sym.Array) -> sym.DerivedType:
    """
    Returns the corresponding FIELD API type for an array.

    This function is IFS specific and assumes that the
    type is an array declared with one of the IFS type specifiers, e.g. KIND=JPRB
    """
    type_map = ["jprb",
                "jpit",
                "jpis",
                "jpim",
                "jpib",
                "jpia",
                "jprt",
                "jprs",
                "jprm",
                "jprd",
                "jplm"]
    type_name = a.type.kind.name

    assert type_name.lower() in type_map, ('Error array type kind is: '
                                           f'"{type_name}" which is not a valid IFS type specifier')
    rank = len(a.shape)
    field_type = sym.DerivedType(name="field_" + str(rank) + type_name[2:4].lower())
    return field_type


def _field_get_data(field_ptr, dev_ptr, transfer_type: FieldAPITransferType,
                    transfer_direction: FieldAPITransferDirection,
                    scope: Scope, queue=None, blk_bounds=None, offset=None,
                    accessor_type: FieldAPIAccessorType=FieldAPIAccessorType.TYPE_BOUND):
    """
    Internal function to generate FIELD API ``GET DATA`` calls.

    .. note::
        This routine is not meant to be called from any code outisde `field_api.py`, then the
        corresponding :any:`field_get_device_data` or :any:`field_get_host_data` functions
        should be called instead.

    Parameters
    ----------
    field_ptr: pointer to field object
        Pointer to the field to call ``GET_DEVICE_DATA`` from.
    dev_ptr: :any:`Array`
        Device pointer array
    transfer_type: :any:`FieldAPITransferType`
        Field API transfer type to determine which type of ``GET DATA`` method to call.
    transfer_direction: :any:`FieldAPITransferDirection`
        Field API transfer direction to determine which type of ``GET DATA`` method to call.
    scope: :any:`Scope`
        Scope of the created :any:`CallStatement`
    queue: integer
       ``QUEUE`` optional  argument
    blk_bounds: integer dimension(2) array
        ``BLK_BOUNDS`` optional argument
    offset: integer
        ``OFFSET`` optional argument
    accessor_type: :any:`FieldAPIAccessorType`
        Type of accessor to be used, e.g., 'get_' type bound method or 'sget'
    """
    if not isinstance(transfer_type, FieldAPITransferType):
        raise TypeError("transfer_type must be of type FieldAPITransferType, " +
                        f"but is of type {type(transfer_type)}")
    if not isinstance(transfer_direction, FieldAPITransferDirection):
        raise TypeError("transfer_direction must be of type FieldAPITransferDirection, " +
                        f"but is of type {type(transfer_direction)}")

    if transfer_type != FieldAPITransferType.FORCE and (queue is not None or blk_bounds is not None):
        raise ValueError("Only force copy methods can have non-None type queue or blk_bounds")
    if (transfer_type == FieldAPITransferType.WRITE_ONLY and
        transfer_direction == FieldAPITransferDirection.DEVICE_TO_HOST
    ):
        raise TypeError("incorrect transfer_type (WRITE_ONLY) for Field-API get method")

    procedure_name = f'{accessor_type}_' + transfer_direction.suffix + '_DATA_' + transfer_type.suffix

    kwargs = []
    if queue is not None:
        kwargs.append(('queue', queue))
    if blk_bounds is not None:
        kwargs.append(('blk_bounds', blk_bounds))
    if offset is not None:
        kwargs.append(('offset', offset))

    kwargs = tuple(kwargs) if len(kwargs) > 0 else None

    if accessor_type == FieldAPIAccessorType.TYPE_BOUND:
        return ir.CallStatement(name=sym.ProcedureSymbol(procedure_name, parent=field_ptr, scope=scope),
                                arguments=(dev_ptr.clone(dimensions=None),), kwarguments=kwargs)
    return ir.CallStatement(name=sym.ProcedureSymbol(procedure_name, scope=scope),
            arguments=(dev_ptr.clone(dimensions=None), field_ptr), kwarguments=kwargs)


def field_get_device_data(field_ptr, dev_ptr, transfer_type: FieldAPITransferType, scope: Scope,
                          queue=None, blk_bounds=None, offset=None,
                          accessor_type: FieldAPIAccessorType=FieldAPIAccessorType.TYPE_BOUND):
    """
    Utility function to generate a :any:`CallStatement` corresponding to a Field API
    ``GET_DEVICE_DATA`` call.

    Parameters
    ----------
    field_ptr: pointer to field object
        Pointer to the field to call ``GET_DEVICE_DATA`` from.
    dev_ptr: :any:`Array`
        Device pointer array
    transfer_type: :any:`FieldAPITransferType`
        Field API transfer type to determine which ``GET_DEVICE_DATA`` method to call.
    scope: :any:`Scope`
        Scope of the created :any:`CallStatement`
    queue: integer
       ``QUEUE`` optional  argument
    blk_bounds: integer dimension(2) array
        ``BLK_BOUNDS`` optional argument
    offset: integer
        ``OFFSET`` optional argument
    accessor_type: :any:`FieldAPIAccessorType`
        Type of accessor to be used, e.g., 'get_' type bound method or 'sget'
    """
    return _field_get_data(field_ptr, dev_ptr, transfer_type, FieldAPITransferDirection.HOST_TO_DEVICE,
                           scope, queue=queue, blk_bounds=blk_bounds, offset=offset, accessor_type=accessor_type)


def field_get_host_data(field_ptr, dev_ptr, transfer_type: FieldAPITransferType, scope: Scope,
                        queue=None, blk_bounds=None, offset=None,
                        accessor_type: FieldAPIAccessorType=FieldAPIAccessorType.TYPE_BOUND):
    """
    Utility function to generate a :any:`CallStatement` corresponding to a Field API
    ``GET_HOST_DATA`` call.

    Parameters
    ----------
    field_ptr: pointer to field object
        Pointer to the field to call ``GET_DEVICE_DATA`` from.
    dev_ptr: :any:`Array`
        Device pointer array
    transfer_type: :any:`FieldAPITransferType`
        Field API transfer type to determine which ``GET_HOST_DATA`` method to call.
    scope: :any:`Scope`
        Scope of the created :any:`CallStatement`
    queue: integer
       ``QUEUE`` optional  argument
    blk_bounds: integer dimension(2) array
        ``BLK_BOUNDS`` optional argument
    offset: integer
        ``OFFSET`` optional argument
    accessor_type: :any:`FieldAPIAccessorType`
        Type of accessor to be used, e.g., 'get_' type bound method or 'sget'
    """
    return _field_get_data(field_ptr, dev_ptr, transfer_type, FieldAPITransferDirection.DEVICE_TO_HOST,
                           scope, queue=queue, blk_bounds=blk_bounds, offset=offset, accessor_type=accessor_type)


def _field_sync(field_ptr, transfer_type: FieldAPITransferType,
                transfer_direction: FieldAPITransferDirection,
                scope: Scope, queue=None, blk_bounds=None, offset=None):
    """
    Internal function to generate FIELD API ``SYNC`` calls.

    .. note::
        This routine is not meant to be called from any code outisde `field_api.py`, then the
        corresponding :any:`field_sync_host` or :any:`field_sync_device` functions should be
        called instead.

    Parameters
    ----------
    field_ptr: pointer to field object
        Pointer to the field to call ``GET_DEVICE_DATA`` from.
    transfer_type: :any:`FieldAPITransferType`
        Field API transfer type to determine which type of ``GET DATA`` method to call.
    transfer_direction: :any:`FieldAPITransferDirection`
        Field API transfer direction to determine which type of ``GET DATA`` method to call.
    scope: :any:`Scope`
        Scope of the created :any:`CallStatement`
    queue: integer
       ``QUEUE`` optional  argument
    blk_bounds: integer dimension(2) array
        ``BLK_BOUNDS`` optional argument
    offset: integer
        ``OFFSET`` optional argument
    """

    if not isinstance(transfer_type, FieldAPITransferType):
        raise TypeError("transfer_type must be of type FieldAPITransferType, " +
                        f"but is of type {type(transfer_type)}")
    if not isinstance(transfer_direction, FieldAPITransferDirection):
        raise TypeError("transfer_direction must be of type FieldAPITransferDirection, " +
                        f"but is of type {type(transfer_direction)}")

    if transfer_type != FieldAPITransferType.FORCE and (queue is not None or blk_bounds is not None):
        raise ValueError("Only force copy methods can have non-None type queue or blk_bounds")

    if (
        transfer_type == FieldAPITransferType.WRITE_ONLY and
        transfer_direction == FieldAPITransferDirection.DEVICE_TO_HOST
    ):
        raise TypeError("incorrect transfer_type for Field-API sync method")

    procedure_name = 'SYNC_' + transfer_direction.suffix + '_' + transfer_type.suffix

    kwargs = []
    if queue is not None:
        kwargs.append(('queue', queue))
    if blk_bounds is not None:
        kwargs.append(('blk_bounds', blk_bounds))
    if offset is not None:
        kwargs.append(('offset', offset))
    kwargs = tuple(kwargs) if len(kwargs) > 0 else None

    return ir.CallStatement(name=sym.ProcedureSymbol(procedure_name, parent=field_ptr, scope=scope),
                            kwarguments=kwargs)


def field_sync_device(field_ptr, transfer_type: FieldAPITransferType, scope: Scope,
                      queue=None, blk_bounds=None, offset=None):
    """
    Utility function to generate a :any:`CallStatement` corresponding to a Field API
    ``SYNC_DEVICE`` call.

    Parameters
    ----------
    field_ptr: pointer to field object
        Pointer to the field to call ``SYNC_HOST`` from.
    transfer_type: :any:`FieldAPITransferType`
        Field API transfer type to determine which ``SYNC_DEVICE`` method to call.
    scope: :any:`Scope`
        Scope of the created :any:`CallStatement`
    queue: integer
       ``QUEUE`` optional  argument
    blk_bounds: integer dimension(2) array
        ``BLK_BOUNDS`` optional argument
    offset: integer
        ``OFFSET`` optional argument
    """

    return _field_sync(field_ptr, transfer_type, FieldAPITransferDirection.HOST_TO_DEVICE,
                       scope, queue=queue, blk_bounds=blk_bounds, offset=offset)


def field_sync_host(field_ptr, transfer_type: FieldAPITransferType, scope: Scope,
                    queue=None, blk_bounds=None, offset=None):
    """
    Utility function to generate a :any:`CallStatement` corresponding to a Field API
    ``SYNC_HOST`` call.

    Parameters
    ----------
    field_ptr: pointer to field object
        Pointer to the field to call ``SYNC_HOST`` from.
    transfer_type: :any:`FieldAPITransferType`
        Field API transfer type to determine which ``SYNC_HOST`` method to call.
    scope: :any:`Scope`
        Scope of the created :any:`CallStatement`
    queue: integer
       ``QUEUE`` optional  argument
    blk_bounds: integer dimension(2) array
        ``BLK_BOUNDS`` optional argument
    offset: integer
        ``OFFSET`` optional argument
    """

    return _field_sync(field_ptr, transfer_type, FieldAPITransferDirection.DEVICE_TO_HOST,
                       scope, queue=queue, blk_bounds=blk_bounds, offset=offset)


def field_create_device_data(field_ptr, scope: Scope, blk_bounds=None):
    """
    Utility unction to generate a :any:`CallStatement` corresponding to a Field API
    `CREATE_DEVICE_DATA` call.

    Parameters
    ----------
    field_ptr: pointer to field object
        Pointer to the field to call ``DELETE_DEVICE_DATA`` from.
    scope: :any:`Scope`
        Scope of the created :any:`CallStatement`
    blk_bounds: integer dimension(2) array
        ``BLK_BOUNDS`` optional argument
    """
    kwargs = (('blk_bounds', blk_bounds),) if blk_bounds else None
    return ir.CallStatement(name=sym.ProcedureSymbol('CREATE_DEVICE_DATA', parent=field_ptr, scope=scope),
                            kwarguments=kwargs)


def field_delete_device_data(field_ptr, scope):
    """
    Utility unction to generate a :any:`CallStatement` corresponding to a Field API
    `DELETE_DEVICE_DATA` call.

    Parameters
    ----------
    field_ptr: pointer to field object
        Pointer to the field to call ``DELETE_DEVICE_DATA`` from.
    scope: :any:`Scope`
        Scope of the created :any:`CallStatement`
    """

    procedure_name = 'DELETE_DEVICE_DATA'
    return ir.CallStatement(name=sym.ProcedureSymbol(procedure_name, parent=field_ptr, scope=scope), arguments=())


def field_wait_for_async_queue(queue, scope: Scope):
    return ir.CallStatement(name=sym.ProcedureSymbol('WAIT_FOR_ASYNC_QUEUE', scope=scope),
                            arguments=(queue,))
loki-ecmwf-0.3.6/loki/transformations/sanitise/0000775000175000017500000000000015167130205021747 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/transformations/sanitise/__init__.py0000664000175000017500000000716415167130205024070 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
"""
Sub-package with assorted utility :any:`Transformation` classes to
harmonize the look-and-feel of input source code.
"""
from functools import partial

from loki.batch import Transformation, Pipeline

from loki.transformations.sanitise.associates import * # noqa
from loki.transformations.sanitise.sequence_associations import * # noqa
from loki.transformations.sanitise.substitute import * # noqa


"""
:any:`Pipeline` class that provides combined access to the features
provided by the following :any:`Transformation` classes, in sequence:
1. :any:`SubstituteExpressionTransformation` - String-based generic
expression substitution.
2. :any:`AssociatesTransformation` - Full or partial resolution of
nested :any:`Associate` nodes, including optional merging of
independent association pairs.
3. :any:`SequenceAssociationTransformation` - Resolves sequence
association patterns in the call signature of :any:`CallStatement`
nodes.

Parameters
----------
substitute_expressions : bool
    Flag to trigger or suppress expression substitution
expression_map : dict of str to str
    A string-to-string map detailing the substitutions to apply.
substitute_spec : bool
    Flag to trigger or suppress expression substitution in specs.
substitute_body : bool
    Flag to trigger or suppress expression substitution in bodies.
resolve_associates : bool, default: True
    Enable full or partial resolution of only :any:`Associate`
    scopes.
merge_associates : bool, default: False
    Enable merging :any:`Associate` to the outermost possible
    scope in nested associate blocks.
start_depth : int, optional
    Starting depth for partial resolution of :any:`Associate`
    after merging.
max_parents : int, optional
    Maximum number of parent symbols for valid selector to have
    when merging :any:`Associate` nodes.
resolve_sequence_associations : bool
    Flag to trigger or suppress resolution of sequence associations
"""
SanitisePipeline = partial(
    Pipeline, classes=(
        SubstituteExpressionTransformation,
        AssociatesTransformation,
        SequenceAssociationTransformation,
    )
)


class SanitiseTransformation(Transformation):
    """
    :any:`Transformation` object to apply several code sanitisation
    steps when batch-processing large source trees via the :any:`Scheduler`.

    Parameters
    ----------
    resolve_associate_mappings : bool
        Resolve ASSOCIATE mappings in body of processed subroutines; default: True.
    resolve_sequence_association : bool
        Replace scalars that are passed to array arguments with array
        ranges; default: False.
    """

    def __init__(
            self, resolve_associate_mappings=True, resolve_sequence_association=False
    ):
        self.resolve_associate_mappings = resolve_associate_mappings
        self.resolve_sequence_association = resolve_sequence_association

    def transform_subroutine(self, routine, **kwargs):

        # Associates at the highest level, so they don't interfere
        # with the sections we need to do for detecting subroutine calls
        if self.resolve_associate_mappings:
            do_resolve_associates(routine)

        # Transform arrays passed with scalar syntax to array syntax
        if self.resolve_sequence_association:
            do_resolve_sequence_association(routine)
loki-ecmwf-0.3.6/loki/transformations/sanitise/tests/0000775000175000017500000000000015167130205023111 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/transformations/sanitise/tests/__init__.py0000664000175000017500000000057015167130205025224 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
loki-ecmwf-0.3.6/loki/transformations/sanitise/tests/test_sequence_associations.py0000664000175000017500000000677015167130205031123 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki import Module, Subroutine
from loki.frontend import available_frontends, OMNI
from loki.ir import nodes as ir, FindNodes

from loki.transformations.sanitise import (
    SequenceAssociationTransformation, do_resolve_sequence_association
)


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('use_trafo', [False, True])
def test_resolve_sequence_assocaition_scalar_notation(tmp_path, frontend, use_trafo):
    fcode = """
module mod_a
  implicit none

  type type_b
    integer :: c
    integer :: d
  end type type_b

  type type_a
    type(type_b) :: b
  end type type_a

contains

  subroutine main()
    type(type_a) :: a
    integer :: k, m, n

    real :: array(10,10)
    real :: another(1:8, 2)

    ! Test array with scalar dimension
    call sub_x(array(1, 1), 1)
    call sub_x(array(2, 2), 2)
    call sub_x(array(m, 1), k)
    call sub_x(array(m-1, 1), k-1)
    call sub_x(array(a%b%c, 1), a%b%d)

    ! Test array with range dimension
    call sub_x(another(1, 1), 1)
    call sub_x(another(2, 2), 2)
    call sub_x(another(m, 1), k)

  contains

    subroutine sub_x(array, k)
      integer, intent(in) :: k
      real, intent(in)    :: array(k:n)

    end subroutine sub_x

  end subroutine main

end module mod_a
    """.strip()

    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    routine = module['main']

    if use_trafo:
        SequenceAssociationTransformation(
            resolve_sequence_associations=True
        ).apply(routine)
    else:
        do_resolve_sequence_association(routine)

    calls = FindNodes(ir.CallStatement).visit(routine.body)
    assert all(c.name == 'sub_x' for c in calls)

    assert calls[0].arguments == ('array(1:10, 1)', 1)
    assert calls[1].arguments == ('array(2:10, 2)', 2)
    assert calls[2].arguments == ('array(m:10, 1)', 'k')
    assert calls[3].arguments == ('array(m - 1:10, 1)', 'k - 1')
    assert calls[4].arguments == ('array(a%b%c:10, 1)', 'a%b%d')

    assert calls[5].arguments == ('another(1:8, 1)', 1)
    assert calls[6].arguments == ('another(2:8, 2)', 2)
    assert calls[7].arguments == ('another(m:8, 1)', 'k')


@pytest.mark.parametrize('frontend', available_frontends(
    skip=[(OMNI, 'OMNI does not like not knowing shapes!')]
))
def test_resolve_sequence_assocaition_missing_shape(frontend):
    fcode = """
subroutine test_resolve_seq_assoc_no_shape(a, n)
  use ricks_module, only: my_type
  implicit none

  type(my_type), intent(inout) :: a
  integer, intent(in) :: n

  ! Test array with no known shape
  call sub_x(a%a(1, 1), 1)
  call sub_x(a%b(1), 1)
  call sub_x(a%c, 1)

contains

  subroutine sub_x(array, k)
    integer, intent(in) :: k
    real, intent(in)    :: array(k:n)

  end subroutine sub_x
end subroutine test_resolve_seq_assoc_no_shape
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    do_resolve_sequence_association(routine)

    calls = FindNodes(ir.CallStatement).visit(routine.body)
    assert all(c.name == 'sub_x' for c in calls)

    assert calls[0].arguments == ('a%a(:, 1)', 1)
    assert calls[1].arguments == ('a%b(:)', 1)
    assert calls[2].arguments == ('a%c', 1)
loki-ecmwf-0.3.6/loki/transformations/sanitise/tests/test_sanitise.py0000664000175000017500000001061515167130205026344 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki import Module
from loki.frontend import available_frontends
from loki.ir import nodes as ir, FindNodes

from loki.transformations.sanitise import (
    SanitiseTransformation, SanitisePipeline
)


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('resolve_associate', [True, False])
@pytest.mark.parametrize('resolve_sequence', [True, False])
def test_transformation_sanitise(frontend, resolve_associate, resolve_sequence, tmp_path):
    """
    Test that the selective dispatch of the sanitisations works.
    """

    fcode = """
module test_transformation_sanitise_mod
  implicit none

  type rick
    real :: scalar
  end type rick
contains

  subroutine test_transformation_sanitise(a, dave)
    real, intent(inout) :: a(3)
    type(rick), intent(inout) :: dave

    associate(scalar => dave%scalar)
      scalar = a(1) + a(2)

      call vadd(a(1), 2.0, 3)
    end associate

  contains
    subroutine vadd(x, y, n)
      real, intent(inout) :: x(n)
      real, intent(inout) :: y
      integer, intent(in) :: n

      x = x + 2.0
    end subroutine vadd
  end subroutine test_transformation_sanitise
end module test_transformation_sanitise_mod
"""
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    routine = module['test_transformation_sanitise']

    assoc = FindNodes(ir.Associate).visit(routine.body)
    assert len(assoc) == 1
    calls = FindNodes(ir.CallStatement).visit(routine.body)
    assert len(calls) == 1
    assert calls[0].arguments[0] == 'a(1)'

    trafo = SanitiseTransformation(
        resolve_associate_mappings=resolve_associate,
        resolve_sequence_association=resolve_sequence,
    )
    trafo.apply(routine)

    assoc = FindNodes(ir.Associate).visit(routine.body)
    assert len(assoc) == 0 if resolve_associate else 1

    calls = FindNodes(ir.CallStatement).visit(routine.body)
    assert len(calls) == 1
    assert calls[0].arguments[0] == 'a(1:3)' if resolve_sequence else 'a(1)'


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('substitute_expressions', [True, False])
@pytest.mark.parametrize('resolve_associates', [True, False])
@pytest.mark.parametrize('resolve_sequence_associations', [True, False])
def test_sanitise_pipeline(
        tmp_path, frontend, substitute_expressions,
        resolve_associates, resolve_sequence_associations
):
    """
    Test the agglomerated :any:`SanitisePipeline` with different settings.
    """
    fcode = """
module test_sanitise_pipeline_mod
  implicit none

  type rick
    real :: scalar
  end type rick
contains

  subroutine test_pipeline_sanitise(n, a, dave)
    integer, intent(in) :: n
    real, intent(inout) :: a(n+1)
    type(rick), intent(inout) :: dave

    associate(scalar => dave%scalar)
      scalar = a(1) + a(2)

      call vadd(a(1), 2.0, n+1)
    end associate

  contains
    subroutine vadd(x, y, n)
      real, intent(inout) :: x(n)
      real, intent(inout) :: y
      integer, intent(in) :: n

      x = x + 2.0
    end subroutine vadd
  end subroutine test_pipeline_sanitise
end module test_sanitise_pipeline_mod
"""
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    routine = module['test_pipeline_sanitise']

    assoc = FindNodes(ir.Associate).visit(routine.body)
    assert len(assoc) == 1
    calls = FindNodes(ir.CallStatement).visit(routine.body)
    assert len(calls) == 1

    pipeline = SanitisePipeline(
        substitute_expressions=substitute_expressions,
        expression_map={'n + 1': 'n'},
        resolve_associates=resolve_associates,
        resolve_sequence_associations=resolve_sequence_associations
    )
    pipeline.apply(routine)

    assoc = FindNodes(ir.Associate).visit(routine.body)
    assert len(assoc) == 0 if resolve_associates else 1

    calls = FindNodes(ir.CallStatement).visit(routine.body)
    assert len(calls) == 1
    if resolve_sequence_associations:
        assert calls[0].arguments[0] == 'a(1:n)' if substitute_expressions else 'a(1:n+1)'
    else:
        assert calls[0].arguments[0] == 'a(1)'
loki-ecmwf-0.3.6/loki/transformations/sanitise/tests/test_associates.py0000664000175000017500000005010315167130205026657 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki import BasicType, Subroutine
from loki.expression import symbols as sym
from loki.frontend import available_frontends, OMNI
from loki.ir import nodes as ir, FindNodes

from loki.transformations.sanitise import (
    do_resolve_associates, do_merge_associates,
    ResolveAssociatesTransformer, AssociatesTransformation
)


@pytest.mark.parametrize('frontend', available_frontends(
    skip=[(OMNI, 'OMNI does not handle missing type definitions')]
))
def test_transform_associates_simple(frontend):
    """
    Test association resolver on simple cases.
    """
    fcode = """
subroutine transform_associates_simple
  use some_module, only: some_obj
  implicit none

  real :: local_var

  associate (a => some_obj%a)
    local_var = a(:)
  end associate
end subroutine transform_associates_simple
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(ir.Associate).visit(routine.body)) == 1
    assert len(FindNodes(ir.Assignment).visit(routine.body)) == 1
    assign = FindNodes(ir.Assignment).visit(routine.body)[0]
    assert assign.rhs == 'a(:)' and 'some_obj' not in assign.rhs
    assert assign.rhs.type.dtype == BasicType.DEFERRED

    # Now apply the association resolver
    do_resolve_associates(routine)

    assert len(FindNodes(ir.Associate).visit(routine.body)) == 0
    assert len(FindNodes(ir.Assignment).visit(routine.body)) == 1
    assign = FindNodes(ir.Assignment).visit(routine.body)[0]
    assert assign.rhs == 'some_obj%a(:)'
    assert assign.rhs.parent == 'some_obj'
    assert assign.rhs.type.dtype == BasicType.DEFERRED
    assert assign.rhs.scope == routine
    assert assign.rhs.parent.scope == routine


@pytest.mark.parametrize('frontend', available_frontends(
    skip=[(OMNI, 'OMNI does not handle missing type definitions')]
))
def test_transform_associates_nested(frontend):
    """
    Test association resolver with deeply nested associates.
    """
    fcode = """
subroutine transform_associates_nested
  use some_module, only: some_obj
  implicit none

  real :: rick

  associate (never => some_obj%never)
    associate (gonna => never%gonna)
      associate (a => gonna%give%you%up)
        rick = a
      end associate
    end associate
  end associate
end subroutine transform_associates_nested
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(ir.Associate).visit(routine.body)) == 3
    assert len(FindNodes(ir.Assignment).visit(routine.body)) == 1
    assign = FindNodes(ir.Assignment).visit(routine.body)[0]
    assert assign.lhs == 'rick' and assign.rhs == 'a'
    assert assign.rhs.type.dtype == BasicType.DEFERRED

    # Now apply the association resolver
    do_resolve_associates(routine)

    assert len(FindNodes(ir.Associate).visit(routine.body)) == 0
    assert len(FindNodes(ir.Assignment).visit(routine.body)) == 1
    assign = FindNodes(ir.Assignment).visit(routine.body)[0]
    assert assign.rhs == 'some_obj%never%gonna%give%you%up'


@pytest.mark.parametrize('frontend', available_frontends(
    skip=[(OMNI, 'OMNI does not handle missing type definitions')]
))
def test_transform_associates_ambiguous(frontend):
    """
    Test association resolver with some ambiguity.
    """
    fcode = """
subroutine transform_associates_ambiguous
  use some_module, only: some_obj
  use other_module, only: other_obj
  implicit none

    associate(ndim => some_obj%ndim, ndims => some_obj%ndims, nested_obj => other_obj%nested)
      nested_obj%ndim = 2
      nested_obj%ndims(:) = 5
      call some_func(nested_obj%ndim)
    end associate
end subroutine transform_associates_ambiguous
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(ir.Associate).visit(routine.body)) == 1
    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assigns) == 2
    calls = FindNodes(ir.CallStatement).visit(routine.body)
    assert len(calls) == 1
    assert assigns[0].lhs == 'nested_obj%ndim' and assigns[0].rhs == '2'
    assert assigns[1].lhs == 'nested_obj%ndims(:)' and assigns[1].rhs == '5'
    assert calls[0].arguments[0] == 'nested_obj%ndim'

    # Now apply the association resolver
    do_resolve_associates(routine)

    assert len(FindNodes(ir.Associate).visit(routine.body)) == 0
    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assigns) == 2
    calls = FindNodes(ir.CallStatement).visit(routine.body)
    assert len(calls) == 1
    assert assigns[0].lhs == 'other_obj%nested%ndim' and assigns[0].rhs == '2'
    assert assigns[1].lhs == 'other_obj%nested%ndims(:)' and assigns[1].rhs == '5'
    assert calls[0].arguments[0] == 'other_obj%nested%ndim'


@pytest.mark.parametrize('frontend', available_frontends(
    skip=[(OMNI, 'OMNI does not handle missing type definitions')]
))
def test_transform_associates_array_call(frontend):
    """
    Test a neat corner case where a component of an associated array
    is used as a keyword argument in a subroutine call.
    """
    fcode = """
subroutine transform_associates_simple
  use some_module, only: some_obj
  implicit none

  integer :: i
  real :: local_var
  real, allocatable :: local_arr(:)

  associate (some_array => some_obj%some_array, a => some_obj%a)
    allocate(local_arr(a%n))

    do i=1, 5
      call another_routine(i, n=some_array(i)%n)
    end do
  end associate
end subroutine transform_associates_simple
"""

    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(ir.Associate).visit(routine.body)) == 1
    assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 1
    call = FindNodes(ir.CallStatement).visit(routine.body)[0]
    assert call.kwarguments[0][1] == 'some_array(i)%n'
    assert call.kwarguments[0][1].type.dtype == BasicType.DEFERRED
    assert routine.variable_map['local_arr'].type.shape == (':',)
    allocs = FindNodes(ir.Allocation).visit(routine.body)
    assert allocs[0].variables[0].dimensions == ('a%n',)

    # Now apply the association resolver
    do_resolve_associates(routine)

    assert len(FindNodes(ir.Associate).visit(routine.body)) == 0
    assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 1
    call = FindNodes(ir.CallStatement).visit(routine.body)[0]
    assert call.kwarguments[0][1] == 'some_obj%some_array(i)%n'
    assert call.kwarguments[0][1].scope == routine
    assert call.kwarguments[0][1].type.dtype == BasicType.DEFERRED

    # Test that symbols in the allocation have been resolved
    assert routine.variable_map['local_arr'].type.shape == (':',)
    allocs = FindNodes(ir.Allocation).visit(routine.body)
    assert allocs[0].variables[0].dimensions == ('some_obj%a%n',)


@pytest.mark.parametrize('frontend', available_frontends(
    skip=[(OMNI, 'OMNI does not handle missing type definitions')]
))
def test_transform_associates_array_slices(frontend):
    """
    Test the resolution of associated array slices.
    """
    fcode = """
subroutine transform_associates_slices(arr2d, arr3d)
  use some_module, only: some_obj, another_routine
  implicit none
  real, intent(inout) :: arr2d(:,:), arr3d(:,:,:)
  integer :: i, j
  integer, parameter :: idx_a = 2
  integer, parameter :: idx_c = 3

  associate (a => arr2d(:, 1), b=>arr2d(:, idx_a), &
           & c => arr3d(:,:,idx_c), idx => some_obj%idx)
    b(:) = 42.0
    do i=1, 5
      a(i) = b(i+2)
      call another_routine(i, a(2:4), b)
      do j=1, 7
        c(i, j) = c(i, j) + b(j)
        c(i, idx) = c(i, idx) + 42.0
      end do
    end do
  end associate
end subroutine transform_associates_slices
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(ir.Associate).visit(routine.body)) == 1
    assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 1
    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assigns) == 4
    calls = FindNodes(ir.CallStatement).visit(routine.body)
    assert len(calls) == 1
    assert calls[0].arguments[1] == 'a(2:4)'
    assert calls[0].arguments[2] == 'b'

    # Now apply the association resolver
    do_resolve_associates(routine)

    assert len(FindNodes(ir.Associate).visit(routine.body)) == 0
    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assigns) == 4
    assert assigns[0].lhs == 'arr2d(:, idx_a)'
    assert assigns[1].lhs == 'arr2d(i, 1)'
    assert assigns[1].rhs == 'arr2d(i+2, idx_a)'
    assert assigns[2].lhs == 'arr3d(i, j, idx_c)'
    assert assigns[2].rhs == 'arr3d(i, j, idx_c) + arr2d(j, idx_a)'
    assert assigns[3].lhs == 'arr3d(i, some_obj%idx, idx_c)'
    assert assigns[3].rhs == 'arr3d(i, some_obj%idx, idx_c) + 42.0'

    calls = FindNodes(ir.CallStatement).visit(routine.body)
    assert len(calls) == 1
    assert calls[0].arguments[1] == 'arr2d(2:4, 1)'
    assert calls[0].arguments[2] == 'arr2d(:, idx_a)'


@pytest.mark.parametrize('frontend', available_frontends(
    skip=[(OMNI, 'OMNI does not handle missing type definitions')]
))
def test_transform_associates_nested_conditional(frontend):
    """
    Test association resolver when associate is nested into a conditional.
    """
    fcode = """
subroutine transform_associates_nested_conditional
    use some_module, only: some_obj, some_flag
    implicit none

    real :: local_var

    if (some_flag) then
        local_var = 0.
    else
        ! Other nodes before the associate
        ! This one, too

        ! And this one
        associate (a => some_obj%a)
            local_var = a
            ! And a conditional which may inject a tuple nesting in the IR
            if (local_var > 10.) then
                local_var = 10.
            end if
        end associate
        ! And nodes after it

        ! like this
    end if
end subroutine transform_associates_nested_conditional
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(ir.Conditional).visit(routine.body)) == 2
    assert len(FindNodes(ir.Associate).visit(routine.body)) == 1
    assert len(FindNodes(ir.Assignment).visit(routine.body)) == 3
    assign = FindNodes(ir.Assignment).visit(routine.body)[1]
    assert assign.rhs == 'a' and 'some_obj' not in assign.rhs
    assert assign.rhs.type.dtype == BasicType.DEFERRED

    # Now apply the association resolver
    do_resolve_associates(routine)

    assert len(FindNodes(ir.Conditional).visit(routine.body)) == 2
    assert len(FindNodes(ir.Associate).visit(routine.body)) == 0
    assert len(FindNodes(ir.Assignment).visit(routine.body)) == 3
    assign = FindNodes(ir.Assignment).visit(routine.body)[1]
    assert assign.rhs == 'some_obj%a'
    assert assign.rhs.parent == 'some_obj'
    assert assign.rhs.type.dtype == BasicType.DEFERRED
    assert assign.rhs.scope == routine
    assert assign.rhs.parent.scope == routine


@pytest.mark.parametrize('frontend', available_frontends(
    skip=[(OMNI, 'OMNI does not handle missing type definitions')]
))
def test_transform_associates_partial_body(frontend):
    """
    Test resolving associated symbols, but only for a part of an
    associate's body.
    """
    fcode = """
subroutine transform_associates_partial
  use some_module, only: some_obj
  implicit none

  integer :: i
  real :: local_var

  associate (a=>some_obj%a, b=>some_obj%b)
    local_var = a(1)

    do i=1, some_obj%n
      a(i) = a(i) + 1.
      b(i) = b(i) + 1.
    end do
  end associate
end subroutine transform_associates_partial
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(ir.Assignment).visit(routine.body)) == 3
    loops = FindNodes(ir.Loop).visit(routine.body)
    assert len(loops) == 1

    transformer = ResolveAssociatesTransformer(inplace=True)
    transformer.visit(loops[0])

    # Check that associated symbols have been resolved in loop body only
    assert len(FindNodes(ir.Loop).visit(routine.body)) == 1
    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assigns) == 3
    assert assigns[0].lhs == 'local_var'
    assert assigns[0].rhs == 'a(1)'
    assert assigns[1].lhs == 'some_obj%a(i)'
    assert assigns[1].rhs == 'some_obj%a(i) + 1.'
    assert assigns[2].lhs == 'some_obj%b(i)'
    assert assigns[2].rhs == 'some_obj%b(i) + 1.'


@pytest.mark.parametrize('frontend', available_frontends(
    skip=[(OMNI, 'OMNI does not handle missing type definitions')]
))
def test_transform_associates_start_depth(frontend):
    """
    Test resolving associated symbols, but only for a part of an
    associate's body.
    """
    fcode = """
subroutine transform_associates_partial
  use some_module, only: some_obj
  implicit none

  integer :: i
  real :: local_var

  associate (a=>some_obj%a, b=>some_obj%b)
  associate (c=>a%b, d=>b%d)
    local_var = a(1)

    do i=1, some_obj%n
      c(i) = c(i) + 1.
      d(i) = d(i) + 1.
    end do
  end associate
  end associate
end subroutine transform_associates_partial
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(ir.Assignment).visit(routine.body)) == 3
    loops = FindNodes(ir.Loop).visit(routine.body)
    assert len(loops) == 1

    # Resolve all expect the outermost associate block
    do_resolve_associates(routine, start_depth=1)

    # Check that associated symbols have been resolved in loop body only
    assert len(FindNodes(ir.Loop).visit(routine.body)) == 1
    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assigns) == 3
    assert assigns[0].lhs == 'local_var'
    assert assigns[0].rhs == 'a(1)'
    assert assigns[1].lhs == 'a%b(i)'
    assert assigns[1].rhs == 'a%b(i) + 1.'
    assert assigns[2].lhs == 'b%d(i)'
    assert assigns[2].rhs == 'b%d(i) + 1.'


@pytest.mark.parametrize('frontend', available_frontends(
    skip=[(OMNI, 'OMNI does not handle missing type definitions')]
))
def test_merge_associates_nested(frontend):
    """
    Test association merging for nested mappings.
    """
    fcode = """
subroutine merge_associates_simple(base)
  use some_module, only: some_type
  implicit none

  type(some_type), intent(inout) :: base
  integer :: i
  real :: local_var

  associate(a => base%a)
  associate(b => base%other%symbol)
  associate(d => base%other%symbol%really%deep, &
   &        a => base%a, c => a%more)
    do i=1, 5
      call another_routine(i, n=b(c)%n)

      d(i) = 42.0
    end do
  end associate
  end associate
  end associate
end subroutine merge_associates_simple
"""

    routine = Subroutine.from_source(fcode, frontend=frontend)

    assocs = FindNodes(ir.Associate).visit(routine.body)
    assert len(assocs) == 3
    assert len(assocs[0].associations) == 1
    assert len(assocs[1].associations) == 1
    assert len(assocs[2].associations) == 3

    # Move associate mapping around
    do_merge_associates(routine, max_parents=2)

    assocs = FindNodes(ir.Associate).visit(routine.body)
    assert len(assocs) == 3
    assert len(assocs[0].associations) == 2
    assert assocs[0].associations[0] == ('base%a', 'a')
    assert assocs[0].associations[1] == ('base%other%symbol', 'b')
    assert len(assocs[1].associations) == 1
    assert assocs[1].associations[0] == ('a%more', 'c')
    assert len(assocs[2].associations) == 1
    assert assocs[2].associations[0] == ('base%other%symbol%really%deep', 'd')

    # Check that body symbols have been rescoped correctly
    call = FindNodes(ir.CallStatement).visit(routine.body)[0]
    b_c_n = call.kwarguments[0][1]  # b(c)%n
    assert b_c_n.scope == assocs[0]


@pytest.mark.parametrize('frontend', available_frontends(
    skip=[(OMNI, 'OMNI does not handle missing type definitions')]
))
@pytest.mark.parametrize('merge', [False, True])
@pytest.mark.parametrize('resolve', [False, True])
def test_associates_transformation(frontend, merge, resolve):
    """
    Test association merging paired with partial resolution of inner
    scopes via :any:`AssociatesTransformation`.
    """
    fcode = """
subroutine merge_associates_simple(base)
  use some_module, only: some_type
  implicit none

  type(some_type), intent(inout) :: base
  integer :: i
  real :: local_var

  associate(a => base%a)
  associate(b => base%b)
  associate(c => a%c)
  associate(d => c%d)
    do i=1, 5
      call another_routine(b(i), c%n)

      d(i) = 42.0
    end do
  end associate
  end associate
  end associate
  end associate
end subroutine merge_associates_simple
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    AssociatesTransformation(
        resolve_associates=resolve, merge_associates=merge, start_depth=1
    ).apply(routine)

    assocs = FindNodes(ir.Associate).visit(routine.body)
    call = FindNodes(ir.CallStatement).visit(routine.body)[0]
    assign = FindNodes(ir.Assignment).visit(routine.body)[0]

    if not merge and not resolve:
        assert len(assocs) == 4
        assert all(len(a.associations) == 1 for a in assocs)

        assert call.arguments[0] == 'b(i)'
        assert call.arguments[1] == 'c%n'
        assert assign.lhs == 'd(i)'

    if merge and not resolve:
        assert len(assocs) == 4
        assert assocs[0].associations == (('base%a', 'a'), ('base%b', 'b'))
        assert assocs[1].associations == (('a%c', 'c'), )
        assert assocs[2].associations == ()
        assert assocs[3].associations == (('c%d', 'd'), )

        assert call.arguments[0] == 'b(i)'
        assert call.arguments[1] == 'c%n'
        assert assign.lhs == 'd(i)'

    if not merge and resolve:
        assert len(assocs) == 1
        assert assocs[0].associations == (('base%a', 'a'),)

        assert call.arguments[0] == 'base%b(i)'
        assert call.arguments[1] == 'a%c%n'
        assert assign.lhs == 'a%c%d(i)'

    if merge and resolve:
        assert len(assocs) == 1
        assert assocs[0].associations == (('base%a', 'a'), ('base%b', 'b'))

        assert call.arguments[0] == 'b(i)'
        assert call.arguments[1] == 'a%c%n'
        assert assign.lhs == 'a%c%d(i)'


@pytest.mark.parametrize('frontend', available_frontends(
    skip=[(OMNI, 'OMNI does not handle missing type definitions')]
))
@pytest.mark.parametrize('depth', [0, 1, 2])
def test_resolve_associates_stmt_func(frontend, depth):
    """
    Test scope management for stmt funcs, either as
    :any:`ProcedureSymbol` or :any:`DeferredTypeSymbol`.
    """
    fcode = """
subroutine test_associates_stmt_func(ydcst, a, b)
  use yomcst, only: tcst
  implicit none
  type(tcst), intent(in) :: ydcst
  real(kind=8), intent(inout) :: a, b
#include "some_stmt.func.h"
  real(kind=8) :: not_an_array
  not_an_array ( x, y ) =  x * y

associate(d=>b)
associate(c=>a)
associate(RTT=>YDCST%RTT)
  a = not_an_array(RTT, 1.0) + a
  b = some_stmt_func(RTT, 1.0) + b
end associate
end associate
end associate
end subroutine test_associates_stmt_func
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    associates = FindNodes(ir.Associate).visit(routine.body)
    assert len(associates) == 3
    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assigns) == 2
    assert isinstance(assigns[0].rhs.children[0], sym.InlineCall)
    assert assigns[0].rhs.children[0].function.scope == associates[2]
    assert isinstance(assigns[1].rhs.children[0], sym.InlineCall)
    assert assigns[1].rhs.children[0].function.scope == associates[2]

    do_resolve_associates(routine, start_depth=depth)

    associates = FindNodes(ir.Associate).visit(routine.body)
    assert len(associates) == depth

    assigns = FindNodes(ir.Assignment).visit(routine.body)
    # Determine the outer routine or last associate left
    outer_scope = routine if depth == 0 else associates[depth-1]
    assert len(assigns) == 2
    assert assigns[0].rhs == 'not_an_array(YDCST%RTT, 1.0) + a'
    assert assigns[1].rhs == 'some_stmt_func(YDCST%RTT, 1.0) + b'
    assert isinstance(assigns[0].rhs.children[0], sym.InlineCall)
    assert assigns[0].rhs.children[0].function.scope == outer_scope
    assert isinstance(assigns[1].rhs.children[0], sym.InlineCall)
    assert assigns[1].rhs.children[0].function.scope == outer_scope

    # Trigger a full clone, which would fail if scopes are missing
    routine.clone()
loki-ecmwf-0.3.6/loki/transformations/sanitise/associates.py0000664000175000017500000003003015167130205024453 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


"""
A small selection of utility transformations that resolve certain code
constructs to unify code structure and make reasoning about Fortran
code easier.
"""

from loki.batch import Transformation
from loki.expression import symbols as sym,  LokiIdentityMapper
from loki.ir import nodes as ir, Transformer, NestedTransformer
from loki.logging import warning
from loki.tools import dict_override
from loki.types import SymbolTable


__all__ = [
    'AssociatesTransformation', 'do_resolve_associates',
    'ResolveAssociatesTransformer', 'do_merge_associates'
]


class AssociatesTransformation(Transformation):
    """
    :any:`Transformation` object to apply code sanitisation steps
    specific to :any:`Associate` nodes.

    It allows merging in nested :any:`Associate` scopes to move
    independent assocation pairs to the outermost scope, optionally
    restricted by a number of ``max_parents`` symbols.

    It also provides partial or full resolution of :any:`Associate`
    nodes by replacing ``identifier`` symbols with the corresponding
    ``selector`` in the node's body.

    Parameters
    ----------
    resolve_associates : bool, default: True
        Enable full or partial resolution of only :any:`Associate`
        scopes.
    merge_associates : bool, default: False
        Enable merging :any:`Associate` to the outermost possible
        scope in nested associate blocks.
    start_depth : int, optional
        Starting depth for partial resolution of :any:`Associate`
        after merging.
    max_parents : int, optional
        Maximum number of parent symbols for valid selector to have
        when merging :any:`Associate` nodes.
    """

    def __init__(
            self, resolve_associates=True, merge_associates=False,
            start_depth=0, max_parents=None
    ):
        self.resolve_associates = resolve_associates
        self.merge_associates = merge_associates

        self.start_depth = start_depth
        self.max_parents = max_parents

    def transform_subroutine(self, routine, **kwargs):

        # Merge associates first so that remainig ones can be resolved
        if self.merge_associates:
            do_merge_associates(routine, max_parents=self.max_parents)

        # Resolve remaining associates depending on start_depth
        if self.resolve_associates:
            do_resolve_associates(routine, start_depth=self.start_depth)


def do_resolve_associates(routine, start_depth=0):
    """
    Resolve :any:`Associate` mappings in the body of a given routine.

    Optionally, partial resolution of only inner :any:`Associate`
    mappings is supported when a ``start_depth`` is specified.

    Parameters
    ----------
    routine : :any:`Subroutine`
        The subroutine for which to resolve all associate blocks.
    start_depth : int, optional
        Starting depth for partial resolution of :any:`Associate`
    """
    transformer = ResolveAssociatesTransformer(start_depth=start_depth)
    routine.body = transformer.visit(routine.body)

    # Ensure that all symbols have the appropriate scope attached.
    # This is needed, as the parent of a symbol might have changed,
    # which affects the symbol's type-defining scope.
    routine.rescope_symbols()


class ResolveAssociateMapper(LokiIdentityMapper):
    """
    Exppression mapper that will resolve symbol associations due
    :any:`Associate` scopes.

    The mapper will inspect the associated scope of each symbol
    and replace it with the inverse of the associate mapping.
    """

    def __init__(self, *args, start_depth=0, **kwargs):
        self.start_depth = start_depth
        super().__init__(*args, **kwargs)

    @staticmethod
    def _match_range_indices(expressions, indices):
        """ Map :data:`indices` to free ranges in :data:`expressions` """
        assert isinstance(expressions, tuple)
        assert isinstance(indices, tuple)

        free_symbols = tuple(e for e in expressions if isinstance(e, sym.RangeIndex))
        if any(s.lower not in (None, 1) for s in free_symbols):
            warning('WARNING: Bounds shifts through association is currently not supported')

        if len(free_symbols) == len(indices):
            # If the provided indices are enough to bind free symbols,
            # we match them in sequence.
            it = iter(indices)
            return tuple(
                next(it) if isinstance(e, sym.RangeIndex) else e
                for e in expressions
            )

        return expressions

    def map_scalar(self, expr, *args, **kwargs):
        # Skip unscoped expressions
        if not hasattr(expr, 'scope'):
            return self.rec(expr, *args, **kwargs)

        # Stop if scope is not an associate
        if not isinstance(expr.scope, ir.Associate):
            return expr

        scope = expr.scope

        # Determine the depth of the symbol-defining associate
        depth = len(tuple(
            p for p in scope.parents if isinstance(p, ir.Associate)
        )) + 1
        if depth <= self.start_depth:
            return expr

        # Recurse on parent first and propagate scope changes
        parent = self.rec(expr.parent, *args, **kwargs)
        if parent != expr.parent:
            expr = expr.clone(parent=parent, scope=parent.scope)

        # Find a match in the given inverse map if the expr has no parent
        #  which is a prerequisite for a possible replacement and avoids
        #  false replacement because of some ambiguity
        if expr.parent is None and expr.basename in scope.inverse_map:
            expr_candidate = scope.inverse_map[expr.basename]
            return self.rec(expr_candidate, *args, **kwargs)

        # Update the scope, as any inner associates will be removed.
        # For this we count backwards the nested scopes, the tail of
        # which will the (innermost) associates.
        new_scope = scope.parents[::-1][depth-self.start_depth-1]
        return expr.clone(scope=new_scope)

    def map_array(self, expr, *args, **kwargs):
        """ Partially resolve dimension indices and handle shape """

        # Recurse over existing array dimensions
        expr_dims = self.rec(expr.dimensions, *args, **kwargs)

        # Recurse over the type's shape
        _type = expr.type
        if expr.type.shape:
            new_shape = self.rec(expr.type.shape, *args, **kwargs)
            _type = expr.type.clone(shape=new_shape)

        # Stop if scope is not an associate
        if not isinstance(expr.scope, ir.Associate):
            return expr.clone(dimensions=expr_dims, type=_type)

        new = self.map_scalar(expr, *args, **kwargs)

        # Recurse over array dimensions
        if isinstance(new, sym.Array) and new.dimensions:
            # Resolve unbound range symbols form existing indices
            new_dims = self.rec(new.dimensions, *args, **kwargs)
            new_dims = self._match_range_indices(new_dims, expr_dims)
        else:
            new_dims = expr_dims

        return new.clone(dimensions=new_dims, type=_type)

    map_variable_symbol = map_scalar
    map_deferred_type_symbol = map_scalar
    map_procedure_symbol = map_scalar


class ResolveAssociatesTransformer(Transformer):
    """
    :any:`Transformer` class to resolve :any:`Associate` nodes in IR trees.

    This will replace each :any:`Associate` node with its own body,
    where all ``identifier`` symbols have been replaced with the
    corresponding ``selector`` expression defined in ``associations``.

    Importantly, this :any:`Transformer` can also be applied over partial
    bodies of :any:`Associate` bodies.

    Optionally, partial resolution of only inner :any:`Associate`
    mappings is supported when a ``start_depth`` is specified.

    Parameters
    ----------
    start_depth : int, optional
        Starting depth for partial resolution of :any:`Associate`
    """
    # pylint: disable=unused-argument

    def __init__(self, start_depth=0, **kwargs):
        self.start_depth = start_depth
        super().__init__(**kwargs)

    def visit_Expression(self, o, **kwargs):
        return ResolveAssociateMapper(start_depth=self.start_depth)(o)

    def visit_Associate(self, o, **kwargs):
        """
        Replaces an :any:`Associate` node with its transformed body
        """

        # Establish traversal depth in kwargs
        depth = kwargs.get('depth', 1)

        # First head-recurse, so that all associate blocks beneath are resolved
        with dict_override(kwargs, {'depth': depth + 1}):
            body = self.visit(o.body, **kwargs)

        if depth <= self.start_depth:
            return o.clone(body=body)

        return body

    def visit_CallStatement(self, o, **kwargs):
        arguments = self.visit(o.arguments, **kwargs)
        kwarguments = tuple((k, self.visit(v, **kwargs)) for k, v in o.kwarguments)
        return o._rebuild(arguments=arguments, kwarguments=kwarguments)


def do_merge_associates(routine, max_parents=None):
    """
    Moves associate mappings in :any:`Associate` within a
    :any:`Subroutine` to the outermost parent scope.

    Please see :any:`MergeAssociatesTransformer` for mode details.

    Note
    ----
    This method can be combined with :any:`resolve_associates` to
    create a more unified look-and-feel for nested ASSOCIATE blocks.

    Parameters
    ----------
    routine : :any:`Subroutine`
        The subroutine for which to resolve all associate blocks.
    max_parents : int, optional
        Maximum number of parent symbols for valid selector to have.
    """
    transformer = MergeAssociatesTransformer(max_parents=max_parents)
    routine.body = transformer.visit(routine.body)


class MergeAssociatesTransformer(NestedTransformer):
    """
    :any:`NestedTransformer` that moves associate mappings in
    :any:`Associate` to parent nodes.

    If a selector expression depends on a symbol from a parent
    :any:`Associate` exists, it does not get moved.

    Additionally, a maximum parent-depth can be specified for the
    selector to prevent overly long symbols to be moved up.

    Parameters
    ----------
    routine : :any:`Subroutine`
        The subroutine for which to resolve all associate blocks.
    max_parents : int, optional
        Maximum number of parent symbols for valid selector to have.
    """

    def __init__(self, max_parents=None, **kwargs):
        self.max_parents = max_parents
        super().__init__(**kwargs)

    def visit_Associate(self, o, **kwargs):
        body = self.visit(o.body, **kwargs)

        if not o.parent or not isinstance(o.parent, ir.Associate):
            return o._rebuild(body=body, rescope_symbols=True)

        # Find all associate mapping that can be moved up
        to_move = tuple(
            (expr, name) for expr, name in o.associations
            if not expr.scope == o.parent
        )

        if self.max_parents:
            # Optionally filter by depth of symbol-parentage
            to_move = tuple(
                (expr, name) for expr, name in to_move
                if not len(expr.parents) > self.max_parents
            )

        # Move up to parent ...
        parent_assoc = tuple(
            (expr, name) for expr, name in to_move
            if (expr, name) not in o.parent.associations
        )
        o.parent._update(associations=o.parent.associations + parent_assoc)

        # ... and remove from this associate node
        new_assocs = tuple(
            (expr, name) for expr, name in o.associations
            if (expr, name) not in to_move
        )
        o = o._rebuild(
            body=body, associations=new_assocs, parent=o.parent,
            rescope_symbols=True, symbol_attrs=SymbolTable()
        )
        # We rebuild the local symbol-table from scratch to ensure
        # that moved associations get the correct defining scope
        o._derive_local_symbol_types(parent_scope=o.parent)
        return o
loki-ecmwf-0.3.6/loki/transformations/sanitise/substitute.py0000664000175000017500000000404015167130205024532 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.batch import Transformation
from loki.ir import SubstituteStringExpressions


__all__ = ['SubstituteExpressionTransformation']


class SubstituteExpressionTransformation(Transformation):
    """
    A :any:`Transformation` that allows individual expressions to be
    substituted in :any:`Subroutine` objects.

    The expressions should be provided as a dictionary map of strings,
    which will be parsed in the local :any:`Subroutine` scope to
    determine the respective symbols.

    Parameters
    ----------
    substitute_expressions : bool
        Flag to trigger or suppress expression substitution
    expression_map : dict of str to str
        A string-to-string map detailing the substitutions to apply.
    substitute_spec : bool
        Flag to trigger or suppress expression substitution in specs.
    substitute_body : bool
        Flag to trigger or suppress expression substitution in bodies.
    """

    def __init__(
            self, substitute_expressions=True, expression_map=None,
            substitute_body=True, substitute_spec=True
    ):
        self.substitute_expressions = substitute_expressions
        self.expression_map = expression_map or {}
        self.substitute_spec = substitute_spec
        self.substitute_body = substitute_body

    def transform_subroutine(self, routine, **kwargs):

        if self.substitute_expressions:
            substitute = SubstituteStringExpressions(
                self.expression_map, scope=routine
            )

            if self.substitute_spec:
                routine.spec = substitute.visit(routine.spec)

            if self.substitute_body:
                routine.body = substitute.visit(routine.body)
loki-ecmwf-0.3.6/loki/transformations/sanitise/sequence_associations.py0000664000175000017500000001000515167130205026704 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.batch import Transformation
from loki.expression import Array, RangeIndex
from loki.ir import Transformer
from loki.tools import as_tuple
from loki.types import BasicType


__all__ = [
    'SequenceAssociationTransformation',
    'do_resolve_sequence_association',
    'SequenceAssociationTransformer'
]


class SequenceAssociationTransformation(Transformation):
    """
    :any:`Transformation` that resolves sequence association patterns
    in :any:`CallStatement` nodes.

    Parameters
    ----------
    resolve_sequence_associations : bool
        Flag to trigger or suppress resolution of sequence associations
    """

    def __init__(self, resolve_sequence_associations=True):
        self.resolve_sequence_associations = resolve_sequence_associations

    def transform_subroutine(self, routine, **kwargs):  # pylint: disable=unused-argument
        if self.resolve_sequence_associations:
            do_resolve_sequence_association(routine)


def check_if_scalar_syntax(arg, dummy):
    """
    Check if an array argument, arg,
    is passed to an array dummy argument, dummy,
    using scalar syntax. i.e. arg(1,1) -> d(m,n)

    Parameters
    ----------
    arg:   variable
    dummy: variable
    """
    if isinstance(arg, Array) and isinstance(dummy, Array):
        if arg.dimensions:
            if not any(isinstance(d, RangeIndex) for d in arg.dimensions):
                return True
    return False


def do_resolve_sequence_association(routine):
    """
    Housekeeping routine to replace scalar syntax when passing arrays
    as arguments For example, a call like

    .. code-block::

        real :: a(m,n)

        call myroutine(a(i,j))

    where myroutine looks like

    .. code-block::

        subroutine myroutine(a)
            real :: a(5)
        end subroutine myroutine

    should be changed to

    .. code-block::

        call myroutine(a(i:m,j)

    Parameters
    ----------
    routine : :any:`Subroutine`
        The subroutine where calls will be changed

    """

    routine.body = SequenceAssociationTransformer(inplace=True).visit(routine.body)


class SequenceAssociationTransformer(Transformer):
    """
    Transformer that resolves sequence association patterns in
    :any:`CallStatement` nodes.
    """

    def visit_CallStatement(self, call, **kwargs):  # pylint: disable=unused-argument
        """
        Resolve sequence association patterns in arguments and return
        new :any:`CallStatement` object if any were found.
        """
        if call.procedure_type is BasicType.DEFERRED:
            return call

        new_args = []
        found_scalar = False
        for dummy, arg in call.arg_map.items():
            if check_if_scalar_syntax(arg, dummy):
                found_scalar = True

                n_dims = len(dummy.shape)
                new_dims = []

                if not arg.shape:
                    # Hack: If we don't have a shape, short-circuit here
                    new_dims = tuple(RangeIndex((None, None)) for _ in dummy.shape)
                else:
                    for s, lower in zip(arg.shape[:n_dims], arg.dimensions[:n_dims]):
                        if isinstance(s, RangeIndex):
                            new_dims += [RangeIndex((lower, s.stop))]
                        else:
                            new_dims += [RangeIndex((lower, s))]

                if len(arg.dimensions) > n_dims:
                    new_dims += arg.dimensions[len(dummy.shape):]
                new_args += [arg.clone(dimensions=as_tuple(new_dims)),]
            else:
                new_args += [arg,]

        if found_scalar:
            return call.clone(arguments = as_tuple(new_args))

        return call
loki-ecmwf-0.3.6/loki/transformations/inline/0000775000175000017500000000000015167130205021406 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/transformations/inline/__init__.py0000664000175000017500000000162415167130205023522 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
"""
Transformations sub-package that provides various forms of
source-level code inlining.

The various inline mechanisms are provided as standalone utility methods,
or via the :any:`InlineTransformation` class for for batch processing.
"""

from loki.transformations.inline.constants import * # noqa
from loki.transformations.inline.functions import * # noqa
from loki.transformations.inline.mapper import * # noqa
from loki.transformations.inline.procedures import * # noqa
from loki.transformations.inline.transformation import * # noqa
loki-ecmwf-0.3.6/loki/transformations/inline/tests/0000775000175000017500000000000015167130205022550 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/transformations/inline/tests/__init__.py0000664000175000017500000000057015167130205024663 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
loki-ecmwf-0.3.6/loki/transformations/inline/tests/test_inline_transformation.py0000664000175000017500000005340715167130205030576 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import re
from pathlib import Path
import pytest

from loki import Module, Subroutine, ProcessingStrategy
from loki.frontend import available_frontends
from loki.ir import nodes as ir, FindNodes
from loki.batch import Scheduler, SchedulerConfig, TransformationError, Pipeline

from loki.transformations.inline import InlineTransformation
from loki.transformations.build_system import FileWriteTransformation


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('pass_as_kwarg', (False, True))
def test_inline_transformation(tmp_path, frontend, pass_as_kwarg):
    """Test combining recursive inlining via :any:`InliningTransformation`."""

    fcode_module = """
module one_mod
  real(kind=8), parameter :: one = 1.0
end module one_mod
"""

    fcode_inner = """
subroutine add_one_and_two(a)
  use one_mod, only: one
  implicit none

  real(kind=8), intent(inout) :: a

  a = a + one

  a = add_two(a)

contains
  elemental function add_two(x)
    real(kind=8), intent(in) :: x
    real(kind=8) :: add_two

    add_two = x + 2.0
  end function add_two
end subroutine add_one_and_two
"""

    fcode = f"""
subroutine test_inline_pragma(a, b)
  implicit none
  real(kind=8), intent(inout) :: a(3), b(3)
  integer, parameter :: n = 3
  integer :: i
  real :: stmt_arg
  real :: some_stmt_func
  some_stmt_func ( stmt_arg ) = stmt_arg + 3.1415

#include "add_one_and_two.intfb.h"

  do i=1, n
    !$loki inline
    call add_one_and_two({'a=' if pass_as_kwarg else ''}a(i))
  end do

  do i=1, n
    !$loki inline
    call add_one_and_two({'a=' if pass_as_kwarg else ''}b(i))
  end do

  a(1) = some_stmt_func({'stmt_arg=' if pass_as_kwarg else ''}a(2))

end subroutine test_inline_pragma
"""
    module = Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path])
    inner = Subroutine.from_source(fcode_inner, definitions=module, frontend=frontend, xmods=[tmp_path])
    routine = Subroutine.from_source(fcode, frontend=frontend)
    routine.enrich(inner)

    trafo = InlineTransformation(
        inline_constants=True, external_only=True, inline_elementals=True,
        inline_stmt_funcs=True
    )

    calls = FindNodes(ir.CallStatement).visit(routine.body)
    assert len(calls) == 2
    assert all(c.routine == inner for c in calls)

    # Apply to the inner subroutine first to resolve parameter and calls
    trafo.apply(inner)

    assigns = FindNodes(ir.Assignment).visit(inner.body)
    assert len(assigns) == 3
    assert assigns[0].lhs == 'a' and assigns[0].rhs == 'a + 1.0'
    assert assigns[1].lhs == 'result_add_two' and assigns[1].rhs == 'a + 2.0'
    assert assigns[2].lhs == 'a' and assigns[2].rhs == 'result_add_two'

    # Apply to the outer routine, but with resolved body of the inner
    trafo.apply(routine)

    calls = FindNodes(ir.CallStatement).visit(routine.body)
    assert len(calls) == 0
    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assigns) == 7
    assert assigns[0].lhs == 'a(i)' and assigns[0].rhs == 'a(i) + 1.0'
    assert assigns[1].lhs == 'result_add_two' and assigns[1].rhs == 'a(i) + 2.0'
    assert assigns[2].lhs == 'a(i)' and assigns[2].rhs == 'result_add_two'
    assert assigns[3].lhs == 'b(i)' and assigns[3].rhs == 'b(i) + 1.0'
    assert assigns[4].lhs == 'result_add_two' and assigns[4].rhs == 'b(i) + 2.0'
    assert assigns[5].lhs == 'b(i)' and assigns[5].rhs == 'result_add_two'
    assert assigns[6].lhs == 'a(1)' and assigns[6].rhs == 'a(2) + 3.1415'



@pytest.mark.parametrize('frontend', available_frontends())
def test_inline_transformation_local_seq_assoc(frontend, tmp_path):
    fcode = """
module somemod
    implicit none
    contains

    subroutine minusone_second(output, x)
        real, intent(inout) :: output
        real, intent(in) :: x(3)
        output = x(2) - 1
    end subroutine minusone_second

    subroutine plusone(output, x)
        real, intent(inout) :: output
        real, intent(in) :: x
        output = x + 1
    end subroutine plusone

    subroutine outer()
      implicit none
      real :: x(3, 3)
      real :: y
      x = 10.0

      call inner(y, x(1, 1)) ! Sequence association tmp_path for member routine.

      !$loki inline
      call plusone(y, x(3, 3)) ! Marked for inlining.

      call minusone_second(y, x(1, 3)) ! Standard call with sequence association (never processed).

      contains

      subroutine inner(output, x)
        real, intent(inout) :: output
        real, intent(in) :: x(3)

        output = x(2) + 2.0
      end subroutine inner
    end subroutine outer

end module somemod
"""
    # Test case that nothing happens if `resolve_sequence_association=True`
    # but inlining "marked" and "internals" is disabled.
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    trafo = InlineTransformation(
        inline_constants=True, external_only=True, inline_elementals=True,
        inline_marked=False, inline_internals=False, resolve_sequence_association=True
    )
    outer = module["outer"]
    trafo.apply(outer)
    callnames = [call.name for call in FindNodes(ir.CallStatement).visit(outer.body)]
    assert 'plusone' in callnames
    assert 'inner' in callnames
    assert 'minusone_second' in callnames

    # Test case that only marked processed if
    # `resolve_sequence_association=True`
    # `inline_marked=True`,
    # `inline_internals=False`
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    trafo = InlineTransformation(
        inline_constants=True, external_only=True, inline_elementals=True,
        inline_marked=True, inline_internals=False, resolve_sequence_association=True
    )
    outer = module["outer"]
    trafo.apply(outer)
    callnames = [call.name for call in FindNodes(ir.CallStatement).visit(outer.body)]
    assert 'plusone' not in callnames
    assert 'inner' in callnames
    assert 'minusone_second' in callnames

    # Test case that a crash occurs if sequence association is not enabled even if it is needed.
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    trafo = InlineTransformation(
        inline_constants=True, external_only=True, inline_elementals=True,
        inline_marked=True, inline_internals=True, resolve_sequence_association=False
    )
    outer = module["outer"]
    with pytest.raises(TransformationError):
        trafo.apply(outer)
    callnames = [call.name for call in FindNodes(ir.CallStatement).visit(outer.body)]

    # Test case that sequence association is run and corresponding call inlined, avoiding crash.
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    trafo = InlineTransformation(
        inline_constants=True, external_only=True, inline_elementals=True,
        inline_marked=False, inline_internals=True, resolve_sequence_association=True
    )
    outer = module["outer"]
    trafo.apply(outer)
    callnames = [call.name for call in FindNodes(ir.CallStatement).visit(outer.body)]
    assert 'plusone' in callnames
    assert 'inner' not in callnames
    assert 'minusone_second' in callnames

    # Test case that everything is enabled.
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    trafo = InlineTransformation(
        inline_constants=True, external_only=True, inline_elementals=True,
        inline_marked=True, inline_internals=True, resolve_sequence_association=True
    )
    outer = module["outer"]
    trafo.apply(outer)
    callnames = [call.name for call in FindNodes(ir.CallStatement).visit(outer.body)]
    assert 'plusone' not in callnames
    assert 'inner' not in callnames
    assert 'minusone_second' in callnames


@pytest.mark.parametrize('frontend', available_frontends())
def test_inline_transformation_local_seq_assoc_crash_marked_no_seq_assoc(frontend, tmp_path):
    # Test case that a crash occurs if marked routine with sequence association is
    # attempted to inline without sequence association enabled.
    fcode = """
module somemod
    implicit none
    contains

    subroutine inner(output, x)
        real, intent(inout) :: output
        real, intent(in) :: x(3)

        output = x(2) + 2.0
    end subroutine inner

    subroutine outer()
      real :: x(3, 3)
      real :: y
      x = 10.0

      !$loki inline
      call inner(y, x(1, 1)) ! Sequence association tmp_path for marked routine.
    end subroutine outer

end module somemod
"""
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    trafo = InlineTransformation(
        inline_constants=True, external_only=True, inline_elementals=True,
        inline_marked=True, inline_internals=False, resolve_sequence_association=False
    )
    outer = module["outer"]
    with pytest.raises(TransformationError):
        trafo.apply(outer)

    # Test case that crash is avoided by activating sequence association.
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    trafo = InlineTransformation(
        inline_constants=True, external_only=True, inline_elementals=True,
        inline_marked=True, inline_internals=False, resolve_sequence_association=True
    )
    outer = module["outer"]
    trafo.apply(outer)
    assert len(FindNodes(ir.CallStatement).visit(outer.body)) == 0

@pytest.mark.parametrize('frontend', available_frontends())
def test_inline_transformation_local_seq_assoc_crash_value_err_no_source(frontend, tmp_path):
    # Testing that ValueError is thrown if sequence association is requested with inlining,
    # but source code behind call is missing (not enough type information).
    fcode = """
module somemod
    implicit none
    contains

    subroutine outer()
      real :: x(3, 3)
      real :: y
      x = 10.0

      !$loki inline
      call inner(y, x(1, 1)) ! Sequence association tmp_path for marked routine.
    end subroutine outer

end module somemod
"""
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    trafo = InlineTransformation(
        inline_constants=True, external_only=True, inline_elementals=True,
        inline_marked=True, inline_internals=False, resolve_sequence_association=True
    )
    outer = module["outer"]
    with pytest.raises(TransformationError):
        trafo.apply(outer)


@pytest.mark.parametrize('frontend', available_frontends())
def test_inline_transformation_adjust_imports(frontend, tmp_path):
    fcode_module = """
module bnds_module
  integer :: m
  integer :: n
  integer :: l
end module bnds_module
    """

    fcode_another = """
module another_module
  integer :: x
end module another_module
    """

    fcode_outer = """
subroutine test_inline_outer(a, b, f)
  use bnds_module, only: n
  use test_inline_mod, only: test_inline_inner
  use test_inline_another_mod, only: test_inline_another_inner
  implicit none

  real(kind=8), intent(inout) :: a(n), b(n), f(0:n-1)
  real(kind=8) :: c(12)

  !$loki inline
  call test_inline_another_inner()
  !$loki inline
  call test_inline_inner(a, b, c(1:4), c(5:8), c(9:12), f)
end subroutine test_inline_outer
    """

    fcode_inner = """
module test_inline_mod
  implicit none
  contains

subroutine test_inline_inner(a, b, c, d, e, f)
  use BNDS_module, only: n, m
  use another_module, only: x

  real(kind=8), intent(inout) :: a(n), b(n), f(2:n+1)
  real(kind=8), intent(out) :: c(4), d(4), e(0:3)
  real(kind=8) :: tmp(m)
  integer :: i

  tmp(1:m) = x
  do i=1, n
    a(i) = b(i) + sum(tmp)
  end do
  do i=1,4
    c(i) = 0.
    d(i) = 0.
    e(i-1) = 0.
  enddo
  c(:) = 1.
  d(1:4) = 1.
  e(0:3) = 1.
  e(:) = 2.
  do i=2, n+1
    f(i) = 2.
  end do
end subroutine test_inline_inner
end module test_inline_mod
    """

    fcode_another_inner = """
module test_inline_another_mod
  implicit none
  contains

subroutine test_inline_another_inner()
  use BNDS_module, only: n, m, l

end subroutine test_inline_another_inner
end module test_inline_another_mod
    """

    _ = Module.from_source(fcode_another, frontend=frontend, xmods=[tmp_path])
    _ = Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path])
    inner = Module.from_source(fcode_inner, frontend=frontend, xmods=[tmp_path])
    another_inner = Module.from_source(fcode_another_inner, frontend=frontend, xmods=[tmp_path])
    outer = Subroutine.from_source(
        fcode_outer, definitions=(inner, another_inner), frontend=frontend, xmods=[tmp_path]
    )

    trafo = InlineTransformation(
        inline_elementals=False, inline_marked=True, adjust_imports=True
    )
    trafo.apply(outer)

    # Check that the inlining has happened
    assign = FindNodes(ir.Assignment).visit(outer.body)
    assert len(assign) == 10
    assert assign[0].lhs == 'tmp(1:m)'
    assert assign[0].rhs == 'x'
    assert assign[1].lhs == 'a(i)'
    assert assign[1].rhs == 'b(i) + sum(tmp)'
    assert assign[2].lhs == 'c(i)'
    assert assign[2].rhs == '0.'
    assert assign[3].lhs == 'c(4 + i)'
    assert assign[3].rhs == '0.'
    assert assign[4].lhs == 'c(8 + i)'
    assert assign[4].rhs == '0.'
    assert assign[5].lhs == 'c(1:4)'
    assert assign[5].rhs == '1.'
    assert assign[6].lhs == 'c(5:8)'
    assert assign[6].rhs == '1.'
    assert assign[7].lhs == 'c(9:12)'
    assert assign[7].rhs == '1.'
    assert assign[8].lhs == 'c(9:12)'
    assert assign[8].rhs == '2.'
    assert assign[9].lhs == 'f(-2 + i)'
    assert assign[9].rhs == '2.'

    # Now check that the right modules have been moved,
    # and the import of the call has been removed
    imports = FindNodes(ir.Import).visit(outer.spec)
    assert len(imports) == 2
    assert imports[0].module == 'another_module'
    assert imports[0].symbols == ('x',)
    assert imports[1].module == 'bnds_module'
    assert all(_ in imports[1].symbols for _ in ['l', 'm', 'n'])


@pytest.mark.parametrize('frontend', available_frontends())
def test_inline_transformation_intermediate(tmp_path, frontend):
    fcode_outermost = """
module outermost_mod
implicit none
contains
subroutine outermost()
use intermediate_mod, only: intermediate

!$loki inline
call intermediate()

end subroutine outermost
end module outermost_mod
"""

    fcode_intermediate = """
module intermediate_mod
implicit none
contains
subroutine intermediate()
use innermost_mod, only: innermost

call innermost()

end subroutine intermediate
end module intermediate_mod
"""

    fcode_innermost = """
module innermost_mod
implicit none
contains
subroutine innermost()

end subroutine innermost
end module innermost_mod
"""

    (tmp_path/'outermost_mod.F90').write_text(fcode_outermost)
    (tmp_path/'intermediate_mod.F90').write_text(fcode_intermediate)
    (tmp_path/'innermost_mod.F90').write_text(fcode_innermost)

    config = {
        'default': {
            'mode': 'idem',
            'role': 'kernel',
            'expand': True,
            'strict': True
        },
        'routines': {
            'outermost': {'role': 'kernel'}
        }
    }

    scheduler = Scheduler(
        paths=[tmp_path], config=SchedulerConfig.from_dict(config),
        frontend=frontend, xmods=[tmp_path]
    )

    def _get_successors(item):
        return scheduler.sgraph.successors(scheduler[item])

    # check graph edges before transformation
    assert len(scheduler.items) == 3
    assert len(_get_successors('outermost_mod#outermost')) == 1
    assert scheduler['intermediate_mod#intermediate'] in _get_successors('outermost_mod#outermost')
    assert len(_get_successors('intermediate_mod#intermediate')) == 1
    assert scheduler['innermost_mod#innermost'] in _get_successors('intermediate_mod#intermediate')

    scheduler.process( transformation=InlineTransformation() )

    # check graph edges were updated correctly
    assert len(scheduler.items) == 2
    assert len(_get_successors('outermost_mod#outermost')) == 1
    assert scheduler['innermost_mod#innermost'] in _get_successors('outermost_mod#outermost')

@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('full_parse', (True, False))
def test_inline_transformation_plan(tmp_path, frontend, full_parse):
    fcode_outermost = """
module outermost_mod
implicit none
contains
subroutine outermost()
use intermediate_1_mod, only: intermediate1 => intermediate_1
use intermediate_2_mod, intermediate2 => intermediate_2

!$loki inline
call intermediate1()

!$loki inline
call intermediate2()

end subroutine outermost
end module outermost_mod
"""

    fcode_intermediate_1 = """
module intermediate_1_mod
implicit none
contains
subroutine intermediate_1()
use innermost_1_mod, only: innermost_1
use innermost_2_mod, only: innermost_2
use innermost_3_mod, only: innermost_3

!$loki inline
call innermost_1()

call innermost_2()

!$loki inline
call innermost_3()

end subroutine intermediate_1
end module intermediate_1_mod
"""

    fcode_intermediate_2 = """
module intermediate_2_mod
implicit none
contains
subroutine intermediate_2()
use innermost_1_mod, only: innermost_1

!$loki inline
call innermost_1()

call innermost_1()

end subroutine intermediate_2
end module intermediate_2_mod
"""

    fcode_innermost_1 = """
module innermost_1_mod
implicit none
contains
subroutine innermost_1()
use innerinnermost_1_mod, only: innerinnermost_1

!$loki inline
call innerinnermost_1()

end subroutine innermost_1
end module innermost_1_mod
"""

    fcode_innermost_2 = """
module innermost_2_mod
implicit none
contains
subroutine innermost_2()

end subroutine innermost_2
end module innermost_2_mod
"""

    fcode_innermost_3 = """
module innermost_3_mod
implicit none
contains
subroutine innermost_3()

end subroutine innermost_3
end module innermost_3_mod
"""

    fcode_innerinnermost_1 = """
module innerinnermost_1_mod
implicit none
contains
subroutine innerinnermost_1()
use innerinnerinnermost_1_mod, only: innerinnerinnermost_1

call innerinnerinnermost_1()

end subroutine innerinnermost_1
end module innerinnermost_1_mod
"""

    fcode_innerinnerinnermost_1 = """
module innerinnerinnermost_1_mod
implicit none
contains
subroutine innerinnerinnermost_1()

end subroutine innerinnerinnermost_1
end module innerinnerinnermost_1_mod
"""

    (tmp_path/'outermost_mod.F90').write_text(fcode_outermost)
    (tmp_path/'intermediate_1_mod.F90').write_text(fcode_intermediate_1)
    (tmp_path/'intermediate_2_mod.F90').write_text(fcode_intermediate_2)
    (tmp_path/'innermost_1_mod.F90').write_text(fcode_innermost_1)
    (tmp_path/'innermost_2_mod.F90').write_text(fcode_innermost_2)
    (tmp_path/'innermost_3_mod.F90').write_text(fcode_innermost_3)
    (tmp_path/'innerinnermost_1_mod.F90').write_text(fcode_innerinnermost_1)
    (tmp_path/'innerinnerinnermost_1_mod.F90').write_text(fcode_innerinnerinnermost_1)

    config = {
        'default': {
            'mode': 'idem',
            'role': 'kernel',
            'expand': True,
            'strict': True,
            'replicate': True,
        },
        'routines': {
            'outermost': {'role': 'driver', 'replicate': False}
        }
    }

    scheduler = Scheduler(
        paths=[tmp_path], config=SchedulerConfig.from_dict(config),
        frontend=frontend, xmods=[tmp_path], full_parse=full_parse
    )

    pipeline = Pipeline(classes=(InlineTransformation, FileWriteTransformation))

    plan_file = tmp_path/'plan.cmake'
    scheduler.process(pipeline, proc_strategy=ProcessingStrategy.PLAN)
    scheduler.write_cmake_plan(filepath=plan_file, rootpath=tmp_path)

    outermost_item = scheduler["outermost_mod#outermost"]
    intermediate_1_item = scheduler["intermediate_1_mod#intermediate_1"]
    intermediate_2_item = scheduler["intermediate_2_mod#intermediate_2"]
    innermost_1_item = scheduler["innermost_1_mod#innermost_1"]
    innermost_2_item = scheduler["innermost_2_mod#innermost_2"]
    innermost_3_item = scheduler["innermost_3_mod#innermost_3"]
    innerinnermost_1_item = scheduler["innerinnermost_1_mod#innerinnermost_1"]
    innerinnerinnermost_1_item = scheduler["innerinnerinnermost_1_mod#innerinnerinnermost_1"]

    assert hasattr(outermost_item, 'plan_data')
    assert set(outermost_item.plan_data['additional_dependencies']) == \
            {'innerinnerinnermost_1_mod#innerinnerinnermost_1', 'innermost_2_mod#innermost_2',
                    'innermost_1_mod#innermost_1'}
    assert set(outermost_item.plan_data['removed_dependencies']) == \
            {'intermediate_2_mod#intermediate_2', 'intermediate_1_mod#intermediate_1'}
    assert not hasattr(intermediate_1_item, 'plan_data')
    assert not hasattr(intermediate_2_item, 'plan_data')
    assert hasattr(innermost_1_item, 'plan_data')
    assert set(innermost_1_item.plan_data['additional_dependencies']) == \
            {'innerinnerinnermost_1_mod#innerinnerinnermost_1'}
    assert set(innermost_1_item.plan_data['removed_dependencies']) == \
            {'innerinnermost_1_mod#innerinnermost_1'}
    assert hasattr(innermost_2_item, 'plan_data')
    assert not hasattr(innermost_3_item, 'plan_data')
    assert not hasattr(innerinnermost_1_item, 'plan_data')
    assert hasattr(innerinnerinnermost_1_item, 'plan_data')

    # Validate the plan file content
    plan_pattern = re.compile(r'set\(\s*(\w+)\s*(.*?)\s*\)', re.DOTALL)
    loki_plan = plan_file.read_text()
    plan_dict = {k: v.split() for k, v in plan_pattern.findall(loki_plan)}
    plan_dict = {k: {Path(s).stem for s in v} for k, v in plan_dict.items()}

    expected_items_to_transform = {'innermost_2_mod', 'outermost_mod', 'innermost_1_mod', 'innerinnerinnermost_1_mod'}
    expected_items_to_append = {'innerinnerinnermost_1_mod.idem', 'outermost_mod.idem', 'innermost_2_mod.idem',
            'innermost_1_mod.idem'}
    expected_items_to_remove = {'outermost_mod'}

    assert plan_dict['LOKI_SOURCES_TO_TRANSFORM'] == expected_items_to_transform
    assert plan_dict['LOKI_SOURCES_TO_APPEND'] == expected_items_to_append
    assert plan_dict['LOKI_SOURCES_TO_REMOVE'] == expected_items_to_remove
loki-ecmwf-0.3.6/loki/transformations/inline/tests/test_procedures.py0000664000175000017500000007312415167130205026343 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest
import numpy as np

from loki import Module, Subroutine
from loki.jit_build import jit_compile
from loki.expression import symbols as sym
from loki.frontend import available_frontends, OMNI
from loki.ir import (
    nodes as ir, FindNodes, FindVariables, FindInlineCalls
)
from loki.types import BasicType, DerivedType

from loki.transformations.inline import (
    inline_member_procedures, inline_marked_subroutines
)
from loki.transformations.sanitise import ResolveAssociatesTransformer


@pytest.mark.parametrize('frontend', available_frontends())
def test_inline_member_routines(tmp_path, frontend):
    """
    Test inlining of member subroutines.
    """
    fcode = """
subroutine member_routines(a, b)
  real(kind=8), intent(inout) :: a(3), b(3)
  integer :: i

  do i=1, size(a)
    call add_one(a(i))
  end do

  call add_to_a(b)

  do i=1, size(a)
    call add_one(a(i))
  end do

  contains

    subroutine add_one(a)
      real(kind=8), intent(inout) :: a
      a = a + 1
    end subroutine

    subroutine add_to_a(b)
      real(kind=8), intent(inout) :: b(:)
      integer :: n

      n = size(a)
      do i = 1, n
        a(i) = a(i) + b(i)
      end do
    end subroutine
end subroutine member_routines
    """
    routine = Subroutine.from_source(fcode, frontend=frontend)

    filepath = tmp_path/(f'ref_transform_inline_member_routines_{frontend}.f90')
    reference = jit_compile(routine, filepath=filepath, objname='member_routines')

    a = np.array([1., 2., 3.], order='F')
    b = np.array([3., 3., 3.], order='F')
    reference(a, b)

    assert (a == [6., 7., 8.]).all()
    assert (b == [3., 3., 3.]).all()

    # Now inline the member routines and check again
    inline_member_procedures(routine=routine)

    assert not routine.members
    assert not FindNodes(ir.CallStatement).visit(routine.body)
    assert len(FindNodes(ir.Loop).visit(routine.body)) == 3
    assert 'n' in routine.variables

    # An verify compiled behaviour
    filepath = tmp_path/(f'transform_inline_member_routines_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname='member_routines')

    a = np.array([1., 2., 3.], order='F')
    b = np.array([3., 3., 3.], order='F')
    function(a, b)

    assert (a == [6., 7., 8.]).all()
    assert (b == [3., 3., 3.]).all()


@pytest.mark.parametrize('frontend', available_frontends())
def test_inline_member_functions(tmp_path, frontend):
    """
    Test inlining of member subroutines.
    """
    fcode = """
subroutine member_functions(a, b, c)
  implicit none
  real(kind=8), intent(inout) :: a(3), b(3), c(3)
  integer :: i

  do i=1, size(a)
    a(i) = add_one(a(i))
  end do

  c = add_to_a(b, 3)

  do i=1, size(a)
    a(i) = add_one(a(i))
  end do

  contains

    function add_one(a)
      real(kind=8) :: a
      real(kind=8) :: add_one
      add_one = a + 1
    end function

    function add_to_a(b, n)
      integer, intent(in) :: n
      real(kind=8), intent(in) :: b(n)
      real(kind=8) :: add_to_a(n)

      do i = 1, n
        add_to_a(i) = a(i) + b(i)
      end do
    end function
end subroutine member_functions
    """
    routine = Subroutine.from_source(fcode, frontend=frontend)

    filepath = tmp_path/(f'ref_transform_inline_member_functions_{frontend}.f90')
    reference = jit_compile(routine, filepath=filepath, objname='member_functions')

    a = np.array([1., 2., 3.], order='F')
    b = np.array([3., 3., 3.], order='F')
    c = np.array([0., 0., 0.], order='F')
    reference(a, b, c)

    assert (a == [3., 4., 5.]).all()
    assert (b == [3., 3., 3.]).all()
    assert (c == [5., 6., 7.]).all()

    # Now inline the member routines and check again
    inline_member_procedures(routine=routine)

    assert not routine.members
    assert not FindNodes(ir.CallStatement).visit(routine.body)
    assert len(FindNodes(ir.Loop).visit(routine.body)) == 3

    # An verify compiled behaviour
    filepath = tmp_path/(f'transform_inline_member_functions_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname='member_functions')

    a = np.array([1., 2., 3.], order='F')
    b = np.array([3., 3., 3.], order='F')
    c = np.array([0., 0., 0.], order='F')
    function(a, b, c)

    assert (a == [3., 4., 5.]).all()
    assert (b == [3., 3., 3.]).all()
    assert (c == [5., 6., 7.]).all()


@pytest.mark.parametrize('frontend', available_frontends())
def test_inline_member_routines_arg_dimensions(frontend):
    """
    Test inlining of member subroutines when sub-arrays of rank less
    than the formal argument are passed.
    """
    fcode = """
subroutine member_routines_arg_dimensions(matrix, tensor)
  real(kind=8), intent(inout) :: matrix(3, 3), tensor(3, 3, 4)
  integer :: i
  do i=1, 3
    call add_one(3, matrix(1:3,i), tensor(:,i,:))
  end do
  contains
    subroutine add_one(n, a, b)
      integer, intent(in) :: n
      real(kind=8), intent(inout) :: a(3), b(3,1:n)
      integer :: j
      do j=1, n
        a(j) = a(j) + 1
        b(j,:) = 66.6
      end do
    end subroutine
end subroutine member_routines_arg_dimensions
    """
    routine = Subroutine.from_source(fcode, frontend=frontend)

    # Ensure initial member arguments
    assert len(routine.routines) == 1
    assert routine.routines[0].name == 'add_one'
    assert len(routine.routines[0].arguments) == 3
    assert routine.routines[0].arguments[0].name == 'n'
    assert routine.routines[0].arguments[1].name == 'a'
    assert routine.routines[0].arguments[2].name == 'b'

    # Now inline the member routines and check again
    inline_member_procedures(routine=routine)

    # Ensure member has been inlined and arguments adapated
    assert len(routine.routines) == 0
    assert len([v for v in FindVariables().visit(routine.body) if v.name == 'a']) == 0
    assert len([v for v in FindVariables().visit(routine.body) if v.name == 'b']) == 0
    assert len([v for v in FindVariables().visit(routine.spec) if v.name == 'a']) == 0
    assert len([v for v in FindVariables().visit(routine.spec) if v.name == 'b']) == 0
    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assigns) == 2
    assert assigns[0].lhs == 'matrix(j, i)' and assigns[0].rhs =='matrix(j, i) + 1'
    assert assigns[1].lhs == 'tensor(j, i, :)'

    # Ensure the `n` in the inner loop bound has been substituted too
    loops = FindNodes(ir.Loop).visit(routine.body)
    assert len(loops) == 2
    assert loops[0].bounds == '1:3'
    assert loops[1].bounds == '1:3'


@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'No header information in test')]))
def test_inline_member_routines_derived_type_member(frontend):
    """
    Test inlining of member subroutines when the member routine
    handles arrays that are derived type components and thus might
    have the DEFERRED type.
    """
    fcode = """
subroutine outer(x, a)
  real, intent(inout) :: x
  type(my_type), intent(in) :: a

  ! Pass derived type arrays as arguments
  call inner(a%b(:), a%c, a%k, a%n)

contains
  subroutine inner(y, z, k, n)
    integer, intent(in) :: k, n
    real, intent(inout) :: y(n), z(:,:)
    integer :: j

    do j=1, n
      x = x + y(j)
      ! Use derived-type variable as index
      ! to test for nested substitution
      y(j) = z(k,j)
    end do
  end subroutine inner
end subroutine outer
    """
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert routine.variable_map['x'].type.dtype == BasicType.REAL
    assert isinstance(routine.variable_map['a'].type.dtype, DerivedType)
    call = FindNodes(ir.CallStatement).visit(routine.body)[0]
    assert isinstance(call.arguments[0], sym.Array)
    assert isinstance(call.arguments[1], sym.DeferredTypeSymbol)
    assert isinstance(call.arguments[2], sym.DeferredTypeSymbol)

    # Now inline the member routines and check again
    inline_member_procedures(routine=routine)

    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assigns) == 2
    assert assigns[0].rhs =='x + a%b(j)'
    assert assigns[1].lhs == 'a%b(j)' and assigns[1].rhs == 'a%c(a%k, j)'


@pytest.mark.parametrize('frontend', available_frontends())
def test_inline_member_routines_variable_shadowing(frontend):
    """
    Test inlining of member subroutines when variable allocations
    in child routine shadow different allocations in the parent.
    """
    fcode = """
subroutine outer()
     real :: x = 3 ! 'x' is real in outer.
     real :: y

     y = 1.0
     call inner(y)
     x = x + y

contains
    subroutine inner(y)
        real, intent(inout) :: Y
        real :: x(3) ! 'x' is array in inner.
        x = [1, 2, 3]
        y = y + sum(x)
    end subroutine inner
end subroutine outer
    """
    routine = Subroutine.from_source(fcode, frontend=frontend)

    # Check outer and inner 'x'
    assert routine.variable_map['x'] == 'x'
    assert isinstance(routine.variable_map['x'], sym.Scalar)
    assert routine.variable_map['x'].type.initial == 3

    assert routine['inner'].variable_map['x'] in ['x(3)', 'x(1:3)']
    assert isinstance(routine['inner'].variable_map['x'], sym.Array)
    assert routine['inner'].variable_map['x'].type.shape == (3,)

    inline_member_procedures(routine=routine)

    # Check outer has not changed
    assert routine.variable_map['x'] == 'x'
    assert isinstance(routine.variable_map['x'], sym.Scalar)
    assert routine.variable_map['x'].type.initial == 3

    # Check inner 'x' was moved correctly
    assert routine.variable_map['inner_x'] in ['inner_x(3)', 'inner_x(1:3)']
    assert isinstance(routine.variable_map['inner_x'], sym.Array)
    assert routine.variable_map['inner_x'].type.shape == (3,)

    # Check inner 'y' was substituted, not renamed!
    assign = FindNodes(ir.Assignment).visit(routine.body)
    assert routine.variable_map['y'] == 'y'
    assert assign[2].lhs == 'y' and assign[2].rhs == 'y + sum(inner_x)'


@pytest.mark.parametrize('frontend', available_frontends())
def test_inline_internal_routines_aliasing_declaration(frontend):
    """
    Test declaration splitting when inlining internal procedures.
    """
    fcode = """
subroutine outer()
  integer :: z
  integer :: jlon
  z = 0
  jlon = 0

  call inner(z)

  jlon = z + 4
contains
  subroutine inner(z)
    integer, intent(inout) :: z
    integer :: jlon, jg ! These two need to get separated
    jlon = 1
    jg = 2
    z = jlon + jg
  end subroutine inner
end subroutine outer
    """
    routine = Subroutine.from_source(fcode, frontend=frontend)

    # Check outer and inner variables
    assert len(routine.variable_map) == 2
    assert 'z' in routine.variable_map
    assert 'jlon' in routine.variable_map

    assert len(routine['inner'].variable_map) == 3
    assert 'z' in routine['inner'].variable_map
    assert 'jlon' in routine['inner'].variable_map
    assert 'jg' in routine['inner'].variable_map

    inline_member_procedures(routine, allowed_aliases=('jlon',))

    assert len(routine.variable_map) == 3
    assert 'z' in routine.variable_map
    assert 'jlon' in routine.variable_map
    assert 'jg' in routine.variable_map

    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assigns) == 6
    assert assigns[2].lhs == 'jlon' and assigns[2].rhs == '1'
    assert assigns[3].lhs == 'jg' and assigns[3].rhs == '2'
    assert assigns[4].lhs == 'z' and assigns[4].rhs == 'jlon + jg'

@pytest.mark.parametrize('frontend', available_frontends())
def test_inline_member_routines_indexing_of_shadowed_array(frontend):
    """
    Test special case of inlining of member subroutines when inlined routine contains
    shadowed array and array indices.
    In particular, this test checks that also the variables indexing
    the array in the inlined result get renamed correctly.
    """
    fcode = """
    subroutine outer(klon)
        integer :: jg, jlon
        integer :: arr(3, 3)

        jg = 70000
        call inner2()

        contains

        subroutine inner2()
            integer :: jlon, jg
            integer :: arr(3, 3)
            do jg=1,3
                do jlon=1,3
                   arr(jlon, jg) = 11
                end do
            end do
        end subroutine inner2

    end subroutine outer
    """
    routine = Subroutine.from_source(fcode, frontend=frontend)
    inline_member_procedures(routine)
    innerloop = FindNodes(ir.Loop).visit(routine.body)[1]
    innerloopvars = FindVariables().visit(innerloop)
    assert 'inner2_arr(inner2_jlon,inner2_jg)' in innerloopvars


@pytest.mark.parametrize('frontend', available_frontends())
def test_inline_member_routines_sequence_assoc(frontend):
    """
    Test inlining of member subroutines in the presence of sequence
    associations. As this is not supported, we check for the
    appropriate error.
    """
    fcode = """
subroutine member_routines_sequence_assoc(vector)
  real(kind=8), intent(inout) :: vector(6)
  integer :: i

  i = 2
  call inner(3, vector(i))

  contains
    subroutine inner(n, a)
      integer, intent(in) :: n
      real(kind=8), intent(inout) :: a(3)
      integer :: j
      do j=1, n
        a(j) = a(j) + 1
      end do
    end subroutine
end subroutine member_routines_sequence_assoc
    """
    routine = Subroutine.from_source(fcode, frontend=frontend)

    # Expect to fail tmp_path due to use of sequence association
    with pytest.raises(RuntimeError):
        inline_member_procedures(routine=routine)


@pytest.mark.parametrize('frontend', available_frontends())
def test_inline_member_routines_with_associate(frontend):
    """
    Ensure that internal routines with :any:`Associate` constructs get
    inlined as expected.
    """
    fcode = """
subroutine acraneb_transt(klon, klev, kidia, kfdia, ktdia)
  implicit none

  integer(kind=4), intent(in) :: klon, klev, kidia, kfdia, ktdia
  integer(kind=4) :: jlon, jlev

  real(kind=8) :: zq1(klon)
  real(kind=8) :: zq2(klon, klev)

  call delta_t(zq1)

  do jlev = ktdia, klev
    call delta_t(zq2(1:klon,jlev))

  enddo

contains

subroutine delta_t(pq)
  implicit none

  real(kind=8), intent(in) :: pq(klon)
  real(kind=8) :: x, z

  associate(zz => z)

  do jlon = 1,klon
    x = x + pq(jlon)
  enddo
  end associate
end subroutine

end subroutine acraneb_transt
    """

    routine = Subroutine.from_source(fcode, frontend=frontend)

    inline_member_procedures(routine=routine)

    assert not routine.members
    loops = FindNodes(ir.Loop).visit(routine.body)
    assert len(loops) == 3

    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assigns) == 2
    assert assigns[0].rhs == 'x + zq1(jlon)'
    assert assigns[1].rhs == 'x + zq2(jlon, jlev)'

    assocs = FindNodes(ir.Associate).visit(routine.body)
    assert len(assocs) == 2


@pytest.mark.parametrize('frontend', available_frontends(
    xfail=[(OMNI, 'OMNI does not handle missing type definitions')]
))
def test_inline_member_routines_with_optionals(frontend):
    """
    Ensure that internal routines with optional arguments get
    inlined as expected (esp. present instrinsics are correctly
    evaluated for all variables types)
    """
    fcode = """
subroutine test_inline(klon, ydxfu, ydmf_phys_out)

  use yomxfu                  , only : txfu
  use mf_phys_type_mod        , only : mf_phys_out_type

  implicit none

  integer(kind=4), intent(in) :: klon
  type(txfu)              ,intent(inout)            :: ydxfu
  type(mf_phys_out_type)  ,intent(in)               :: ydmf_phys_out

  call member_rout (ydxfu%visicld, pvmin=ydmf_phys_out%visicld, psmax=1.0_8)

  contains

  subroutine member_rout (x, pvmin, pvmax, psmin, psmax)

    real(kind=8)         ,intent(inout)            :: x(1:klon)
    real(kind=8)         ,intent(in)    ,optional  :: pvmin(1:klon)
    real(kind=8)         ,intent(in)    ,optional  :: pvmax(1:klon)
    real(kind=8)         ,intent(in)    ,optional  :: psmin
    real(kind=8)         ,intent(in)    ,optional  :: psmax

    if (present (psmin)) x = psmin
    if (present (psmax)) x = psmax
    if (present (pvmin)) x = minval(pvmin(:))
    if (present (pvmax)) x = maxval(pvmax(:))

  end subroutine member_rout

end subroutine test_inline
    """

    routine = Subroutine.from_source(fcode, frontend=frontend)

    inline_member_procedures(routine=routine)

    assert not routine.members

    conds = FindNodes(ir.Conditional).visit(routine.body)
    assert len(conds) == 4
    assert conds[0].condition == 'False'
    assert conds[1].condition == 'True'
    assert conds[2].condition == 'True'
    assert conds[3].condition == 'False'


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('adjust_imports', [True, False])
def test_inline_marked_subroutines(frontend, adjust_imports, tmp_path):
    """ Test subroutine inlining via marker pragmas. """

    fcode_driver = """
subroutine test_pragma_inline(a, b)
  use util_mod, only: add_one, add_a_to_b
  implicit none

  real(kind=8), intent(inout) :: a(3), b(3)
  integer, parameter :: n = 3
  integer :: i

  do i=1, n
    !$loki inline
    call add_one(a(i))
  end do

  !$loki inline
  call add_a_to_b(a(:), b(:), 3)

  do i=1, n
    call add_one(b(i))
  end do

end subroutine test_pragma_inline
    """

    fcode_module = """
module util_mod
implicit none

contains
  subroutine add_one(a)
    interface
      subroutine do_something()
      end subroutine do_something
    end interface
    real(kind=8), intent(inout) :: a
    a = a + 1
  end subroutine add_one

  subroutine add_a_to_b(a, b, n)
    interface
      subroutine do_something_else()
      end subroutine do_something_else
    end interface
    real(kind=8), intent(inout) :: a(:), b(:)
    integer, intent(in) :: n
    integer :: i

    do i = 1, n
      a(i) = a(i) + b(i)
    end do
  end subroutine add_a_to_b
end module util_mod
"""
    module = Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path])
    driver = Subroutine.from_source(fcode_driver, frontend=frontend, xmods=[tmp_path])
    driver.enrich(module)

    calls = FindNodes(ir.CallStatement).visit(driver.body)
    assert calls[0].routine == module['add_one']
    assert calls[1].routine == module['add_a_to_b']
    assert calls[2].routine == module['add_one']

    inline_marked_subroutines(
        routine=driver, allowed_aliases=('I',), adjust_imports=adjust_imports
    )

    # Check inlined loops and assignments
    assert len(FindNodes(ir.Loop).visit(driver.body)) == 3
    assign = FindNodes(ir.Assignment).visit(driver.body)
    assert len(assign) == 2
    assert assign[0].lhs == 'a(i)' and assign[0].rhs == 'a(i) + 1'
    assert assign[1].lhs == 'a(i)' and assign[1].rhs == 'a(i) + b(i)'

    # Check that the last call is left untouched
    calls = FindNodes(ir.CallStatement).visit(driver.body)
    assert len(calls) == 1
    assert calls[0].routine.name == 'add_one'
    assert calls[0].arguments == ('b(i)',)

    imports = FindNodes(ir.Import).visit(driver.spec)
    assert len(imports) == 1
    if adjust_imports:
        assert imports[0].symbols == ('add_one',)
    else:
        assert imports[0].symbols == ('add_one', 'add_a_to_b')

    if adjust_imports:
        # check that explicit interfaces were imported
        intfs = driver.interfaces
        assert len(intfs) == 1
        assert all(isinstance(s, sym.ProcedureSymbol) for s in driver.interface_symbols)
        assert 'do_something' in driver.interface_symbols
        assert 'do_something_else' in driver.interface_symbols


@pytest.mark.parametrize('frontend', available_frontends())
def test_inline_marked_subroutines_with_interfaces(frontend, tmp_path):
    """ Test inlining of subroutines with explicit interfaces via marker pragmas. """

    fcode_driver = """
subroutine test_pragma_inline(a, b)
  implicit none

  interface
    subroutine add_a_to_b(a, b, n)
      real(kind=8), intent(inout) :: a(:), b(:)
      integer, intent(in) :: n
    end subroutine add_a_to_b
    subroutine add_one(a)
      real(kind=8), intent(inout) :: a
    end subroutine add_one
  end interface

  interface
    subroutine add_two(a)
      real(kind=8), intent(inout) :: a
    end subroutine add_two
  end interface

  real(kind=8), intent(inout) :: a(3), b(3)
  integer, parameter :: n = 3
  integer :: i

  do i=1, n
    !$loki inline
    call add_one(a(i))
  end do

  !$loki inline
  call add_a_to_b(a(:), b(:), 3)

  do i=1, n
    call add_one(b(i))
    !$loki inline
    call add_two(b(i))
  end do

end subroutine test_pragma_inline
    """

    fcode_module = """
module util_mod
implicit none

contains
  subroutine add_one(a)
    real(kind=8), intent(inout) :: a
    a = a + 1
  end subroutine add_one

  subroutine add_two(a)
    real(kind=8), intent(inout) :: a
    a = a + 2
  end subroutine add_two

  subroutine add_a_to_b(a, b, n)
    real(kind=8), intent(inout) :: a(:), b(:)
    integer, intent(in) :: n
    integer :: i

    do i = 1, n
      a(i) = a(i) + b(i)
    end do
  end subroutine add_a_to_b
end module util_mod
"""

    module = Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path])
    driver = Subroutine.from_source(fcode_driver, frontend=frontend, xmods=[tmp_path])
    driver.enrich(module.subroutines)

    calls = FindNodes(ir.CallStatement).visit(driver.body)
    assert calls[0].routine == module['add_one']
    assert calls[1].routine == module['add_a_to_b']
    assert calls[2].routine == module['add_one']
    assert calls[3].routine == module['add_two']

    inline_marked_subroutines(routine=driver, allowed_aliases=('I',))

    # Check inlined loops and assignments
    assert len(FindNodes(ir.Loop).visit(driver.body)) == 3
    assign = FindNodes(ir.Assignment).visit(driver.body)
    assert len(assign) == 3
    assert assign[0].lhs == 'a(i)' and assign[0].rhs == 'a(i) + 1'
    assert assign[1].lhs == 'a(i)' and assign[1].rhs == 'a(i) + b(i)'
    assert assign[2].lhs == 'b(i)' and assign[2].rhs == 'b(i) + 2'

    # Check that the last call is left untouched
    calls = FindNodes(ir.CallStatement).visit(driver.body)
    assert len(calls) == 1
    assert calls[0].routine.name == 'add_one'
    assert calls[0].arguments == ('b(i)',)

    intfs = FindNodes(ir.Interface).visit(driver.spec)
    assert len(intfs) == 1
    assert intfs[0].symbols == ('add_one',)


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('adjust_imports', [True, False])
def test_inline_marked_routine_with_optionals(frontend, adjust_imports, tmp_path):
    """ Test subroutine inlining via marker pragmas with omitted optionals. """

    fcode_driver = """
subroutine test_pragma_inline_optionals(a, b)
  use util_mod, only: add_one
  implicit none

  real(kind=8), intent(inout) :: a(3), b(3)
  integer, parameter :: n = 3
  integer :: i

  do i=1, n
    !$loki inline
    call add_one(a(i), two=2.0)
  end do

  do i=1, n
    !$loki inline
    call add_one(b(i))
  end do

end subroutine test_pragma_inline_optionals
    """

    fcode_module = """
module util_mod
implicit none

contains
  subroutine add_one(a, two)
    real(kind=8), intent(inout) :: a
    real(kind=8), optional, intent(inout) :: two
    a = a + 1

    if (present(two)) then
      a = a + two
    end if
  end subroutine add_one
end module util_mod
"""
    module = Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path])
    driver = Subroutine.from_source(fcode_driver, frontend=frontend, xmods=[tmp_path])
    driver.enrich(module)

    calls = FindNodes(ir.CallStatement).visit(driver.body)
    assert calls[0].routine == module['add_one']
    assert calls[1].routine == module['add_one']

    inline_marked_subroutines(routine=driver, adjust_imports=adjust_imports)

    # Check inlined loops and assignments
    assert len(FindNodes(ir.Loop).visit(driver.body)) == 2
    assign = FindNodes(ir.Assignment).visit(driver.body)
    assert len(assign) == 4
    assert assign[0].lhs == 'a(i)' and assign[0].rhs == 'a(i) + 1'
    assert assign[1].lhs == 'a(i)' and assign[1].rhs == 'a(i) + 2.0'
    assert assign[2].lhs == 'b(i)' and assign[2].rhs == 'b(i) + 1'
    # TODO: This is a problem, since it's not declared anymore
    assert assign[3].lhs == 'b(i)' and assign[3].rhs == 'b(i) + two'

    # Check that the PRESENT checks have been resolved
    assert len(FindNodes(ir.CallStatement).visit(driver.body)) == 0
    assert len(FindInlineCalls().visit(driver.body)) == 0
    checks = FindNodes(ir.Conditional).visit(driver.body)
    assert len(checks) == 2
    assert checks[0].condition == 'True'
    assert checks[1].condition == 'False'

    imports = FindNodes(ir.Import).visit(driver.spec)
    assert len(imports) == 0 if adjust_imports else 1


@pytest.mark.parametrize('frontend', available_frontends(
    xfail=[(OMNI, 'OMNI has no sense of humour!')])
)
def test_inline_marked_subroutines_with_associates(frontend):
    """ Test subroutine inlining via marker pragmas with nested associates. """

    fcode_outer = """
subroutine test_pragma_inline_associates(never)
  use peter_pan, only: neverland
  implicit none
  type(neverland), intent(inout) :: never

  associate(going=>never%going_to)

  associate(up=>give_you%up)

  !$loki inline
  call dave(going, up)

  end associate

  end associate
end subroutine test_pragma_inline_associates
    """

    fcode_inner = """
subroutine dave(going)
  use your_imagination, only: astley
  implicit none
  type(astley), intent(inout) :: going

  associate(give_you=>going%give_you)

  associate(up=>give_you%up)

  call rick_is(up)

  end associate

  end associate
end subroutine dave
    """

    outer = Subroutine.from_source(fcode_outer, frontend=frontend)
    inner = Subroutine.from_source(fcode_inner, frontend=frontend)
    outer.enrich(inner)

    assert FindNodes(ir.CallStatement).visit(outer.body)[0].routine == inner

    inline_marked_subroutines(routine=outer, adjust_imports=True)

    # Ensure that all associates are perfectly nested afterwards
    assocs = FindNodes(ir.Associate).visit(outer.body)
    assert len(assocs) == 4
    assert assocs[1].parent == assocs[0]
    assert assocs[2].parent == assocs[1]
    assert assocs[3].parent == assocs[2]

    # And, because we can...
    outer.body = ResolveAssociatesTransformer().visit(outer.body)
    call = FindNodes(ir.CallStatement).visit(outer.body)[0]
    assert call.name == 'rick_is'
    assert call.arguments == ('never%going_to%give_you%up',)
    # Q. E. D.


@pytest.mark.parametrize('frontend', available_frontends())
def test_inline_marked_subroutines_declarations(frontend, tmp_path):
    """Test symbol propagation to hoisted declaration when inlining."""
    fcode = """
module inline_declarations
  implicit none

  type bounds
    integer :: start, end
  end type bounds

  contains

  subroutine outer(a, bnds)
    real(kind=8), intent(inout) :: a(bnds%end)
    type(bounds), intent(in) :: bnds
    real(kind=8) :: b(bnds%end)

    b(bnds%start:bnds%end) = a(bnds%start:bnds%end) + 42.0

    !$loki inline
    call inner(a, dims=bnds)
  end subroutine outer

  subroutine inner(c, dims)
    real(kind=8), intent(inout) :: c(dims%end)
    type(bounds), intent(in) :: dims
    real(kind=8) :: d(dims%end)

    d(dims%start:dims%end) = c(dims%start:dims%end) - 66.6
    c(dims%start) = sum(d)
  end subroutine inner
end module inline_declarations
"""
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    outer = module['outer']

    inline_marked_subroutines(routine=outer, adjust_imports=True)

    # Check that all declarations are using the ``bnds`` symbol
    assert outer.symbols[0] == 'a(bnds%end)'
    assert outer.symbols[2] == 'b(bnds%end)'
    assert outer.symbols[3] == 'd(bnds%end)'
    assert all(
        a.shape == ('bnds%end',) for a in outer.symbols if isinstance(a, sym.Array)
    )


@pytest.mark.parametrize('frontend', available_frontends(
    xfail=[(OMNI, 'No header information in test')]
))
def test_inline_marked_subroutines_imports(frontend, tmp_path):
    """Test propagation of necessary imports to the parent function"""
    fcode = """
subroutine inline_routine_imports(n, a, b)
  use rick_mod, only: rick
  use dave_mod, only: dave
implicit none

  integer, intent(in) :: n
  real(kind=8), intent(inout) :: a(n), b(n)
  integer :: i

  !$loki inline
  call rick(a)

  call rick(b)

  !$loki inline
  call dave(a)
end subroutine inline_routine_imports
"""

    fcode_rick = """
module rick_mod
  use type_mod, only: a_type
  implicit none
contains
  subroutine rick(a)
    use type_mod, only: a_type

    real(kind=8), intent(inout) :: a(:)
    type(a_type) :: my_obj

    my_obj%a = a(1)
    a(:) = my_obj%a
  end subroutine rick
end module rick_mod
"""

    fcode_dave = """
module dave_mod
  implicit none
contains
  subroutine dave(a)
    use type_mod, only: a_type, a_kind

    real(kind=8), intent(inout) :: a(:)
    type(a_type) :: my_obj
    real(kind=a_kind) :: my_number

    my_obj%a = a(1)
    my_number = real(a(1), kind=a_kind)
    a(1) = my_obj%a + my_number
  end subroutine dave
end module dave_mod
"""
    rick_mod = Module.from_source(fcode_rick, frontend=frontend, xmods=[tmp_path])
    dave_mod = Module.from_source(fcode_dave, frontend=frontend, xmods=[tmp_path])
    routine = Subroutine.from_source(
        fcode, definitions=[rick_mod, dave_mod], frontend=frontend, xmods=[tmp_path]
    )

    imports = FindNodes(ir.Import).visit(routine.spec)
    assert len(imports) == 2
    assert imports[0].module == 'rick_mod'
    assert imports[0].symbols == ('rick',)
    assert imports[1].module == 'dave_mod'
    assert imports[1].symbols == ('dave',)

    inline_marked_subroutines(routine=routine, adjust_imports=True)

    imports = FindNodes(ir.Import).visit(routine.spec)
    assert len(imports) == 2
    assert imports[0].module == 'type_mod'
    assert imports[0].symbols == ('a_type', 'a_kind')
    assert imports[1].module == 'rick_mod'
    assert imports[1].symbols == ('rick',)
loki-ecmwf-0.3.6/loki/transformations/inline/tests/test_constants.py0000664000175000017500000002251615167130205026203 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest
import numpy as np

from loki import Module, Subroutine
from loki.jit_build import jit_compile_lib, Builder, Obj, jit_compile, clean_test
from loki.frontend import available_frontends
from loki.ir import nodes as ir, FindNodes
from loki.expression import symbols as sym, parse_expr

from loki.transformations.inline import inline_constant_parameters
from loki.transformations.utilities import replace_selected_kind


@pytest.fixture(name='builder')
def fixture_builder(tmp_path):
    yield Builder(source_dirs=tmp_path, build_dir=tmp_path)
    Obj.clear_cache()


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_inline_constant_parameters(tmp_path, builder, frontend):
    """
    Test correct inlining of constant parameters.
    """
    fcode_module = """
module parameters_mod
  implicit none
  integer, parameter :: a = 1
  integer, parameter :: b = -1
contains
  subroutine dummy
  end subroutine dummy
end module parameters_mod
"""

    fcode = """
module inline_const_param_mod
  ! TODO: use parameters_mod, only: b
  implicit none
  integer, parameter :: c = 1+1
contains
  subroutine inline_const_param(v1, v2, v3)
    use parameters_mod, only: a, b
    integer, intent(in) :: v1
    integer, intent(out) :: v2, v3

    v2 = v1 + b - a
    v3 = c
  end subroutine inline_const_param
end module inline_const_param_mod
"""
    # Generate reference code, compile run and verify
    orig_tmp_path = tmp_path/'orig'
    orig_tmp_path.mkdir()
    param_module = Module.from_source(fcode_module, frontend=frontend, xmods=[orig_tmp_path])
    module = Module.from_source(fcode, frontend=frontend, xmods=[orig_tmp_path])
    refname = f'ref_{module.name}_{ frontend}'
    reference = jit_compile_lib([module, param_module], path=orig_tmp_path, name=refname, builder=builder)

    v2, v3 = reference.inline_const_param_mod.inline_const_param(10)
    assert v2 == 8
    assert v3 == 2

    # Now transform with supplied elementals but without module
    new_tmp_path = tmp_path/'new'
    new_tmp_path.mkdir()
    param_module = Module.from_source(fcode_module, frontend=frontend, xmods=[new_tmp_path])
    module = Module.from_source(fcode, definitions=param_module, frontend=frontend, xmods=[new_tmp_path])
    assert len(FindNodes(ir.Import).visit(module['inline_const_param'].spec)) == 1
    for routine in module.subroutines:
        inline_constant_parameters(routine, external_only=True)
    assert not FindNodes(ir.Import).visit(module['inline_const_param'].spec)

    # Hack: rename module to use a different filename in the build
    module.name = f'{module.name}_'
    obj = jit_compile_lib([module], path=new_tmp_path, name=f'{module.name}_{frontend}', builder=builder)

    v2, v3 = obj.inline_const_param_mod_.inline_const_param(10)
    assert v2 == 8
    assert v3 == 2


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_inline_constant_parameters_kind(tmp_path, builder, frontend):
    """
    Test correct inlining of constant parameters for kind symbols.
    """
    fcode_module = """
module kind_parameters_mod
  implicit none
  integer, parameter :: jprb = selected_real_kind(13, 300)
end module kind_parameters_mod
"""

    fcode = """
module inline_const_param_kind_mod
  implicit none
contains
  subroutine inline_const_param_kind(v1)
    use kind_parameters_mod, only: jprb
    real(kind=jprb), intent(out) :: v1

    v1 = real(2, kind=jprb) + 3.
  end subroutine inline_const_param_kind
end module inline_const_param_kind_mod
"""
    # Generate reference code, compile run and verify
    param_module = Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path])
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    refname = f'ref_{module.name}_{frontend}'
    reference = jit_compile_lib([module, param_module], path=tmp_path, name=refname, builder=builder)

    v1 = reference.inline_const_param_kind_mod.inline_const_param_kind()
    assert v1 == 5.
    (tmp_path/f'{module.name}.f90').unlink()
    (tmp_path/f'{param_module.name}.f90').unlink()

    # Now transform with supplied elementals but without module
    module = Module.from_source(fcode, definitions=param_module, frontend=frontend, xmods=[tmp_path])
    assert len(FindNodes(ir.Import).visit(module['inline_const_param_kind'].spec)) == 1
    for routine in module.subroutines:
        inline_constant_parameters(routine, external_only=True)
    assert not FindNodes(ir.Import).visit(module['inline_const_param_kind'].spec)

    # Hack: rename module to use a different filename in the build
    module.name = f'{module.name}_'
    obj = jit_compile_lib([module], path=tmp_path, name=f'{module.name}_{frontend}', builder=builder)

    v1 = obj.inline_const_param_kind_mod_.inline_const_param_kind()
    assert v1 == 5.

    (tmp_path/f'{module.name}.f90').unlink()


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_inline_constant_parameters_replace_kind(tmp_path, builder, frontend):
    """
    Test correct inlining of constant parameters for kind symbols.
    """
    fcode_module = """
module replace_kind_parameters_mod
  implicit none
  integer, parameter :: jprb = selected_real_kind(13, 300)
end module replace_kind_parameters_mod
"""

    fcode = """
module inline_param_repl_kind_mod
  implicit none
contains
  subroutine inline_param_repl_kind(v1)
    use replace_kind_parameters_mod, only: jprb
    real(kind=jprb), intent(out) :: v1
    real(kind=jprb) :: a = 3._JPRB

    v1 = 1._jprb + real(2, kind=jprb) + a
  end subroutine inline_param_repl_kind
end module inline_param_repl_kind_mod
"""
    # Generate reference code, compile run and verify
    param_module = Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path])
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    refname = f'ref_{module.name}_{frontend}'
    reference = jit_compile_lib([module, param_module], path=tmp_path, name=refname, builder=builder)
    func = getattr(getattr(reference, module.name), module.subroutines[0].name)

    v1 = func()
    assert v1 == 6.
    (tmp_path/f'{module.name}.f90').unlink()
    (tmp_path/f'{param_module.name}.f90').unlink()

    # Now transform with supplied elementals but without module
    module = Module.from_source(fcode, definitions=param_module, frontend=frontend, xmods=[tmp_path])
    imports = FindNodes(ir.Import).visit(module.subroutines[0].spec)
    assert len(imports) == 1 and imports[0].module.lower() == param_module.name.lower()
    for routine in module.subroutines:
        inline_constant_parameters(routine, external_only=True)
        replace_selected_kind(routine)
    imports = FindNodes(ir.Import).visit(module.subroutines[0].spec)
    assert len(imports) == 1 and imports[0].module.lower() == 'iso_fortran_env'

    # Hack: rename module to use a different filename in the build
    module.name = f'{module.name}_'
    obj = jit_compile_lib([module], path=tmp_path, name=f'{module.name}_{frontend}', builder=builder)

    func = getattr(getattr(obj, module.name), module.subroutines[0].name)
    v1 = func()
    assert v1 == 6.

    (tmp_path/f'{module.name}.f90').unlink()


@pytest.mark.parametrize('frontend', available_frontends())
def test_constant_replacement_internal(tmp_path, frontend):
    """
    Test constant replacement for internally defined constants.
    """
    fcode = """
subroutine kernel(a, b)
  integer, parameter :: par = 10
  integer, parameter :: par2 = 0
  integer, intent(inout) :: a, b
  real, parameter :: par_x = 1.3
  real :: x, y
  logical, parameter :: flag1 = .true.

  x = 0.0
  y = 1.4
  if (flag1 .and. par2 .eq. 0) then
    x = y + par_x
  endif

  a = b + par + ceiling(x)
end subroutine kernel
    """
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/(f'{routine.name}_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname=routine.name)
    # test original function
    a, b = np.array(1, dtype=np.int32), np.array(2, dtype=np.int32)
    function(a=a, b=b)
    assert a == 15

    inline_constant_parameters(routine=routine, external_only=False)

    transf_filepath = tmp_path/(f'{routine.name}_transf_{frontend}.f90')
    transf_function = jit_compile(routine, filepath=transf_filepath, objname=routine.name)
    # test transformed function
    transf_a, transf_b = np.array(1, dtype=np.int32), np.array(2, dtype=np.int32)
    transf_function(a=transf_a, b=transf_b)
    assert transf_a == 15

    # check IR
    assert len(routine.variables) == 4
    assert 'a' in routine.variables
    assert 'b' in routine.variables
    assert 'x' in routine.variables
    assert 'y' in routine.variables

    stmts = FindNodes(ir.Assignment).visit(routine.body)
    assert len(stmts) == 4
    assert stmts[2].rhs == 'y + 1.3'
    assert '10' in stmts[3].rhs

    conditionals = FindNodes(ir.Conditional).visit(routine.body)
    cond = conditionals[0].condition
    assert isinstance(cond, sym.LogicalAnd)
    assert cond.children[0] == parse_expr('.true.')
    assert cond.children[1] == '0 == 0'

    clean_test(filepath)
    clean_test(transf_filepath)
loki-ecmwf-0.3.6/loki/transformations/inline/tests/test_functions.py0000664000175000017500000004104515167130205026175 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest
import numpy as np

from loki import Module, Subroutine
from loki.jit_build import jit_compile_lib, Builder, Obj
from loki.expression import symbols as sym
from loki.frontend import available_frontends, OMNI
from loki.ir import (
    nodes as ir, FindNodes, FindVariables, FindInlineCalls
)
from loki.types import ProcedureType

from loki.transformations.inline import (
    inline_elemental_functions, inline_statement_functions
)


@pytest.fixture(name='builder')
def fixture_builder(tmp_path):
    yield Builder(source_dirs=tmp_path, build_dir=tmp_path)
    Obj.clear_cache()


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('return_type_via_var', (True, False))
def test_transform_inline_elemental_functions(tmp_path, builder, frontend, return_type_via_var):
    """
    Test correct inlining of elemental functions.
    """
    fcode_module = f"""
module multiply_mod
  use iso_fortran_env, only: real64
  implicit none
contains

  elemental {'real(kind=real64)' if not return_type_via_var else ''} function multiply(a, b)
    {'real(kind=real64) :: multiply' if return_type_via_var else ''}
    real(kind=real64), intent(in) :: a, b
    real(kind=real64) :: temp
    !$loki routine seq

    ! simulate multi-line function
    temp = a * b
    multiply = temp
  end function multiply
end module multiply_mod
"""

    fcode = """
subroutine transform_inline_elemental_functions(v1, v2, v3)
  use iso_fortran_env, only: real64
  use multiply_mod, only: multiply
  real(kind=real64), intent(in) :: v1
  real(kind=real64), intent(out) :: v2, v3

  v2 = multiply(v1, 6._real64)
  v3 = 600. + multiply(6._real64, 11._real64)
end subroutine transform_inline_elemental_functions
"""

    # Generate reference code, compile run and verify
    module = Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path])
    routine = Subroutine.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    refname = f'ref_{routine.name}_{"return_var" if return_type_via_var else ""}_{frontend}'
    reference = jit_compile_lib([module, routine], path=tmp_path, name=refname, builder=builder)

    v2, v3 = reference.transform_inline_elemental_functions(11.)
    assert v2 == 66.
    assert v3 == 666.

    (tmp_path/f'{module.name}.f90').unlink()
    (tmp_path/f'{routine.name}.f90').unlink()

    # Now inline elemental functions
    routine = Subroutine.from_source(fcode, definitions=module, frontend=frontend, xmods=[tmp_path])
    inline_elemental_functions(routine)

    # Make sure there are no more inline calls in the routine body
    assert not FindInlineCalls().visit(routine.body)

    # Verify correct scope of inlined elements
    assert all(v.scope is routine for v in FindVariables().visit(routine.body))

    # Ensure the !$loki routine pragma has been removed
    assert not FindNodes(ir.Pragma).visit(routine.body)

    # Hack: rename routine to use a different filename in the build
    routine.name = f'{routine.name}_'
    kernel = jit_compile_lib([routine], path=tmp_path, name=routine.name, builder=builder)

    v2, v3 = kernel.transform_inline_elemental_functions_(11.)
    assert v2 == 66.
    assert v3 == 666.

    builder.clean()
    (tmp_path/f'{routine.name}.f90').unlink()

@pytest.fixture(name='multiply_extended_mod', params=available_frontends())
def fixture_multiply_extended_mod(request, tmp_path):
    fcode_module = """
module multiply_extended_mod
  use iso_fortran_env, only: real64
  implicit none
contains

  elemental function multiply(a, b) ! result (ret_mult)
    ! real(kind=real64) :: ret_mult
    real(kind=real64) :: multiply
    real(kind=real64), intent(in) :: a, b
    real(kind=real64) :: temp

    ! simulate multi-line function
    temp = a * b
    multiply = temp
    ! ret_mult = temp
  end function multiply

  elemental function multiply_single_line(a, b)
    real(kind=real64) :: multiply_single_line
    real(kind=real64), intent(in) :: a, b
    real(kind=real64) :: temp

    multiply_single_line = a * b
  end function multiply_single_line

  elemental function add(a, b)
    real(kind=real64) :: add
    real(kind=real64), intent(in) :: a, b
    real(kind=real64) :: temp

    ! simulate multi-line function
    temp = a + b
    add = temp
  end function add
end module multiply_extended_mod
"""

    frontend = request.param
    module = Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path])
    return module, frontend

def test_transform_inline_elemental_functions_extended_scalar(multiply_extended_mod, builder, tmp_path):
    module, frontend = multiply_extended_mod

    fcode = """
subroutine transform_inline_elemental_functions_extended_scalar(v1, v2, v3)
  use iso_fortran_env, only: real64
  use multiply_extended_mod, only: multiply, multiply_single_line, add
  real(kind=real64), intent(in) :: v1
  real(kind=real64), intent(out) :: v2, v3
  real(kind=real64), parameter :: param1 = 100.

  v2 = multiply(v1, 6._real64) + multiply_single_line(v1, 3._real64)
  v3 = add(param1, 200._real64) + add(150._real64, 150._real64) + multiply(6._real64, 11._real64)
end subroutine transform_inline_elemental_functions_extended_scalar
"""

    routine = Subroutine.from_source(fcode, frontend=frontend, definitions=[module], xmods=[tmp_path])
    refname = f'ref_{routine.name}_{frontend}'
    reference = jit_compile_lib([module, routine], path=tmp_path, name=refname, builder=builder)
    v2, v3 = reference.transform_inline_elemental_functions_extended_scalar(11.)
    assert v2 == 99.
    assert v3 == 666.

    (tmp_path/f'{routine.name}.f90').unlink()

    # Now inline elemental functions
    routine = Subroutine.from_source(fcode, definitions=module, frontend=frontend, xmods=[tmp_path])
    inline_elemental_functions(routine)
    # Make sure there are no more inline calls in the routine body
    assert not FindInlineCalls().visit(routine.body)
    # Verify correct scope of inlined elements
    assert all(v.scope is routine for v in FindVariables().visit(routine.body))
    # Hack: rename routine to use a different filename in the build
    routine.name = f'{routine.name}_'
    kernel = jit_compile_lib([routine, module], path=tmp_path, name=routine.name, builder=builder)
    v2, v3 = kernel.transform_inline_elemental_functions_extended_scalar_(11.)
    assert v2 == 99.
    assert v3 == 666.

    builder.clean()
    (tmp_path/f'{routine.name}.f90').unlink()
    (tmp_path/f'{module.name}.f90').unlink()

def test_transform_inline_elemental_functions_extended_arr(multiply_extended_mod, builder, tmp_path):
    module, frontend = multiply_extended_mod

    fcode_arr = """
subroutine transform_inline_elemental_functions_extended_array(v1, v2, v3, len)
  use iso_fortran_env, only: real64
  use multiply_extended_mod, only: multiply, multiply_single_line, add
  integer, intent(in) :: len
  real(kind=real64), intent(in) :: v1(len)
  real(kind=real64), intent(inout) :: v2(len), v3(len)
  real(kind=real64), parameter :: param1 = 100.
  integer, parameter :: arr_index = 1

  v2 = multiply(v1(:), 6._real64) + multiply_single_line(v1(:), 3._real64)
  v3 = add(param1, 200._real64) + add(v1, 150._real64) + multiply(v1(arr_index), v2(1))
end subroutine transform_inline_elemental_functions_extended_array
"""

    routine = Subroutine.from_source(fcode_arr, frontend=frontend, definitions=[module], xmods=[tmp_path])
    refname = f'ref_{routine.name}_frontend'
    reference = jit_compile_lib([module, routine], path=tmp_path, name=refname, builder=builder)
    arr_len = 5
    v1 = np.array([1.0, 2.0, 3.0, 5.0, 3.0], dtype=np.float64, order='F')
    v2 = np.zeros((arr_len,), dtype=np.float64, order='F')
    v3 = np.zeros((arr_len,), dtype=np.float64, order='F')
    reference.transform_inline_elemental_functions_extended_array(v1, v2, v3, arr_len)
    assert (v2 == np.array([9., 18., 27., 45., 27.], dtype=np.float64, order='F')).all()
    assert (v3 == np.array([460., 461., 462., 464., 462.], dtype=np.float64, order='F')).all()

    (tmp_path/f'{routine.name}.f90').unlink()

    routine = Subroutine.from_source(fcode_arr, definitions=module, frontend=frontend, xmods=[tmp_path])
    inline_elemental_functions(routine)
    # TODO: Make sure there are no more inline calls in the routine body
    #  assert not FindInlineCalls().visit(routine.body)
    #  this is currently not achievable as calls to elemental functions with array arguments
    #  can't be properly inlined and therefore are skipped
    # Verify correct scope of inlined elements
    assert all(v.scope is routine for v in FindVariables().visit(routine.body))
    # Hack: rename routine to use a different filename in the build
    routine.name = f'{routine.name}_'
    kernel = jit_compile_lib([routine, module], path=tmp_path, name=routine.name, builder=builder)
    v1 = np.array([1.0, 2.0, 3.0, 5.0, 3.0], dtype=np.float64, order='F')
    v2 = np.zeros((arr_len,), dtype=np.float64, order='F')
    v3 = np.zeros((arr_len,), dtype=np.float64, order='F')
    kernel.transform_inline_elemental_functions_extended_array_(v1, v2, v3, arr_len)
    assert (v2 == np.array([9., 18., 27., 45., 27.], dtype=np.float64, order='F')).all()
    assert (v3 == np.array([460., 461., 462., 464., 462.], dtype=np.float64, order='F')).all()

    builder.clean()
    (tmp_path/f'{routine.name}.f90').unlink()
    (tmp_path/f'{module.name}.f90').unlink()


@pytest.mark.parametrize('frontend', available_frontends(
    skip={OMNI: "OMNI automatically inlines Statement Functions"}
))
@pytest.mark.parametrize('stmt_decls', (True, False))
def test_inline_statement_functions(frontend, stmt_decls):
    stmt_decls_code = """
    real :: PTARE
    real :: FOEDELTA
    FOEDELTA ( PTARE ) = PTARE + 1.0
    real :: FOEEW
    FOEEW ( PTARE ) = PTARE + FOEDELTA(PTARE) + EXP(PTARE)
    """.strip()

    fcode = f"""
subroutine stmt_func(arr, ret)
    implicit none
    real, intent(in) :: arr(:)
    real, intent(inout) :: ret(:)
    real :: ret2
    real, parameter :: rtt = 1.0
    {stmt_decls_code if stmt_decls else '#include "fcttre.func.h"'}

    ret = foeew(arr)
    ret2 = foedelta(3.0)
end subroutine stmt_func
     """.strip()

    routine = Subroutine.from_source(fcode, frontend=frontend)
    if stmt_decls:
        assert FindNodes(ir.StatementFunction).visit(routine.spec)
    else:
        assert not FindNodes(ir.StatementFunction).visit(routine.spec)
    assert FindInlineCalls().visit(routine.body)
    inline_statement_functions(routine)

    assert not FindNodes(ir.StatementFunction).visit(routine.spec)
    if stmt_decls:
        assert len(FindInlineCalls().visit(routine.body)) == 1
        assignments = FindNodes(ir.Assignment).visit(routine.body)
        assert assignments[0].lhs  == 'ret'
        assert assignments[0].rhs  ==  "arr + arr + 1.0 + exp(arr)"
        assert assignments[1].lhs  == 'ret2'
        assert assignments[1].rhs  ==  "3.0 + 1.0"
    else:
        assert FindInlineCalls().visit(routine.body)

@pytest.mark.parametrize('frontend', available_frontends(
    skip={OMNI: "OMNI automatically inlines Statement Functions"}
))
@pytest.mark.parametrize('provide_myfunc', ('import', 'module', 'interface', 'intfb', 'routine'))
def test_inline_statement_functions_inline_call(frontend, provide_myfunc, tmp_path):
    fcode_myfunc = """
elemental function myfunc(a)
    real, intent(in) :: a
    real :: myfunc
    myfunc = a * 2.0
end function myfunc
    """.strip()

    if provide_myfunc == 'module':
        fcode_myfunc = f"""
module my_mod
implicit none
contains
{fcode_myfunc}
end module my_mod
        """.strip()

    if provide_myfunc in ('import', 'module'):
        module_import = 'use my_mod, only: myfunc'
    else:
        module_import = ''

    if provide_myfunc == 'interface':
        intf = """
            interface
            elemental function myfunc(a)
                implicit none
                real a
                real myfunc
            end function myfunc
            end interface
        """
    elif provide_myfunc in ('intfb', 'routine'):
        intf = '#include "myfunc.intfb.h"'
    else:
        intf = ''

    fcode = f"""
subroutine stmt_func(arr, val, ret)
    {module_import}
    implicit none
    real, intent(in) :: arr(:)
    real, intent(in) :: val
    real, intent(inout) :: ret(:)
    real :: ret2
    real, parameter :: rtt = 1.0
    real :: PTARE
    real :: FOEDELTA
    FOEDELTA ( PTARE ) = PTARE + 1.0 + MYFUNC(PTARE)
    real :: FOEEW
    FOEEW ( PTARE ) = PTARE + FOEDELTA(PTARE) + MYFUNC(PTARE)
    {intf}

    ret = foeew(arr)
    ret2 = foedelta(3.0) + foedelta(val)
end subroutine stmt_func
    """.strip()

    if provide_myfunc == 'module':
        definitions = (Module.from_source(fcode_myfunc, xmods=[tmp_path]),)
    elif provide_myfunc == 'routine':
        definitions = (Subroutine.from_source(fcode_myfunc, xmods=[tmp_path]),)
    else:
        definitions = None
    routine = Subroutine.from_source(fcode, frontend=frontend, definitions=definitions, xmods=[tmp_path])

    # Check the spec
    statement_funcs = FindNodes(ir.StatementFunction).visit(routine.spec)
    assert len(statement_funcs) == 2

    inline_calls = FindInlineCalls(unique=False).visit(routine.spec)
    if provide_myfunc in ('module', 'interface', 'routine'):
        # Enough information available that MYFUNC is recognized as a procedure call
        assert len(inline_calls) == 3
        assert all(isinstance(call.function.type.dtype, ProcedureType) for call in inline_calls)
    else:
        # No information available about MYFUNC, so fparser treats it as an ArraySubscript
        assert len(inline_calls) == 1
        assert inline_calls[0].function == 'foedelta'
        assert isinstance(inline_calls[0].function.type.dtype, ProcedureType)

    # Check the body
    inline_calls = FindInlineCalls().visit(routine.body)
    assert len(inline_calls) == 3

    # Apply the transformation
    inline_statement_functions(routine)

    # Check the outcome
    assert not FindNodes(ir.StatementFunction).visit(routine.spec)
    inline_calls = FindInlineCalls(unique=False).visit(routine.body)
    assignments = FindNodes(ir.Assignment).visit(routine.body)

    if provide_myfunc in ('import', 'intfb'):
          # MYFUNC(arr) is misclassified as array subscript
        assert len(inline_calls) == 0
    elif provide_myfunc in ('module', 'routine'):
          # MYFUNC(arr) is eliminated due to inlining
        assert len(inline_calls) == 0
    else:
        assert len(inline_calls) == 4

    assert assignments[0].lhs  == 'ret'
    assert assignments[1].lhs  == 'ret2'
    if provide_myfunc in ('module', 'routine'):
        # Fully inlined due to definition of myfunc available
        assert assignments[0].rhs  ==  "arr + arr + 1.0 + arr*2.0 + arr*2.0"
        assert assignments[1].rhs  ==  "3.0 + 1.0 + 3.0*2.0 + val + 1.0 + val*2.0"
    else:
        # myfunc not inlined
        assert assignments[0].rhs  ==  "arr + arr + 1.0 + myfunc(arr) + myfunc(arr)"
        assert assignments[1].rhs  ==  "3.0 + 1.0 + myfunc(3.0) + val + 1.0 + myfunc(val)"


@pytest.mark.parametrize('frontend', available_frontends())
def test_inline_elemental_functions_intrinsic_procs(frontend):
    fcode = """
subroutine test_inline_elementals(a)
implicit none
  integer, parameter :: jprb = 8
  real(kind=jprb), intent(inout) :: a

  a = fminj(0.5, a)
contains
  pure elemental function fminj(x,y) result(m)
    real(kind=jprb), intent(in) :: x, y
    real(kind=jprb) :: m

    m = y - 0.5_jprb*(abs(x-y)-(x-y))
  end function fminj
end subroutine test_inline_elementals
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assigns) == 1
    assert isinstance(assigns[0].rhs.function, sym.ProcedureSymbol)
    assert assigns[0].rhs.function.type.dtype.procedure == routine.members[0]

    # Ensure we have an intrinsic in the internal elemental function
    inline_calls = tuple(FindInlineCalls().visit(routine.members[0].body))
    assert len(inline_calls) == 1
    assert inline_calls[0].function.type.is_intrinsic
    assert inline_calls[0].function.scope == routine.members[0]

    inline_elemental_functions(routine)

    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assigns) == 2

    # Ensure that the intrinsic function has been rescoped
    inline_calls = tuple(FindInlineCalls().visit(assigns[0]))
    assert len(inline_calls) == 1
    assert inline_calls[0].function.type.is_intrinsic
    assert inline_calls[0].function.scope == routine
loki-ecmwf-0.3.6/loki/transformations/inline/transformation.py0000664000175000017500000002101315167130205025023 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.batch import Transformation
from loki.transformations.remove_code import do_remove_dead_code

from loki.ir import pragmas_attached, CallStatement, FindNodes, is_loki_pragma
from loki.tools.util import as_tuple, CaseInsensitiveDict, OrderedSet
from loki.transformations.inline.constants import inline_constant_parameters
from loki.transformations.inline.functions import (
    inline_elemental_functions, inline_statement_functions
)
from loki.transformations.inline.procedures import (
    inline_internal_procedures, inline_marked_subroutines,
    resolve_sequence_association_for_inlined_calls
)


__all__ = ['InlineTransformation']


class InlineTransformation(Transformation):
    """
    :any:`Transformation` class to apply several types of source inlining
    when batch-processing large source trees via the :any:`Scheduler`.

    Parameters
    ----------
    inline_constants : bool
        Replace instances of variables with known constant values by
        :any:`Literal` (see :any:`inline_constant_parameters`); default: False.
    inline_elementals : bool
        Replaces :any:`InlineCall` expression to elemental functions
        with the called function's body (see :any:`inline_elemental_functions`);
        default: True.
    inline_stmt_funcs: bool
        Replaces  :any:`InlineCall` expression to statement functions
        with the corresponding rhs of the statement function if
        the statement function declaration is available; default: False.
    inline_internals : bool
        Inline internal procedure (see :any:`inline_internal_procedures`);
        default: False.
    inline_marked : bool
        Inline :any:`Subroutine` objects marked by pragma annotations
        (see :any:`inline_marked_subroutines`); default: True.
    remove_dead_code : bool
        Perform dead code elimination, where unreachable branches are
        trimmed from the code (see :any:`dead_code_elimination`); default: True
    allowed_aliases : tuple or list of str or :any:`Expression`, optional
        List of variables that will not be renamed in the parent scope during
        internal and pragma-driven inlining.
    adjust_imports : bool
        Adjust imports by removing the symbol of the inlined routine or adding
        imports needed by the imported routine (optional, default: True)
    external_only : bool, optional
        Do not replace variables declared in the local scope when
        inlining constants (default: True)
    resolve_sequence_association: bool
        Resolve sequence association for routines that contain calls to inline (default: False)
    """

    # Ensure correct recursive inlining by traversing from the leaves
    reverse_traversal = True

    # This transformation will potentially change the edges in the callgraph
    creates_items = False

    def __init__(
            self, inline_constants=False, inline_elementals=True,
            inline_stmt_funcs=False, inline_internals=False,
            inline_marked=True, remove_dead_code=True,
            allowed_aliases=None, adjust_imports=True,
            external_only=True, resolve_sequence_association=False
    ):
        self.inline_constants = inline_constants
        self.inline_elementals = inline_elementals
        self.inline_stmt_funcs = inline_stmt_funcs
        self.inline_internals = inline_internals
        self.inline_marked = inline_marked
        self.remove_dead_code = remove_dead_code
        self.allowed_aliases = allowed_aliases
        self.adjust_imports = adjust_imports
        self.external_only = external_only
        self.resolve_sequence_association = resolve_sequence_association
        if self.inline_marked:
            self.creates_items = True

    def transform_subroutine(self, routine, **kwargs):

        # Resolve sequence association in calls that are about to be inlined.
        # This step runs only if all of the following hold:
        # 1) it is requested by the user
        # 2) inlining of "internals" or "marked" routines is activated
        # 3) there is an "internal" or "marked" procedure to inline.
        if self.resolve_sequence_association:
            resolve_sequence_association_for_inlined_calls(
                routine, self.inline_internals, self.inline_marked
            )

        # Replace constant parameter variables with explicit values
        if self.inline_constants:
            inline_constant_parameters(routine, external_only=self.external_only)

        # Inline elemental functions
        if self.inline_elementals:
            inline_elemental_functions(routine)

        # Inline Statement Functions
        if self.inline_stmt_funcs:
            inline_statement_functions(routine)

        # Inline internal (contained) procedures
        if self.inline_internals:
            inline_internal_procedures(routine, allowed_aliases=self.allowed_aliases)

        # Inline explicitly pragma-marked subroutines
        if self.inline_marked:
            inline_marked_subroutines(
                routine, allowed_aliases=self.allowed_aliases,
                adjust_imports=self.adjust_imports
            )

        # After inlining, attempt to trim unreachable code paths
        if self.remove_dead_code:
            do_remove_dead_code(routine)

    def plan_subroutine(self, routine, **kwargs):

        if not self.inline_marked:
            return

        item = kwargs.get('item')
        sub_sgraph = kwargs.get('sub_sgraph', None)
        successors = sub_sgraph.successors(item) if sub_sgraph is not None else ()
        successor_map = CaseInsensitiveDict(
            (successor.local_name, successor) for successor in successors
        )
        # look for call statements with pragmas attached
        with pragmas_attached(routine, node_type=CallStatement):
            inline_calls = OrderedSet()
            not_inline_calls = OrderedSet()
            calls = FindNodes(CallStatement).visit(routine.ir)
            # for all calls sort those having '!$loki inline' and those not having it
            for call in calls:
                if is_loki_pragma(call.pragma, starts_with='inline'):
                    inline_calls.add(str(call.name).lower())
                else:
                    not_inline_calls.add(str(call.name).lower())
        # Determine the list of routines that will be completely inlined and therefore no longer be dependencies
        # of the current item. If calls to the same routine remain non-inlined, the dependency remains, too.
        removed_calls = inline_calls - not_inline_calls
        rename_map = CaseInsensitiveDict(
                (s.name, s.type.use_name if s.type.use_name else s.name)
                for imprt in reversed(getattr(routine, 'imports', ()))
                for s in imprt.symbols or [r[1] for r in imprt.rename_list or ()]
                )
        inline_items = [successor_map[rename_map.get(call, call)] for call in removed_calls
                # this shouldn't be necessary, however, if for example a call is marked as to be inlined
                # within a loki remove pragma region this could otherwise end up throwing an error
                if rename_map.get(call, call) in successor_map]
        # Add fully inlined dependencies to the 'removed_dependencies' list in the plan data to indicate that
        # they will no longer be dependents. At the same time add any dependencies of inlined successors
        # to the current item's dependencies as these will be inherited as dependents (unless they are also
        # inlined)
        if inline_items:
            item.plan_data.setdefault('removed_dependencies', ())
            item.plan_data.setdefault('additional_dependencies', ())
            item.plan_data['removed_dependencies'] += as_tuple(inline_items)
            additional_dep = ()
            for inline_item in inline_items:
                inlined_successors = sub_sgraph.successors(inline_item) + \
                        inline_item.plan_data.get('additional_dependencies', ())
                for inlined_successor in inlined_successors:
                    if inlined_successor not in inline_item.plan_data.get('removed_dependencies', ()):
                        additional_dep += (inlined_successor,)
            item.plan_data['additional_dependencies'] += as_tuple(OrderedSet(additional_dep))
loki-ecmwf-0.3.6/loki/transformations/inline/procedures.py0000664000175000017500000004622715167130205024146 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from collections import defaultdict, ChainMap

from loki.ir import (
    Import, Comment, VariableDeclaration, CallStatement, Transformer,
    FindNodes, FindVariables, FindInlineCalls, SubstituteExpressions,
    pragmas_attached, is_loki_pragma, Interface, Pragma, AttachScopes
)
from loki.expression import symbols as sym, simplify
from loki.types import BasicType
from loki.tools import as_tuple, CaseInsensitiveDict
from loki.logging import error
from loki.subroutine import Subroutine

from loki.transformations.sanitise import SequenceAssociationTransformer
from loki.transformations.utilities import (
    single_variable_declaration, recursive_expression_map_update
)


__all__ = [
    'inline_internal_procedures', 'inline_member_procedures',
    'inline_marked_subroutines',
    'resolve_sequence_association_for_inlined_calls'
]


def resolve_sequence_association_for_inlined_calls(routine, inline_internals, inline_marked):
    """
    Resolve sequence association in calls to all member procedures (if ``inline_internals = True``)
    or in calls to procedures that have been marked with an inline pragma (if ``inline_marked = True``).
    If both ``inline_internals`` and ``inline_marked`` are ``False``, no processing is done.
    """
    class SequenceAssociationForInlineCallsTransformer(SequenceAssociationTransformer):

        def visit_CallStatement(self, call, **kwargs):
            condition = (
                (inline_marked and is_loki_pragma(call.pragma, starts_with='inline')) or
                (inline_internals and call.routine in routine.routines)
            )
            if condition:
                if call.routine == BasicType.DEFERRED:
                    # NOTE: Throwing error here instead of continuing, because the user has explicitly
                    # asked sequence assoc to happen with inlining, so source for routine should be
                    # found in calls to be inlined.
                    raise ValueError(
                        f"Cannot resolve sequence association for call to ``{call.name}`` " +
                        f"to be inlined in routine ``{routine.name}``, because " +
                        f"the ``CallStatement`` referring to ``{call.name}`` does not contain " +
                        "the source code of the procedure. " +
                        "If running in batch processing mode, please recheck Scheduler configuration."
                    )

            return super().visit_CallStatement(call, **kwargs)

    with pragmas_attached(routine, node_type=CallStatement):
        routine.body = SequenceAssociationForInlineCallsTransformer(inplace=True).visit(routine.body)


def map_call_to_procedure_body(call, caller, callee=None):
    """
    Resolve arguments of a call and map to the called procedure body.

    Parameters
    ----------
    call : :any:`CallStatment` or :any:`InlineCall`
         Call object that defines the argument mapping
    caller : :any:`Subroutine`
         Procedure (scope) into which the callee's body gets mapped
    callee : :any:`Subroutine`, optional
         Procedure (scope) called. Provide if it differs from
         call.routine.
    """

    def _map_unbound_dims(var, val):
        """
        Maps all unbound dimension ranges in the passed array value
        ``val`` with the indices from the local variable ``var``. It
        returns the re-mapped symbol.

        For example, mapping the passed array ``m(:,j)`` to the local
        expression ``a(i)`` yields ``m(i,j)``.
        """

        def _offset_lbound(local_lbound, decl_lbound, v):
            _sum = sym.Product((-1, decl_lbound))
            _sum = sym.Sum((_sum, local_lbound, v))
            return simplify(_sum)

        new_dimensions = list(val.dimensions)

        indices = [index for index, dim in enumerate(val.dimensions) if isinstance(dim, sym.Range)]

        lbounds_diff = [sym.IntLiteral(0) for _ in var.shape]
        var_ubounds = [getattr(dim, 'upper', dim) for dim in var.shape]
        if var.shape and val.shape:
            decl_lbounds = [(getattr(val.shape[i], 'lower', sym.IntLiteral(1)),
                             getattr(dim, 'lower', sym.IntLiteral(1))) for i, dim in enumerate(var.shape)]

            for i, (lb_val, lb_var) in enumerate(decl_lbounds):
                # we can't simply check if lb_val here as that would return a false negative if lb_val == 0
                if lb_val is not None and lb_var is not None:
                    lbounds_diff[i] = simplify(sym.Sum((lb_val, sym.Product((lb_var, sym.IntLiteral(-1))))))

        for (index, dim), lbdiff in zip(enumerate(var.dimensions), lbounds_diff):
            # if the argument contains an array range, we must map the bounds accordingly
            if isinstance(val.dimensions[index], sym.Range) and (lower := val.dimensions[index].lower):
                lower = simplify(sym.Sum((lower, lbdiff)))
                decl_lbound = decl_lbounds[index][0]
                if isinstance(dim, sym.Range):
                    _lower = dim.lower or decl_lbounds[index][1]
                    _upper = dim.upper or var_ubounds[index]

                    _lower = _offset_lbound(lower, decl_lbound, _lower)
                    _upper = _offset_lbound(lower, decl_lbound, _upper)

                    new_dimensions[indices[index]] = sym.Range((_lower, _upper))
                else:
                    new_dimensions[indices[index]] = _offset_lbound(lower, decl_lbound, dim)
            else:
                new_dimensions[indices[index]] = simplify(sym.Sum((dim, lbdiff)))

        return val.clone(dimensions=tuple(new_dimensions))

    # Get callee from the procedure type
    callee = callee or call.routine
    if callee is BasicType.DEFERRED:
        error(
            '[Loki::TransformInline] Need procedure definition to resolve '
            f'call to {call.name} from {caller}'
        )
        raise RuntimeError('Procedure definition not found! ')

    argmap = {}
    callee_vars = FindVariables().visit(callee.body)

    # Match dimension indexes between the argument and the given value
    # for all occurences of the argument in the body
    for arg, val in call.arg_map.items():
        if isinstance(arg, sym.Array):
            # Resolve implicit dimension ranges of the passed value,
            # eg. when passing a two-dimensional array `a` as `call(arg=a)`
            # Check if val is a DeferredTypeSymbol, as it does not have a `dimensions` attribute
            if not isinstance(val, sym.DeferredTypeSymbol) and val.dimensions:
                qualified_value = val
            else:
                qualified_value = val.clone(
                    dimensions=tuple(sym.Range((None, None)) for _ in arg.shape)
                )

            # If sequence association (scalar-to-array argument passing) is used,
            # we cannot determine the right re-mapped iteration space, so we bail here!
            if not any(isinstance(d, sym.Range) for d in qualified_value.dimensions):
                error(
                    '[Loki::TransformInline] Cannot find free dimension resolving '
                    f' array argument for value "{qualified_value}"'
                )
                raise RuntimeError(
                    f'[Loki::TransformInline] Cannot resolve procedure call to {call.name}'
                )
            arg_vars = tuple(v for v in callee_vars if v.name == arg.name)
            argmap.update((v, _map_unbound_dims(v, qualified_value)) for v in arg_vars)
        else:
            argmap[arg] = val

    # Deal with PRESENT check for optional arguments
    present_checks = tuple(
        check for check in FindInlineCalls().visit(callee.body) if check.function == 'PRESENT'
    )
    present_map = {
        check: sym.Literal('.true.') if check.arguments[0] in [arg.name for arg in call.arg_map]
                                     else sym.Literal('.false.')
        for check in present_checks
    }
    argmap.update(present_map)

    # Recursive update of the map in case of nested variables to map
    argmap = recursive_expression_map_update(argmap, max_iterations=10)

    # Substitute argument calls into a copy of the body
    callee_body = SubstituteExpressions(argmap, rebuild_scopes=True).visit(
        callee.body.body, scope=caller
    )

    # Remove 'loki routine' pragmas
    callee_body = Transformer(
        {pragma: None for pragma in FindNodes(Pragma).visit(callee_body)
         if is_loki_pragma(pragma, starts_with='routine')}
    ).visit(callee_body)

    # Ensure all symbols are rescoped to the caller
    AttachScopes().visit(callee_body, scope=caller)

    # Inline substituted body within a pair of marker comments
    comment = Comment(f'! [Loki] inlined child subroutine: {callee.name}')
    c_line = Comment('! =========================================')
    return (comment, c_line) + as_tuple(callee_body) + (c_line, )


def inline_subroutine_calls(routine, calls, callee, allowed_aliases=None):
    """
    Inline a set of call to an individual :any:`Subroutine` at source level.

    This will replace all :any:`Call` objects to the specified
    subroutine with an adjusted equivalent of the member routines'
    body. For this, argument matching, including partial dimension
    matching for array references is performed, and all
    member-specific declarations are hoisted to the containing
    :any:`Subroutine`.

    Parameters
    ----------
    routine : :any:`Subroutine`
        The subroutine in which to inline all calls to the member routine
    calls : tuple or list of :any:`CallStatement`
    callee : :any:`Subroutine`
        The called target subroutine to be inlined in the parent
    allowed_aliases : tuple or list of str or :any:`Expression`, optional
        List of variables that will not be renamed in the parent scope, even
        if they alias with a local declaration.
    """
    allowed_aliases = as_tuple(allowed_aliases)

    # Ensure we process sets of calls to the same callee
    assert all(call.routine == callee for call in calls)
    assert isinstance(callee, Subroutine)

    # Prevent shadowing of callee's variables by renaming them a priori
    parent_variables = routine.variable_map
    duplicates = tuple(
        v for v in callee.variables
        if v.name in parent_variables and v.name.lower() not in callee._dummies
    )
    # Filter out allowed aliases to prevent suffixing
    duplicates = tuple(v for v in duplicates if v.symbol not in allowed_aliases)
    shadow_mapper = SubstituteExpressions(
        {v: v.clone(name=f'{callee.name}_{v.name}') for v in duplicates}
    )
    callee.spec = shadow_mapper.visit(callee.spec)

    var_map = {}
    duplicate_names = {dl.name.lower() for dl in duplicates}
    for v in FindVariables(unique=False).visit(callee.body):
        if v.name.lower() in duplicate_names:
            var_map[v] = v.clone(name=f'{callee.name}_{v.name}')
    var_map = recursive_expression_map_update(var_map)
    callee.body = SubstituteExpressions(var_map).visit(callee.body)

    # Separate allowed aliases from other variables to ensure clean hoisting
    if allowed_aliases:
        single_variable_declaration(callee, variables=allowed_aliases)

    # Get local variable declarations and hoist them
    decls = FindNodes(VariableDeclaration).visit(callee.spec)
    decls = tuple(d for d in decls if all(s.name.lower() not in callee._dummies for s in d.symbols))
    decls = tuple(d for d in decls if all(s not in routine.variables for s in d.symbols))
    # Rescope the declaration symbols
    decls = tuple(d.clone(symbols=tuple(s.clone(scope=routine) for s in d.symbols)) for d in decls)

    # Find and apply symbol remappings for array size expressions
    symbol_map = dict(ChainMap(*[call.arg_map for call in calls]))
    decls = SubstituteExpressions(symbol_map).visit(decls)

    routine.spec.append(decls)

    # Resolve the call by mapping arguments into the called procedure's body
    call_map = {
        call: map_call_to_procedure_body(call, caller=routine) for call in calls
    }

    # Replace calls to child procedure with the child's body
    routine.body = Transformer(call_map).visit(routine.body)

    # We need this to ensure that symbols, as well as nested scopes
    # are correctly attached to each other (eg. nested associates).
    routine.rescope_symbols()


def inline_internal_procedures(routine, allowed_aliases=None):
    """
    Inline internal subroutines contained in an individual :any:`Subroutine`.

    Please note that internal functions are not yet supported!

    Parameters
    ----------
    routine : :any:`Subroutine`
        The subroutine in which to inline all member routines
    allowed_aliases : tuple or list of str or :any:`Expression`, optional
        List of variables that will not be renamed in the parent scope, even
        if they alias with a local declaration.
    """

    from loki.transformations.inline import inline_functions  # pylint: disable=cyclic-import,import-outside-toplevel

    # Run through all members and invoke individual inlining transforms
    for child in routine.members:
        if child.is_function:
            inline_functions(routine, functions=(child,))
        else:
            calls = tuple(
                call for call in FindNodes(CallStatement).visit(routine.body)
                if call.routine == child
            )
            inline_subroutine_calls(routine, calls, child, allowed_aliases=allowed_aliases)

        # Can't use transformer to replace subroutine/function, so strip it manually
        contains_body = tuple(n for n in routine.contains.body if not n == child)
        routine.contains._update(body=contains_body)


inline_member_procedures = inline_internal_procedures


def inline_marked_subroutines(routine, allowed_aliases=None, adjust_imports=True):
    """
    Inline :any:`Subroutine` objects guided by pragma annotations.

    When encountering :any:`CallStatement` objects that are marked with a
    ``!$loki inline`` pragma, this utility will attempt to replace the call
    with the body of the called procedure and remap all passed arguments
    into the calling procedures scope.

    Please note that this utility requires :any:`CallStatement` objects
    to be "enriched" with external type information.

    Parameters
    ----------
    routine : :any:`Subroutine`
        The subroutine in which to look for pragma-marked procedures to inline
    allowed_aliases : tuple or list of str or :any:`Expression`, optional
        List of variables that will not be renamed in the parent scope, even
        if they alias with a local declaration.
    adjust_imports : bool
        Adjust imports by removing the symbol of the inlined routine or adding
        imports needed by the imported routine (optional, default: True)
    """

    with pragmas_attached(routine, node_type=CallStatement):

        # Group the marked calls by callee routine
        call_sets = defaultdict(list)
        no_call_sets = defaultdict(list)
        for call in FindNodes(CallStatement).visit(routine.body):
            if call.routine == BasicType.DEFERRED:
                continue

            if is_loki_pragma(call.pragma, starts_with='inline'):
                call_sets[call.routine].append(call)
            else:
                no_call_sets[call.routine].append(call)

        # Trigger per-call inlining on collected sets
        for callee, calls in call_sets.items():
            if callee:  # Skip the unattached calls (collected under None)
                inline_subroutine_calls(
                    routine, calls, callee, allowed_aliases=allowed_aliases
                )

            if adjust_imports:
                # Move imports that the callee uses up to the caller
                propagate_callee_imports(routine, callee)

    # Remove imported symbols that have become obsolete
    if adjust_imports:
        callees = tuple(callee.procedure_symbol for callee in call_sets.keys())
        not_inlined = tuple(callee.procedure_symbol for callee in no_call_sets.keys())

        import_map = {}
        for impt in FindNodes(Import).visit(routine.spec):
            # Remove interface header imports
            if any(f'{c.name.lower()}.intfb.h' == impt.module for c in callees):
                import_map[impt] = None

            if any(s.name in callees for s in impt.symbols):
                new_symbols = tuple(
                    s for s in impt.symbols if s.name not in callees or s.name in not_inlined
                )
                # Remove import if no further symbols used, otherwise clone with new symbols
                import_map[impt] = impt.clone(symbols=new_symbols) if new_symbols else None

        # Remove explicit interfaces of inlined routines
        for intf in routine.interfaces:
            if not intf.spec:
                _body = tuple(
	                    s.type.dtype.procedure for s in intf.symbols
	                    if s.name not in callees or s.name in not_inlined
                )
                if _body:
                    import_map[intf] = intf.clone(body=_body)
                else:
                    import_map[intf] = None

        # Finally, apply the import remapping
        routine.spec = Transformer(import_map).visit(routine.spec)

        # Add missing explicit interfaces from inlined subroutines
        new_intfs = []
        intf_symbols = routine.interface_symbols
        for callee in call_sets.keys():
            for intf in callee.interfaces:
                for s in intf.symbols:
                    if not s in intf_symbols:
                        new_intfs += [s.type.dtype.procedure,]

        if new_intfs:
            routine.spec.append(Interface(body=as_tuple(new_intfs)))


def propagate_callee_imports(routine, callee):
    """
    Move any :any:`Import` nodes from the :data:`callee` routine to
    the caller, trimming symbols where needed.

    Parameters
    ----------
    routine : :any:`Subroutine`
        The subroutine to which to propagate imports.
    callee : :any:`Subroutine`
        The subroutine from which to get the relevant imports.
    """

    # Now move any callee imports we might need over to the caller
    new_imports = tuple()
    imported_module_map = CaseInsensitiveDict((im.module, im) for im in routine.imports)

    for impt in callee.imports:
        # Add any callee module we do not yet know
        if impt.module not in imported_module_map:
            new_imports += (impt,)

        # If we're importing the same module, check for missing symbols
        if m := imported_module_map.get(impt.module):
            if not all(s in m.symbols for s in impt.symbols):
                # Add new, rescoped symbols in-place
                new_symbols = tuple(s.rescope(routine) for s in impt.symbols)
                m._update(symbols=tuple(dict.fromkeys(m.symbols + new_symbols)))

    # Add Fortran imports to the top, and C-style interface headers at the bottom
    c_imports = tuple(im for im in new_imports if im.c_import)
    f_imports = tuple(im for im in new_imports if not im.c_import)
    routine.spec.prepend(f_imports)
    routine.spec.append(c_imports)
loki-ecmwf-0.3.6/loki/transformations/inline/constants.py0000664000175000017500000000725415167130205024004 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.ir import (
    Import, Comment, Transformer, FindNodes, FindVariables,
    FindLiterals, SubstituteExpressions
)
from loki.expression import symbols as sym


__all__ = ['inline_constant_parameters']


def inline_constant_parameters(routine, external_only=True):
    """
    Replace instances of variables with known constant values by `Literals`.

    Notes
    -----
    The ``.type.initial`` property is used to derive the replacement
    value,a which means for symbols imported from external modules,
    the parent :any:`Module` needs to be supplied in the
    ``definitions`` to the constructor when creating the
    :any:`Subroutine`.

    Variables that are replaced are also removed from their
    corresponding import statements, with empty import statements
    being removed alltogether.

    Parameters
    ----------
    routine : :any:`Subroutine`
         Procedure in which to inline/resolve constant parameters.
    external_only : bool, optional
        Do not replace variables declared in the local scope (default: True)
    """
    # Find all variable instances in spec and body
    variables = FindVariables().visit(routine.ir)

    # Filter out variables declared locally
    if external_only:
        variables = [v for v in variables if v not in routine.variables]

    def is_inline_parameter(v):
        return hasattr(v, 'type') and v.type.parameter and v.type.initial is not None

    # Create mapping for variables and imports
    vmap = {v: v.type.initial for v in variables if is_inline_parameter(v)}

    # Replace kind parameters in variable types
    for variable in routine.variables:
        if is_inline_parameter(variable.type.kind):
            routine.symbol_attrs[variable.name] = variable.type.clone(kind=variable.type.kind.type.initial)
        if variable.type.initial is not None:
            # Substitute kind specifier in literals in initializers (I know...)
            init_map = {literal.kind: literal.kind.type.initial
                        for literal in FindLiterals().visit(variable.type.initial)
                        if hasattr(literal, 'kind') and is_inline_parameter(literal.kind)}
            if init_map:
                initial = SubstituteExpressions(init_map).visit(variable.type.initial)
                routine.symbol_attrs[variable.name] = variable.type.clone(initial=initial)

    # Update imports
    imprtmap = {}
    substituted_names = {v.name.lower() for v in vmap}
    for imprt in FindNodes(Import).visit(routine.spec):
        if imprt.symbols:
            symbols = tuple(s for s in imprt.symbols if s.name.lower() not in substituted_names)
            if not symbols:
                imprtmap[imprt] = Comment(f'! Loki: parameters from {imprt.module} inlined')
            elif len(symbols) < len(imprt.symbols):
                imprtmap[imprt] = imprt.clone(symbols=symbols)

    # Flush mappings through spec and body
    routine.spec = Transformer(imprtmap).visit(routine.spec)
    routine.spec = SubstituteExpressions(vmap).visit(routine.spec)
    routine.body = SubstituteExpressions(vmap).visit(routine.body)

    # Clean up declarations that are about to become defunct
    decl_map = {
        decl: None for decl in routine.declarations
        if all(issubclass(type(s), sym._Literal) for s in decl.symbols)
    }
    routine.spec = Transformer(decl_map).visit(routine.spec)
loki-ecmwf-0.3.6/loki/transformations/inline/mapper.py0000664000175000017500000000757115167130205023256 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.expression import LokiIdentityMapper
from loki.ir import (
    FindNodes, Assignment, StatementFunction, SubstituteExpressions
)
from loki.logging import detail
from loki.types import BasicType


__all__ = ['InlineSubstitutionMapper']


class InlineSubstitutionMapper(LokiIdentityMapper):
    """
    An expression mapper that defines symbolic substitution for inlining.
    """

    def map_algebraic_leaf(self, expr, *args, **kwargs):
        raise NotImplementedError

    def map_scalar(self, expr, *args, **kwargs):
        parent = self.rec(expr.parent, *args, **kwargs) if expr.parent is not None else None

        scope = kwargs.get('scope') or expr.scope
        # We're re-scoping an imported symbol
        if expr.scope != scope:
            return expr.clone(scope=scope, type=expr.type.clone(), parent=parent)
        return expr.clone(parent=parent)

    map_deferred_type_symbol = map_scalar

    def map_array(self, expr, *args, **kwargs):
        if expr.dimensions:
            dimensions = self.rec(expr.dimensions, *args, **kwargs)
        else:
            dimensions = None
        parent = self.rec(expr.parent, *args, **kwargs) if expr.parent is not None else None

        scope = kwargs.get('scope') or expr.scope
        # We're re-scoping an imported symbol
        if expr.scope != scope:
            return expr.clone(scope=scope, type=expr.type.clone(), parent=parent, dimensions=dimensions)
        return expr.clone(parent=parent, dimensions=dimensions)

    def map_procedure_symbol(self, expr, *args, **kwargs):
        parent = self.rec(expr.parent, *args, **kwargs) if expr.parent is not None else None

        scope = kwargs.get('scope') or expr.scope
        # We're re-scoping an imported symbol
        if expr.scope != scope:
            return expr.clone(scope=scope, type=expr.type.clone(), parent=parent)
        return expr.clone(parent=parent)

    def map_inline_call(self, expr, *args, **kwargs):
        if expr.procedure_type in (None, BasicType.DEFERRED) or expr.procedure_type.is_intrinsic:
            # Unkonw inline call, potentially an intrinsic
            # We still need to recurse and ensure re-scoping
            return super().map_inline_call(expr, *args, **kwargs)

        # if it is an inline call to a Statement Function
        if isinstance(expr.routine, StatementFunction):
            function = expr.routine
            # Substitute all arguments through the elemental body
            arg_map = dict(expr.arg_iter())
            fbody = SubstituteExpressions(arg_map).visit(function.rhs)
            return fbody

        function = expr.procedure_type.procedure
        v_result = [v for v in function.variables if v == function.name][0]

        scope = kwargs.get('scope') or expr.function.scope
        if scope and function.name in scope.interface_map:
            # Inline call to a function that is provided via an interface
            # We don't have the function body available for inlining
            detail(f'Cannot inline {expr.function.name} into {scope.name}. Only interface available.')
            return super().map_inline_call(expr, *args, **kwargs)

        # Substitute all arguments through the elemental body
        arg_map = dict(expr.arg_iter())
        fbody = SubstituteExpressions(arg_map).visit(function.body)

        # Extract the RHS of the final result variable assignment
        stmts = [s for s in FindNodes(Assignment).visit(fbody) if s.lhs == v_result]
        assert len(stmts) == 1
        rhs = self.rec(stmts[0].rhs, *args, **kwargs)
        return rhs
loki-ecmwf-0.3.6/loki/transformations/inline/functions.py0000664000175000017500000004025715167130205024000 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from collections import ChainMap

from loki.logging import warning
from loki.expression import symbols as sym, ExpressionRetriever, ExpressionDimensionsMapper
from loki.ir import (
    Transformer, FindNodes, FindVariables, Import, StatementFunction,
    FindInlineCalls, ExpressionFinder, SubstituteExpressions,
    VariableDeclaration
)
from loki.subroutine import Subroutine
from loki.types import BasicType
from loki.tools import as_tuple, OrderedSet

from loki.transformations.inline.mapper import InlineSubstitutionMapper
from loki.transformations.inline.procedures import map_call_to_procedure_body
from loki.transformations.utilities import (
    single_variable_declaration, recursive_expression_map_update
)


__all__ = [
    'inline_elemental_functions', 'inline_functions',
    'inline_statement_functions', 'inline_function_calls'
]


def inline_elemental_functions(routine):
    """
    Replaces `InlineCall` expression to elemental functions with the
    called functions body.

    Parameters
    ----------
    routine : :any:`Subroutine`
         Procedure in which to inline functions.
    """
    inline_functions(routine, inline_elementals_only=True)


def inline_functions(routine, inline_elementals_only=False, functions=None):
    """
    Replaces `InlineCall` expression to functions with the
    called functions body. Nested calls are handled/inlined through
    an iterative approach calling :any:`_inline_functions`.

    Parameters
    ----------
    routine : :any:`Subroutine`
         Procedure in which to inline functions.
    inline_elementals_only : bool, optional
        Inline elemental routines/functions only (default: False).
    functions : tuple, optional
        Inline only functions that are provided here
        (default: None, thus inline all functions).
    """
    potentially_functions_to_be_inlined = True
    while potentially_functions_to_be_inlined:
        potentially_functions_to_be_inlined = _inline_functions(
            routine, inline_elementals_only=inline_elementals_only, functions=functions
        )

def _inline_functions(routine, inline_elementals_only=False, functions=None):
    """
    Replaces `InlineCall` expression to functions with the
    called functions body, but doesn't include nested calls!

    Parameters
    ----------
    routine : :any:`Subroutine`
         Procedure in which to inline functions.
    inline_elementals_only : bool, optional
        Inline elemental routines/functions only (default: False).
    functions : tuple, optional
        Inline only functions that are provided here
        (default: None, thus inline all functions).

    Returns
    -------
    bool
        Whether inline calls are (potentially) left to be
        inlined in the next call to this function.
    """

    def is_array(expr):
        """
        Check whether expr evaluates to an array.
        E.g., for arr(:, :) return True, for arr(1, 1) or arr(jl, jk) return False.
        """
        return any(d != '1' for d in ExpressionDimensionsMapper()(expr))

    class ExpressionRetrieverSkipInlineCallParameters(ExpressionRetriever):
        """
        Expression retriever skipping parameters of inline calls.
        """
        # pylint: disable=abstract-method

        def __init__(self, query, recurse_query=None, inline_elementals_only=False,
                functions=None, **kwargs):
            self.inline_elementals_only = inline_elementals_only
            self.functions = as_tuple(functions)
            super().__init__(query, recurse_query, **kwargs)

        def map_inline_call(self, expr, *args, **kwargs):
            if not self.visit(expr, *args, **kwargs):
                return
            if not expr.procedure_type is BasicType.DEFERRED and expr.procedure_type.is_elemental:
                if any(is_array(val) for val in expr.arg_map.values() if isinstance(val, sym.Array)):
                    warning(f"Call to elemental function '{expr.routine.name}' with array arguments."
                            f' There is currently no support to inline those calls!')
                    return
            self.rec(expr.function, *args, **kwargs)
            # SKIP parameters/args/kwargs on purpose
            #  under certain circumstances
            if expr.procedure_type is BasicType.DEFERRED or\
                    (self.inline_elementals_only and\
                    not(expr.procedure_type.is_function and expr.procedure_type.is_elemental)) or\
                    (self.functions and expr.routine not in self.functions):
                for child in expr.parameters:
                    self.rec(child, *args, **kwargs)
                for child in list(expr.kw_parameters.values()):
                    self.rec(child, *args, **kwargs)

            self.post_visit(expr, *args, **kwargs)

    class FindInlineCallsSkipInlineCallParameters(ExpressionFinder):
        """
        Find inline calls but skip/ignore parameters of inline calls.
        """
        retriever = ExpressionRetrieverSkipInlineCallParameters(
            query=lambda e: isinstance(e, sym.InlineCall),
            inline_elementals_only=inline_elementals_only, functions=functions
        )

    # functions are provided, however functions is empty, thus early exit
    if functions is not None and not functions:
        return False
    functions = as_tuple(functions)

    # Keep track of removed symbols
    removed_functions = set()

    # Find and filter inline calls and corresponding nodes
    function_calls = {}
    # Find inline calls but skip/ignore inline calls being parameters of other inline calls
    # to ensure correct ordering of inlining. Those skipped/ignored inline calls will be handled
    # in the next call to this function.
    for node, calls in FindInlineCallsSkipInlineCallParameters(with_ir_node=True).visit(routine.body):
        for call in calls:
            if call.procedure_type is BasicType.DEFERRED or isinstance(call.routine, StatementFunction):
                continue
            if not call.procedure_type.is_function:
                continue
            if inline_elementals_only and not call.procedure_type.is_elemental:
                continue
            if functions:
                if call.routine not in functions:
                    continue
            function_calls.setdefault(str(call.name).lower(),[]).append((call, node))

    if not function_calls:
        return False

    # inline functions
    node_prepend_map = {}
    call_map = {}
    for calls_nodes in function_calls.values():
        calls, nodes = list(zip(*calls_nodes))
        for call in calls:
            removed_functions.add(call.procedure_type)
        # collect nodes to be appendes as well as expression replacement for inline call
        inline_node_map, inline_call_map = inline_function_calls(routine, as_tuple(calls),
                                                                 calls[0].routine, as_tuple(nodes))
        for node, nodes_to_prepend in inline_node_map.items():
            node_prepend_map.setdefault(node, []).extend(list(nodes_to_prepend))
        call_map.update(inline_call_map)

    # collect nodes to be prepended for each node that contains (at least one) inline call to a function
    node_map = {}
    for node, prepend_nodes in node_prepend_map.items():
        node_map[node] = as_tuple(prepend_nodes) + (SubstituteExpressions(call_map[node]).visit(node),)
    # inline via prepending the relevant functions
    routine.body = Transformer(node_map).visit(routine.body)
    # We need this to ensure that symbols, as well as nested scopes
    # are correctly attached to each other (eg. nested associates).
    routine.rescope_symbols()

    # Remove all module imports that have become obsolete now
    import_map = {}
    for im in FindNodes(Import).visit(routine.spec):
        if im.symbols and all(s.type.dtype in removed_functions for s in im.symbols):
            import_map[im] = None
    routine.spec = Transformer(import_map).visit(routine.spec)
    return True


def inline_statement_functions(routine):
    """
    Replaces :any:`InlineCall` expression to statement functions with the
    called statement functions rhs.
    """
    # Keep track of removed symbols
    removed_functions = set()

    stmt_func_decls = FindNodes(StatementFunction).visit(routine.spec)
    exprmap = {}
    for call in FindInlineCalls().visit(routine.body):
        proc_type = call.procedure_type
        if proc_type is BasicType.DEFERRED:
            continue
        if proc_type.is_function and isinstance(call.routine, StatementFunction):
            exprmap[call] = InlineSubstitutionMapper()(call, scope=routine)
            removed_functions.add(call.routine)
    # Apply the map to itself to handle nested statement function calls
    exprmap = recursive_expression_map_update(exprmap, max_iterations=10, mapper_cls=InlineSubstitutionMapper)
    # Apply expression-level substitution to routine
    routine.body = SubstituteExpressions(exprmap).visit(routine.body)

    # remove statement function declarations as well as statement function argument(s) declarations
    vars_to_remove = {stmt_func.variable.name.lower() for stmt_func in stmt_func_decls}
    vars_to_remove |= {arg.name.lower() for stmt_func in stmt_func_decls for arg in stmt_func.arguments}
    spec_map = {stmt_func: None for stmt_func in stmt_func_decls}
    for decl in routine.declarations:
        if any(var in vars_to_remove for var in decl.symbols):
            symbols = tuple(var for var in decl.symbols if var not in vars_to_remove)
            if symbols:
                decl._update(symbols=symbols)
            else:
                spec_map[decl] = None
    routine.spec = Transformer(spec_map).visit(routine.spec)


def _get_callee_result_var(routine):
    """
    Get or create the result variable for a function, necessary/useful
    since there are multiple ways of specifying the return variable/type.

    Parameters
    ----------
    routine : :any:`Subroutine`
        The function for which to get the result variable.
    """
    if routine.result_name.lower() in routine.variable_map:
        callee_result_var = routine.variable_map[routine.result_name]
    else:
        callee_result_var_type = routine.symbol_attrs[routine.result_name]
        callee_result_var = sym.Variable(name=routine.result_name, type=callee_result_var_type, scope=routine)
    return callee_result_var


def inline_function_calls(routine, calls, callee, nodes, allowed_aliases=None):
    """
    Inline a set of call to an individual :any:`Subroutine` being functions
    at source level.

    This will replace all :any:`InlineCall` objects to the specified
    subroutine with an adjusted equivalent of the member routines'
    body. For this, argument matching, including partial dimension
    matching for array references is performed, and all
    member-specific declarations are hoisted to the containing
    :any:`Subroutine`.

    Parameters
    ----------
    routine : :any:`Subroutine`
        The subroutine in which to inline all calls to the member routine
    calls : tuple or list of :any:`InlineCall`
        Set of calls (to the same callee) to be inlined.
    callee : :any:`Subroutine`
        The called target function to be inlined in the parent
    nodes : :any:`Node`
        The corresponding nodes the functions are called from.
    allowed_aliases : tuple or list of str or :any:`Expression`, optional
        List of variables that will not be renamed in the parent scope, even
        if they alias with a local declaration.
    """

    def rename_result_name(routine, rename):
        callee = routine.clone()
        var_map = {}
        callee_result_var = _get_callee_result_var(callee)
        new_callee_result_var = callee_result_var.clone(name=rename)
        var_map[callee_result_var] = new_callee_result_var
        callee_vars = [var for var in FindVariables().visit(callee.body)
                       if var.name.lower() == callee_result_var.name.lower()]
        var_map.update({var: var.clone(name=rename) for var in callee_vars})
        var_map = recursive_expression_map_update(var_map)
        callee.body = SubstituteExpressions(var_map).visit(callee.body)
        return callee, new_callee_result_var

    allowed_aliases = as_tuple(allowed_aliases)

    # Ensure we process sets of calls to the same callee
    assert all(call.routine == callee for call in calls)
    assert isinstance(callee, Subroutine)

    # Prevent shadowing of callee's variables by renaming them a priori
    parent_variables = routine.variable_map
    duplicates = tuple(
        v for v in callee.variables
        if v.name.lower() != callee.result_name.lower()
        and v.name in parent_variables and v.name.lower() not in callee._dummies
    )
    # Filter out allowed aliases to prevent suffixing
    duplicates = tuple(v for v in duplicates if v.symbol not in allowed_aliases)
    shadow_mapper = SubstituteExpressions(
        {v: v.clone(name=f'{callee.name}_{v.name}') for v in duplicates}
    )
    callee.spec = shadow_mapper.visit(callee.spec)

    var_map = {}
    duplicate_names = {dl.name.lower() for dl in duplicates}
    for v in FindVariables(unique=False).visit(callee.body):
        if v.name.lower() in duplicate_names:
            var_map[v] = v.clone(name=f'{callee.name}_{v.name}')

    var_map = recursive_expression_map_update(var_map)
    callee.body = SubstituteExpressions(var_map).visit(callee.body)

    # Separate allowed aliases from other variables to ensure clean hoisting
    if allowed_aliases:
        single_variable_declaration(callee, variables=allowed_aliases)

    single_variable_declaration(callee, variables=callee.result_name)
    # Get local variable declarations and hoist them
    decls = FindNodes(VariableDeclaration).visit(callee.spec)
    decls = tuple(d for d in decls if all(s.name.lower() != callee.result_name.lower() for s in d.symbols))
    decls = tuple(d for d in decls if all(s.name.lower() not in callee._dummies for s in d.symbols))
    decls = tuple(d for d in decls if all(s not in routine.variables for s in d.symbols))
    # Rescope the declaration symbols
    decls = tuple(d.clone(symbols=tuple(s.clone(scope=routine) for s in d.symbols)) for d in decls)

    # Find and apply symbol remappings for array size expressions
    symbol_map = dict(ChainMap(*[call.arg_map for call in calls]))
    decls = SubstituteExpressions(symbol_map).visit(decls)
    routine.spec.append(decls)

    # Handle result/return var/value
    new_symbols = OrderedSet()
    result_var_map = {}
    adapted_calls = []
    rename_result_var = not len(nodes) == len(set(nodes))
    for i_call, call in enumerate(calls):
        callee_result_var = _get_callee_result_var(callee)
        prefix = ''
        new_callee_result_var_name = f'{prefix}result_{callee.result_name.lower()}_{i_call}'\
                if rename_result_var else f'{prefix}result_{callee.result_name.lower()}'
        new_callee, new_symbol = rename_result_name(callee, new_callee_result_var_name)
        adapted_calls.append(new_callee)
        new_symbols.add(new_symbol)
        if isinstance(callee_result_var, sym.Array):
            result_var_map[(nodes[i_call], call)] = callee_result_var.clone(name=new_callee_result_var_name,
                    dimensions=None)
        else:
            result_var_map[(nodes[i_call], call)] = callee_result_var.clone(name=new_callee_result_var_name)
    new_symbols = SubstituteExpressions(symbol_map).visit(as_tuple(new_symbols), recurse_to_declaration_attributes=True)
    routine.variables += as_tuple([symbol.clone(scope=routine) for symbol in new_symbols])

    # create node map to map nodes to be prepended (representing the functions) for each node
    node_map = {}
    call_map = {}
    for i_call, call in enumerate(calls):
        node_map.setdefault(nodes[i_call], []).extend(
                list(map_call_to_procedure_body(call, caller=routine, callee=adapted_calls[i_call]))
        )
        call_map.setdefault(nodes[i_call], {}).update({call: result_var_map[(nodes[i_call], call)]})
    return node_map, call_map
loki-ecmwf-0.3.6/loki/transformations/transpile/0000775000175000017500000000000015167130205022131 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/transformations/transpile/__init__.py0000664000175000017500000000110415167130205024236 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.transformations.transpile.fortran_c import * # noqa
from loki.transformations.transpile.fortran_iso_c_wrapper import * # noqa
from loki.transformations.transpile.fortran_python import * # noqa
loki-ecmwf-0.3.6/loki/transformations/transpile/tests/0000775000175000017500000000000015167130205023273 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/transformations/transpile/tests/__init__.py0000664000175000017500000000057015167130205025406 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
loki-ecmwf-0.3.6/loki/transformations/transpile/tests/test_scc_cuda.py0000664000175000017500000002565315167130205026463 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from pathlib import Path
import pytest

from loki import Scheduler, Dimension, read_file
from loki.frontend import available_frontends
from loki.ir import nodes as ir, FindNodes

from loki.transformations.transpile import (
    FortranCTransformation, FortranISOCWrapperTransformation
)
from loki.transformations.single_column import (
    SCCLowLevelHoist, SCCLowLevelParametrise
)


@pytest.fixture(scope='module', name='horizontal')
def fixture_horizontal():
    return Dimension(name='horizontal', size='nlon', index='jl', bounds=('start', 'iend'))


@pytest.fixture(scope='module', name='vertical')
def fixture_vertical():
    return Dimension(name='vertical', size='nz', index='jk')


@pytest.fixture(scope='module', name='blocking')
def fixture_blocking():
    return Dimension(name='blocking', size='nb', index='b')


@pytest.fixture(scope='module', name='here')
def fixture_here():
    return Path(__file__).parent


@pytest.fixture(name='config')
def fixture_config():
    """
    Default configuration dict with basic options.
    """
    return {
        'default': {
            'mode': 'idem',
            'role': 'kernel',
            'expand': True,
            'strict': False,  # cudafor import
        },
        'routines': {
            'driver': {'role': 'driver'}
        }
    }

def remove_whitespace_linebreaks(text):
    return text.replace(' ', '').replace('\n', ' ').replace('\r', '').replace('\t', '').lower()

@pytest.mark.parametrize('frontend', available_frontends())
def test_scc_cuda_parametrise(tmp_path, here, frontend, config, horizontal, vertical, blocking):
    """
    Test SCC-CUF transformation type 0, thus including parametrising (array dimension(s))
    """

    proj = here / '../../tests/sources/projSccCuf/module'

    scheduler = Scheduler(
        paths=[proj], config=config, seed_routines=['driver'],
        output_dir=tmp_path, frontend=frontend, xmods=[tmp_path]
    )

    dic2p = {'nz': 137}
    cuda_transform = SCCLowLevelParametrise(
        horizontal=horizontal, vertical=vertical, block_dim=blocking,
        transformation_type='parametrise',
        dim_vars=(vertical.size,), as_kwarguments=True, remove_vector_section=True,
        use_c_ptr=True, dic2p=dic2p, path=here, mode='cuda'
    )
    scheduler.process(transformation=cuda_transform)
    f2c_transformation = FortranCTransformation(language='cuda')
    scheduler.process(transformation=f2c_transformation)
    f2cwrap = FortranISOCWrapperTransformation(language='cuda', use_c_ptr=True)
    scheduler.process(transformation=f2cwrap)

    kernel = scheduler['kernel_mod#kernel'].ir
    kernel_variable_map = kernel.variable_map
    assert kernel_variable_map[horizontal.index].type.intent is None
    assert kernel_variable_map[horizontal.index].scope == kernel
    device = scheduler['kernel_mod#device'].ir
    device_variable_map = device.variable_map
    assert device_variable_map[horizontal.index].type.intent.lower() == 'in'
    assert device_variable_map[horizontal.index].scope == device

    fc_kernel = remove_whitespace_linebreaks(read_file(tmp_path/'kernel_fc.F90'))
    c_kernel = remove_whitespace_linebreaks(read_file(tmp_path/'kernel_c.c'))
    c_kernel_header = remove_whitespace_linebreaks(read_file(tmp_path/'kernel_c.h'))
    c_kernel_launch = remove_whitespace_linebreaks(read_file(tmp_path/'kernel_c_launch.h'))
    c_device = remove_whitespace_linebreaks(read_file(tmp_path/'device_c.c'))
    c_elemental_device = remove_whitespace_linebreaks(read_file(tmp_path/'elemental_device_c.c'))
    c_some_func = remove_whitespace_linebreaks(read_file(tmp_path/'some_func_c.c'))
    c_some_func_header = remove_whitespace_linebreaks(read_file(tmp_path/'some_func_c.h'))

    calls = FindNodes(ir.CallStatement).visit(scheduler["driver_mod#driver"].ir.body)
    assert len(calls) == 3
    for call in calls:
        assert str(call.name).lower() == 'kernel'
        assert call.pragma[0].keyword == 'loki'
        assert 'removed_loop' in call.pragma[0].content
    # kernel_fc.F90
    assert '!$acchost_datause_device(q,t,z)' in fc_kernel
    assert 'kernel_iso_c(start,nlon,c_loc(q),c_loc(t),c_loc(z),nb,tot,iend)' in fc_kernel
    assert 'bind(c,name="kernel_c_launch")' in fc_kernel
    assert 'useiso_c_binding' in fc_kernel
    # kernel_c.c
    assert '#include' in c_kernel
    assert '#include' in c_kernel
    assert '#include"kernel_c.h"' in c_kernel
    assert '#include"kernel_c_launch.h"' in c_kernel
    assert 'include"elemental_device_c.h"' in c_kernel
    assert 'include"device_c.h"' in c_kernel
    assert 'include"some_func_c.h"' in c_kernel
    assert '__global__voidkernel_c' in c_kernel
    assert 'jl=threadidx.x;' in c_kernel
    assert 'b=blockidx.x;' in c_kernel
    assert 'device_c(' in c_kernel
    assert 'elemental_device_c(' in c_kernel
    assert '=some_func_c(' in c_kernel
    # kernel_c.h
    assert '__global__voidkernel_c' in c_kernel_header
    assert 'jl=threadidx.x;' not in c_kernel_header
    assert 'b=blockidx.x;' not in c_kernel_header
    # kernel_c_launch.h
    assert 'extern"c"' in c_kernel_launch
    assert 'voidkernel_c_launch(' in c_kernel_launch
    assert 'structdim3blockdim;' in c_kernel_launch
    assert 'structdim3griddim;' in c_kernel_launch
    assert 'griddim=dim3(' in c_kernel_launch
    assert 'blockdim=dim3(' in c_kernel_launch
    assert 'kernel_c<<>>(' in c_kernel_launch
    assert 'cudadevicesynchronize();' in c_kernel_launch
    # device_c.c
    assert '#include' in c_device
    assert '#include' in c_device
    assert '#include"device_c.h"' in c_device
    # elemental_device_c.c
    assert '__device__voiddevice_c(' in c_device
    assert '#include' in c_elemental_device
    assert '#include' in c_elemental_device
    assert '#include"elemental_device_c.h"' in c_elemental_device
    # some_func_c.c
    assert 'doublesome_func_c(doublea)' in c_some_func
    assert 'returnsome_func' in c_some_func
    # some_func_c.h
    assert 'doublesome_func_c(doublea);' in c_some_func_header


@pytest.mark.parametrize('frontend', available_frontends())
def test_scc_cuda_hoist(tmp_path, here, frontend, config, horizontal, vertical, blocking):
    """
    Test SCC-CUF transformation type 0, thus including parametrising (array dimension(s))
    """

    proj = here / '../../tests/sources/projSccCuf/module'

    scheduler = Scheduler(
        paths=[proj], config=config, seed_routines=['driver'],
        output_dir=tmp_path, frontend=frontend, xmods=[tmp_path]
    )

    cuda_transform = SCCLowLevelHoist(
        horizontal=horizontal, vertical=vertical, block_dim=blocking,
        transformation_type='parametrise',
        dim_vars=(vertical.size,), as_kwarguments=True, remove_vector_section=True,
        use_c_ptr=True, path=here, mode='cuda'
    )
    scheduler.process(transformation=cuda_transform)
    f2c_transformation = FortranCTransformation(language='cuda')
    scheduler.process(transformation=f2c_transformation)
    f2cwrap = FortranISOCWrapperTransformation(language='cuda', use_c_ptr=True)
    scheduler.process(transformation=f2cwrap)

    kernel = scheduler['kernel_mod#kernel'].ir
    kernel_variable_map = kernel.variable_map
    assert kernel_variable_map[horizontal.index].type.intent is None
    assert kernel_variable_map[horizontal.index].scope == kernel
    device = scheduler['kernel_mod#device'].ir
    device_variable_map = device.variable_map
    assert device_variable_map[horizontal.index].type.intent.lower() == 'in'
    assert device_variable_map[horizontal.index].scope == device

    fc_kernel = remove_whitespace_linebreaks(read_file(tmp_path/'kernel_fc.F90'))
    c_kernel = remove_whitespace_linebreaks(read_file(tmp_path/'kernel_c.c'))
    c_kernel_header = remove_whitespace_linebreaks(read_file(tmp_path/'kernel_c.h'))
    c_kernel_launch = remove_whitespace_linebreaks(read_file(tmp_path/'kernel_c_launch.h'))
    c_device = remove_whitespace_linebreaks(read_file(tmp_path/'device_c.c'))
    c_elemental_device = remove_whitespace_linebreaks(read_file(tmp_path/'elemental_device_c.c'))
    c_some_func = remove_whitespace_linebreaks(read_file(tmp_path/'some_func_c.c'))
    c_some_func_header = remove_whitespace_linebreaks(read_file(tmp_path/'some_func_c.h'))

    calls = FindNodes(ir.CallStatement).visit(scheduler["driver_mod#driver"].ir.body)
    assert len(calls) == 3
    for call in calls:
        assert str(call.name).lower() == 'kernel'
        assert call.pragma[0].keyword == 'loki'
        assert 'removed_loop' in call.pragma[0].content
    # kernel_fc.F90
    assert '!$acchost_datause_device(q,t,z,local_z,device_local_x)' in fc_kernel
    assert 'kernel_iso_c(start,nlon,nz,c_loc(q),c_loc(t),c_loc(z)' in fc_kernel
    assert 'c_loc(z),nb,tot,iend,c_loc(local_z),c_loc(device_local_x))' in fc_kernel
    assert 'bind(c,name="kernel_c_launch")' in fc_kernel
    assert 'useiso_c_binding' in fc_kernel
    # kernel_c.c
    assert '#include' in c_kernel
    assert '#include' in c_kernel
    assert '#include"kernel_c.h"' in c_kernel
    assert '#include"kernel_c_launch.h"' in c_kernel
    assert '#include"elemental_device_c.h"' in c_kernel
    assert '#include"device_c.h"' in c_kernel
    assert 'include"some_func_c.h"' in c_kernel
    assert '__global__voidkernel_c' in c_kernel
    assert 'jl=threadidx.x;' in c_kernel
    assert 'b=blockidx.x;' in c_kernel
    assert 'device_c(' in c_kernel
    assert 'elemental_device_c(' in c_kernel
    assert '=some_func_c(' in c_kernel
    # kernel_c.h
    assert '__global__voidkernel_c' in c_kernel_header
    assert 'jl=threadidx.x;' not in c_kernel_header
    assert 'b=blockidx.x;' not in c_kernel_header
    # kernel_c_launch.h
    assert 'extern"c"' in c_kernel_launch
    assert 'voidkernel_c_launch(' in c_kernel_launch
    assert 'structdim3blockdim;' in c_kernel_launch
    assert 'structdim3griddim;' in c_kernel_launch
    assert 'griddim=dim3(' in c_kernel_launch
    assert 'blockdim=dim3(' in c_kernel_launch
    assert 'kernel_c<<>>(' in c_kernel_launch
    assert 'cudadevicesynchronize();' in c_kernel_launch
    # device_c.c
    assert '#include' in c_device
    assert '#include' in c_device
    assert '#include"device_c.h"' in c_device
    assert '__device__voiddevice_c(' in c_device
    # elemental_device_c.c
    assert '#include' in c_elemental_device
    assert '#include' in c_elemental_device
    assert '#include"elemental_device_c.h"' in c_elemental_device
    # some_func_c.c
    assert 'doublesome_func_c(doublea)' in c_some_func
    assert 'returnsome_func' in c_some_func
    # some_func_c.h
    assert 'doublesome_func_c(doublea);' in c_some_func_header
loki-ecmwf-0.3.6/loki/transformations/transpile/tests/test_sdfg.py0000664000175000017500000002727315167130205025642 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import importlib
import itertools
from pathlib import Path
import numpy as np
import pytest

from loki import Subroutine
from loki.jit_build import jit_compile
from loki.frontend import available_frontends

from loki.transformations.transpile import FortranPythonTransformation


pytestmark = [
    # Skip tests if dace is not installed
    pytest.mark.skipif(
      importlib.util.find_spec('dace') is None,
      reason='DaCe is not installed'
    ),
    # Disable warnings from Dace about np.bool being deprecated
    pytest.mark.filterwarnings(
        "ignore:`np.bool` is a deprecated alias:DeprecationWarning"
    )
]


def load_module(path):
    path = Path(path)

    # Trigger the actual module import
    try:
        return importlib.import_module(path.stem)
    except ModuleNotFoundError:
        # If module caching interferes, try again with clean caches
        importlib.invalidate_caches()
        return importlib.import_module(path.stem)


def create_sdfg(routine, tmp_path):
    trafo = FortranPythonTransformation(with_dace=True, suffix='_py')
    routine.apply(trafo, path=tmp_path)

    mod = load_module(trafo.py_path)
    function = getattr(mod, routine.name)
    return function.to_sdfg()


@pytest.mark.parametrize('frontend', available_frontends())
def test_sdfg_routine_copy(tmp_path, frontend):

    fcode = """
subroutine routine_copy(n, x, y)
  ! A simple routine that copies the values of x to y
  use iso_fortran_env, only: real64
  implicit none
  real(kind=real64), intent(in) :: x(n)
  real(kind=real64), intent(out) :: y(n)
  integer, intent(in) :: n
  integer :: i

  do i=1,n
    y(i) = x(i)
  enddo
end subroutine routine_copy
    """
    routine = Subroutine.from_source(fcode, frontend=frontend)

    # Test the reference solution
    filepath = tmp_path/(f'routine_copy_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname='routine_copy')

    n = 64
    x_ref = np.array(range(n), dtype=np.float64)
    x = np.zeros(n, dtype=np.float64)
    x[:] = x_ref[:]
    y = np.zeros(n, dtype=np.float64)
    function(n=n, x=x, y=y)
    assert all(x_ref == y)

    # Create and compile the SDFG
    sdfg = create_sdfg(routine, tmp_path)
    assert sdfg.validate() is None

    csdfg = sdfg.compile()
    assert csdfg

    # Run the SDFG
    x[:] = x_ref[:]
    csdfg(n=np.int32(n), x=x, y=y)
    assert all(x_ref == y)


@pytest.mark.xfail(reason='Scalar inout arguments do not work in dace')
@pytest.mark.filterwarnings('ignore:The value of the smallest subnormal.*class \'numpy.float64\':UserWarning')
@pytest.mark.parametrize('frontend', available_frontends())
def test_sdfg_routine_axpy_scalar(tmp_path, frontend):

    fcode = """
subroutine routine_axpy_scalar(a, x, y)
  ! A simple standard routine that computes x = a * x + y for
  ! scalar arguments
  use iso_fortran_env, only: real64
  implicit none
  real(kind=real64), intent(in) :: a, y
  real(kind=real64), intent(inout) :: x

  x = a * x + y
end subroutine routine_axpy_scalar
    """
    routine = Subroutine.from_source(fcode, frontend=frontend)

    # Test the reference solution
    filepath = tmp_path/(f'sdfg_routine_axpy_scalar_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname='routine_axpy_scalar')

    a = np.float64(23)
    x = np.float64(42)
    x_out = np.array([x], dtype=np.float64)
    y = np.float64(5)
    function(a=a, x=x_out, y=y)
    assert x_out == a * x + y

    # Create and compile the SDFG
    sdfg = create_sdfg(routine, tmp_path)
    assert sdfg.validate() is None

    csdfg = sdfg.compile()
    assert csdfg

    # Run the SDFG
    x_out = np.array([x], dtype=np.float64)
    csdfg(a=a, x=x_out, y=y)
    assert x_out == a * x + y


@pytest.mark.parametrize('frontend', available_frontends())
def test_sdfg_routine_copy_stream(tmp_path, frontend):

    fcode = """
subroutine routine_copy_stream(length, alpha, vector_in, vector_out)
  implicit none
  ! A simple standard looking routine to test argument declarations
  ! and generator toolchain
  integer, intent(in) :: length, alpha(1), vector_in(length)
  integer, intent(out) :: vector_out(length)
  integer :: i

  !$loki dataflow
  do i=1, length
    vector_out(i) = vector_in(i) + alpha(1)
  end do
end subroutine routine_copy_stream
    """
    routine = Subroutine.from_source(fcode, frontend=frontend)
    # TODO: make alpha a true scalar, which doesn't seem to work with SDFG at the moment???

    # Test the reference solution
    filepath = tmp_path/(f'sdfg_routine_copy_stream_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname='routine_copy_stream')

    length = 32
    alpha = np.array([7], dtype=np.int32)
    vector_in = np.array(range(length), order='F', dtype=np.int32)
    vector_out = np.zeros(length, order='F', dtype=np.int32)
    function(length=length, alpha=alpha, vector_in=vector_in, vector_out=vector_out)
    assert np.all(vector_out == np.array(range(length)) + alpha)

    # Create and compile the SDFG
    sdfg = create_sdfg(routine, tmp_path)
    assert sdfg.validate() is None

    csdfg = sdfg.compile()
    assert csdfg

    # Run the SDFG
    vec_in = np.array(range(length), order='F', dtype=np.intc)
    vec_out = np.zeros(length, order='F', dtype=np.intc)
    csdfg(length=length, alpha=alpha, vector_in=vec_in, vector_out=vec_out)
    assert np.all(vec_out == np.array(range(length)) + alpha)


@pytest.mark.parametrize('frontend', available_frontends())
def test_sdfg_routine_fixed_loop(tmp_path, frontend):

    fcode = """
subroutine routine_fixed_loop(scalar, vector, vector_out, tensor, tensor_out)
  use iso_fortran_env, only: real64
  implicit none
  ! integer :: n=6, m=4
  real(kind=real64), intent(in) :: scalar(1)
  real(kind=real64), intent(in) :: tensor(6, 4), vector(6)
  real(kind=real64), intent(out) :: tensor_out(4, 6), vector_out(6)
  integer :: i, j

  ! For testing, the operation is:
  !$loki dataflow
  do j=1, 6
     vector_out(j) = vector(j) + tensor(j, 1) + 1.0
     !$loki dataflow
     do i=1, 4
        tensor_out(i, j) = tensor(j, i)
     end do
  end do
end subroutine routine_fixed_loop
    """
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/(f'sdfg_routine_fixed_loop_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname='routine_fixed_loop')

    # Test the reference solution
    n, m = 6, 4
    scalar = np.array([2.0], dtype=np.float64)
    vector = np.zeros(shape=(n,), order='F') + 3.
    tensor = np.array([list(range(i, i+m)) for i in range(n)], order='F', dtype=np.float64)
    tensor_out = np.zeros(shape=(m, n), order='F')
    ref_vector = vector + np.array(list(range(n)), dtype=np.float64) + 1.
    ref_tensor = np.transpose(tensor)
    function(scalar=scalar, vector=vector, vector_out=vector, tensor=tensor, tensor_out=tensor_out)
    assert np.all(vector == ref_vector)
    assert np.all(tensor_out == ref_tensor)

    # Create and compile the SDFG
    sdfg = create_sdfg(routine, tmp_path)
    assert sdfg.validate() is None

    csdfg = sdfg.compile()
    assert csdfg

    # Test the transpiled kernel
    n, m = 6, 4
    scalar = np.array([2.0], dtype=np.float64)
    vector = np.zeros(shape=(n,), order='F') + 3.
    tensor = np.zeros(shape=(n, m), order='F') + 4.
    tensor = np.array([list(range(i, i+m)) for i in range(n)], order='F', dtype=np.float64)
    tensor_out = np.zeros(shape=(m, n), order='F')
    csdfg(scalar=scalar, vector=vector, vector_out=vector, tensor=tensor, tensor_out=tensor_out)
    assert np.all(vector == ref_vector)
    assert np.all(tensor_out == ref_tensor)


@pytest.mark.skip(reason=('This translates successfully but the generated OpenMP code does not '
                          'honour the loop-carried dependency, thus creating data races for more '
                          'than 1 thread.'))
@pytest.mark.parametrize('frontend', available_frontends())
def test_sdfg_routine_loop_carried_dependency(tmp_path, frontend):

    fcode = """
subroutine routine_loop_carried_dependency(vector)
  use iso_fortran_env, only: real64
  implicit none
  real(kind=real64), intent(inout) :: vector(32)
  integer :: i

  !$loki dataflow
  do i=2, 32
     vector(i) = vector(i) + vector(i-1)
  end do
end subroutine routine_loop_carried_dependency
    """
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/(f'sdfg_routine_loop_carried_dependency_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname='routine_loop_carried_dependency')

    # Test the reference solution
    n = 32
    vector = np.zeros(shape=(n,), order='F') + 3.
    ref_vector = np.array(list(itertools.accumulate(vector)))
    function(vector=vector)
    assert np.all(vector == ref_vector)

    # Create and compile the SDFG
    sdfg = create_sdfg(routine, tmp_path)
    assert sdfg.validate() is None

    csdfg = sdfg.compile()
    assert csdfg

    # Test the transpiled kernel
    n = 32
    vector = np.zeros(shape=(n,), order='F') + 3.
    ref_vector = np.array(list(itertools.accumulate(vector)))
    csdfg(vector=vector)
    assert np.all(vector == ref_vector)


@pytest.mark.parametrize('frontend', available_frontends())
def test_sdfg_routine_moving_average(tmp_path, frontend):
    # TODO: This needs more work to properly handle boundary values.
    # In the current form, these values seem to be handled in a way
    # that causes race conditions. Either this is a DaCe bug or we are
    # using DaCe wrong tmp_path.

    fcode = """
subroutine routine_moving_average(length, data_in, data_out)
  use iso_fortran_env, only: real64
  implicit none
  integer, intent(in) :: length
  real(kind=real64), intent(in) :: data_in(length)
  real(kind=real64), intent(out) :: data_out(length)
  integer :: i
  real(kind=real64) :: prev, next, divisor, incr

  data_out(1) = (data_in(1) + data_in(2)) / 2.0

  !$loki dataflow
  do i=2, length-1
    ! TODO: range check prohibits this for some reason
    incr = 1.0
    divisor = 2.0
    if (i > 1) then
      prev = data_in(i-1)
      ! divisor = 2.0
    else
      divisor = divisor - incr
      prev = 0
      ! divisor = 1.0
    end if
    if (i < length) then
      next = data_in(i+1)
      divisor = divisor + incr
    else
      next = 0
    end if
    data_out(i) = (prev + data_in(i) + next) / divisor
  end do

  data_out(length) = (data_in(length-1) + data_in(length)) / 2.0
end subroutine routine_moving_average
    """
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/(f'sdfg_routine_moving_average_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname='routine_moving_average')

    # Create random input data
    n = 32
    data_in = np.array(np.random.rand(n), order='F')

    # Compute reference solution
    expected = np.zeros(shape=(n,), order='F')
    expected[0] = (data_in[0] + data_in[1]) / 2.
    expected[1:-1] = (data_in[:-2] + data_in[1:-1] + data_in[2:]) / 3.
    expected[-1] = (data_in[-2] + data_in[-1]) / 2.

    # Test the Fortran kernel
    data_out = np.zeros(shape=(n,), order='F')
    function(length=n, data_in=data_in, data_out=data_out)
    assert np.all(data_out[1:-1] == expected[1:-1])

    # Create and compile the SDFG
    sdfg = create_sdfg(routine, tmp_path)
    assert sdfg.validate() is None

    csdfg = sdfg.compile()
    assert csdfg

    # Test the transpiled kernel
    data_out = np.zeros(shape=(n,), order='F')
    csdfg(length=n, data_in=data_in, data_out=data_out)
    assert np.all(data_out[1:-1] == expected[1:-1])
loki-ecmwf-0.3.6/loki/transformations/transpile/tests/test_transpile.py0000664000175000017500000015337315167130205026721 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest
import numpy as np

from loki import Subroutine, Module, cgen, cppgen, cudagen, FindNodes
from loki.jit_build import jit_compile, jit_compile_lib, clean_test, Builder, Obj
import loki.expression.symbols as sym
from loki.frontend import available_frontends
from loki import ir

from loki.transformations.transpile import (
    FortranCTransformation, FortranISOCWrapperTransformation
)

# pylint: disable=too-many-lines


def wrapperpath(path, module_or_routine):
    """
    Utility that generates the ``_fc.F90`` path for Fortran wrappers
    """
    name = f'{module_or_routine.name}_fc'
    return (path/name.lower()).with_suffix('.F90')


def cpath(path, module_or_routine, suffix='.c'):
    """
    Utility that generates the ``_c.h`` path for Fortran wrappers
    """
    name = f'{module_or_routine.name}_c'
    return (path/name.lower()).with_suffix(suffix)


@pytest.fixture(scope='function', name='builder')
def fixture_builder(tmp_path):
    yield Builder(source_dirs=tmp_path, build_dir=tmp_path)
    Obj.clear_cache()


def test_transpile_unsupported_lang():
    """
    A simple test for testing failure/exception for unsupported
    language(s).
    """
    with pytest.raises(ValueError):
        FortranCTransformation(language='not-supported')


@pytest.mark.parametrize('case_sensitive', (False, True))
@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('language', ('c', 'cuda'))
def test_transpile_case_sensitivity(tmp_path, frontend, case_sensitive, language):
    """
    A simple test for testing lowering the case and case-sensitivity
    for specific symbols.
    """

    fcode = """
subroutine transpile_case_sensitivity(a)
    integer, intent(in) :: a

end subroutine transpile_case_sensitivity
"""
    def convert_case(_str, case_sensitive):
        return _str.lower() if not case_sensitive else _str

    routine = Subroutine.from_source(fcode, frontend=frontend)

    var_thread_idx = sym.Variable(name="threadIdx", case_sensitive=case_sensitive)
    var_x = sym.Variable(name="x", parent=var_thread_idx, case_sensitive=case_sensitive)
    assignment = ir.Assignment(lhs=routine.variable_map['a'], rhs=var_x)
    routine.arguments=routine.arguments + (routine.arguments[0].clone(name='sOmE_vAr', case_sensitive=case_sensitive),
            sym.Variable(name="oTher_VaR", case_sensitive=case_sensitive, type=routine.arguments[0].type.clone()))

    call = ir.CallStatement(sym.Variable(name='somE_cALl', case_sensitive=case_sensitive),
            arguments=(routine.variable_map['a'],))
    inline_call = sym.InlineCall(function=sym.Variable(name='somE_InlINeCaLl', case_sensitive=case_sensitive),
            parameters=(sym.IntLiteral(1),))
    inline_call_assignment = ir.Assignment(lhs=routine.variable_map['a'], rhs=inline_call)
    routine.body = (routine.body, assignment, call, inline_call_assignment)

    f2c = FortranCTransformation(language=language)
    f2c.apply(source=routine, path=tmp_path)
    ccode = cpath(tmp_path, routine).read_text().replace(' ', '').replace('\n', ' ').replace('\r', '').replace('\t', '')
    assert convert_case('transpile_case_sensitivity_c(inta,intsOmE_vAr,intoTher_VaR)', case_sensitive) in ccode
    assert convert_case('a=threadIdx%x;', case_sensitive) in ccode
    assert convert_case('somE_cALl(a);', case_sensitive) in ccode
    assert convert_case('a=somE_InlINeCaLl(1);', case_sensitive) in ccode


@pytest.mark.parametrize('use_c_ptr', (False, True))
@pytest.mark.parametrize('frontend', available_frontends())
def test_transpile_simple_loops(tmp_path, builder, frontend, use_c_ptr):
    """
    A simple test routine to test C transpilation of loops
    """

    fcode = """
subroutine simple_loops(n, m, scalar, vector, tensor)
  use iso_fortran_env, only: real64
  implicit none
  integer, intent(in) :: n, m
  real(kind=real64), intent(inout) :: scalar
  real(kind=real64), intent(inout) :: vector(n), tensor(n, m)

  integer :: i, j

  ! For testing, the operation is:
  do i=1, n
     vector(i) = vector(i) + tensor(i, 1) + 1.0
  end do

  do j=1, m
     do i=1, n
        tensor(i, j) = 10.* j + i
     end do
  end do
end subroutine simple_loops
"""

    # Generate reference code, compile run and verify
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/(f'simple_loops{"_c_ptr" if use_c_ptr else ""}_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname='simple_loops')

    n, m = 3, 4
    scalar = 2.0
    vector = np.zeros(shape=(n,), order='F') + 3.
    tensor = np.zeros(shape=(n, m), order='F') + 4.
    function(n, m, scalar, vector, tensor)

    assert np.all(vector == 8.)
    assert np.all(tensor == [[11., 21., 31., 41.],
                             [12., 22., 32., 42.],
                             [13., 23., 33., 43.]])

    # Generate and test the transpiled C kernel
    f2c = FortranCTransformation()
    f2c.apply(source=routine, path=tmp_path)
    f2cwrap = FortranISOCWrapperTransformation(use_c_ptr=use_c_ptr)
    f2cwrap.apply(source=routine, path=tmp_path)

    libname = f'fc_{routine.name}{"_c_ptr" if use_c_ptr else ""}_{frontend}'
    c_kernel = jit_compile_lib(
        [wrapperpath(tmp_path, routine), cpath(tmp_path, routine)],
        path=tmp_path, name=libname, builder=builder
    )
    fc_function = c_kernel.simple_loops_fc_mod.simple_loops_fc

    # check the generated F2C wrapper
    with open(wrapperpath(tmp_path, routine), 'r') as f2c_f:
        f2c_str = f2c_f.read().upper().replace(' ', '')
        if use_c_ptr:
            assert f2c_str.count('TARGET') == 2
            assert f2c_str.count('C_LOC') == 3
            assert 'VECTOR(:)' in f2c_str
            assert 'TENSOR(:,:)' in f2c_str
        else:
            assert f2c_str.count('TARGET') == 0
            assert f2c_str.count('C_LOC') == 0
            assert 'VECTOR(N)' in f2c_str
            assert 'TENSOR(N,M)' in f2c_str

    n, m = 3, 4
    scalar = 2.0
    vector = np.zeros(shape=(n,), order='F') + 3.
    tensor = np.zeros(shape=(n, m), order='F') + 4.
    fc_function(n, m, scalar, vector, tensor)

    assert np.all(vector == 8.)
    assert np.all(tensor == [[11., 21., 31., 41.],
                             [12., 22., 32., 42.],
                             [13., 23., 33., 43.]])


@pytest.mark.parametrize('use_c_ptr', (False, True))
@pytest.mark.parametrize('frontend', available_frontends())
def test_transpile_arguments(tmp_path, builder, frontend, use_c_ptr):
    """
    A test the correct exchange of arguments with varying intents
    """

    fcode = """
subroutine transpile_arguments(n, array, array_io, a, b, c, a_io, b_io, c_io)
  use iso_fortran_env, only: real32, real64
  implicit none

  integer, intent(in) :: n
  real(kind=real64), intent(inout) :: array(n)
  real(kind=real64), intent(out) :: array_io(n)

  integer, intent(out) :: a
  real(kind=real32), intent(out) :: b
  real(kind=real64), intent(out) :: c
  integer, intent(inout) :: a_io
  real(kind=real32), intent(inout) :: b_io
  real(kind=real64), intent(inout) :: c_io

  integer :: i

  do i=1, n
     array(i) = 3.
     array_io(i) = array_io(i) + 3.
  end do

  a = 2**3
  b = 3.2_real32
  c = 4.1_real64

  a_io = a_io + 2
  b_io = b_io + real(3.2, kind=real32)
  c_io = c_io + 4.1
end subroutine transpile_arguments
"""

    # Test the reference solution
    n = 3
    array = np.zeros(shape=(n,), order='F')
    array_io = np.zeros(shape=(n,), order='F') + 3.
    # To do scalar inout we allocate data in single-element arrays
    a_io = np.array(1)
    b_io = np.array(2.)
    c_io = np.array(3.)

    # Generate reference code, compile run and verify
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/(f'transpile_arguments{"_c_ptr" if use_c_ptr else ""}_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname='transpile_arguments')
    a, b, c = function(n, array, array_io, a_io, b_io, c_io)

    assert np.all(array == 3.) and array.size == n
    assert np.all(array_io == 6.)
    assert a_io == 3. and np.isclose(b_io, 5.2) and np.isclose(c_io, 7.1)
    assert a == 8 and np.isclose(b, 3.2) and np.isclose(c, 4.1)

    # Generate and test the transpiled C kernel
    f2c = FortranCTransformation()
    f2c.apply(source=routine, path=tmp_path)
    f2cwrap = FortranISOCWrapperTransformation(use_c_ptr=use_c_ptr)
    f2cwrap.apply(source=routine, path=tmp_path)

    libname = f'fc_{routine.name}{"_c_ptr" if use_c_ptr else ""}_{frontend}'
    c_kernel = jit_compile_lib(
        [wrapperpath(tmp_path, routine), cpath(tmp_path, routine)],
        path=tmp_path, name=libname, builder=builder
    )
    fc_function = c_kernel.transpile_arguments_fc_mod.transpile_arguments_fc

    # check the generated F2C wrapper
    with open(wrapperpath(tmp_path, routine), 'r') as f2c_f:
        f2c_str = f2c_f.read().upper().replace(' ', '')
        if use_c_ptr:
            assert f2c_str.count('TARGET') == 2
            assert f2c_str.count('C_LOC') == 3
            assert 'ARRAY(:)' in f2c_str
            assert 'ARRAY_IO(:)' in f2c_str
        else:
            assert f2c_str.count('TARGET') == 0
            assert f2c_str.count('C_LOC') == 0
            assert 'ARRAY(N)' in f2c_str
            assert 'ARRAY_IO(N)' in f2c_str

    array = np.zeros(shape=(n,), order='F')
    array_io = np.zeros(shape=(n,), order='F') + 3.
    a_io = np.array(1)
    b_io = np.array(2.)
    c_io = np.array(3.)
    a, b, c = fc_function(n, array, array_io, a_io, b_io, c_io)

    assert np.all(array == 3.) and array.size == n
    assert np.all(array_io == 6.)
    assert a_io == 3. and np.isclose(b_io, 5.2) and np.isclose(c_io, 7.1)
    assert a == 8 and np.isclose(b, 3.2) and np.isclose(c, 4.1)


@pytest.mark.parametrize('use_c_ptr', (False, True))
@pytest.mark.parametrize('frontend', available_frontends())
def test_transpile_derived_type(tmp_path, builder, frontend, use_c_ptr):
    """
    Tests handling and type-conversion of various argument types
    """

    fcode_type = """
module transpile_type_mod
    use iso_fortran_env, only: real32, real64
    implicit none

    type my_struct
        integer :: a
        real(kind=real32) :: b
        real(kind=real64) :: c
    end type my_struct
end module transpile_type_mod
    """.strip()

    fcode_routine = """
subroutine transp_der_type(a_struct)
    use transpile_type_mod, only: my_struct
    implicit none
    type(my_struct), intent(inout) :: a_struct

    a_struct%a = a_struct%a + 4
    a_struct%b = a_struct%b + 5.
    a_struct%c = a_struct%c + 6.
end subroutine transp_der_type
    """.strip()

    module = Module.from_source(fcode_type, frontend=frontend, xmods=[tmp_path])
    routine = Subroutine.from_source(fcode_routine, definitions=module, frontend=frontend, xmods=[tmp_path])
    refname = f'ref_{routine.name}{"_c_ptr" if use_c_ptr else ""}_{frontend}'
    reference = jit_compile_lib([module, routine], path=tmp_path, name=refname, builder=builder)

    # Test the reference solution
    a_struct = reference.transpile_type_mod.my_struct()
    a_struct.a = 4
    a_struct.b = 5.
    a_struct.c = 6.
    reference.transp_der_type(a_struct)
    assert a_struct.a == 8
    assert a_struct.b == 10.
    assert a_struct.c == 12.

    # Translate the header module to expose parameters
    f2cwrap = FortranISOCWrapperTransformation(use_c_ptr=use_c_ptr)
    f2cwrap.apply(source=module, path=tmp_path, role='header')

    # Create transformation object and apply
    f2c = FortranCTransformation()
    f2c.apply(source=routine, path=tmp_path, role='kernel')
    f2cwrap.apply(source=routine, path=tmp_path, role='kernel')

    # Build and wrap the cross-compiled library
    sources = [module, wrapperpath(tmp_path, routine), cpath(tmp_path, routine)]
    libname = f'fc_{routine.name}{"_c_ptr" if use_c_ptr else ""}_{frontend}'
    c_kernel = jit_compile_lib(sources=sources, path=tmp_path, name=libname, builder=builder)

    a_struct = c_kernel.transpile_type_mod.my_struct()
    a_struct.a = 4
    a_struct.b = 5.
    a_struct.c = 6.
    function = c_kernel.transp_der_type_fc_mod.transp_der_type_fc
    function(a_struct)
    assert a_struct.a == 8
    assert a_struct.b == 10.
    assert a_struct.c == 12.


@pytest.mark.parametrize('use_c_ptr', (False, True))
@pytest.mark.parametrize('frontend', available_frontends())
def test_transpile_associates(tmp_path, builder, frontend, use_c_ptr):
    """
    Tests C-transpilation of associate statements
    """

    fcode_type = """
module assoc_type_mod
    use iso_fortran_env, only: real32, real64
    implicit none

    type my_struct
        integer :: a
        real(kind=real32) :: b
        real(kind=real64) :: c
    end type my_struct
end module assoc_type_mod
    """.strip()

    fcode_routine = """
subroutine transp_assoc(a_struct)
    use assoc_type_mod, only: my_struct
    implicit none
    type(my_struct), intent(inout) :: a_struct

    associate(a_struct_a=>a_struct%a, a_struct_b=>a_struct%b,&
            & a_struct_c=>a_struct%c)
        a_struct%a = a_struct_a + 4.
        a_struct_b = a_struct%b + 5.
        a_struct_c = a_struct_a + a_struct%b + a_struct_c
    end associate
end subroutine transp_assoc
    """.strip()

    module = Module.from_source(fcode_type, frontend=frontend, xmods=[tmp_path])
    routine = Subroutine.from_source(fcode_routine, definitions=module, frontend=frontend, xmods=[tmp_path])
    refname = f'ref_{routine.name}{"_c_ptr" if use_c_ptr else ""}_{frontend}'
    reference = jit_compile_lib([module, routine], path=tmp_path, name=refname, builder=builder)

    # Test the reference solution
    a_struct = reference.assoc_type_mod.my_struct()
    a_struct.a = 4
    a_struct.b = 5.
    a_struct.c = 6.
    reference.transp_assoc(a_struct)
    assert a_struct.a == 8
    assert a_struct.b == 10.
    assert a_struct.c == 24.

    # Translate the header module to expose parameters
    f2cwrap = FortranISOCWrapperTransformation(use_c_ptr=use_c_ptr)
    f2cwrap.apply(source=module, path=tmp_path, role='header')

    # Create transformation object and apply
    f2c = FortranCTransformation()
    f2c.apply(source=routine, path=tmp_path, role='kernel')
    f2cwrap.apply(source=routine, path=tmp_path, role='kernel')

    # Build and wrap the cross-compiled library
    sources = [module, wrapperpath(tmp_path, routine), cpath(tmp_path, routine)]
    libname = f'fc_{routine.name}{"_c_ptr" if use_c_ptr else ""}_{frontend}'
    c_kernel = jit_compile_lib(sources=sources, path=tmp_path, name=libname, builder=builder)

    a_struct = c_kernel.assoc_type_mod.my_struct()
    a_struct.a = 4
    a_struct.b = 5.
    a_struct.c = 6.
    function = c_kernel.transp_assoc_fc_mod.transp_assoc_fc
    function(a_struct)
    assert a_struct.a == 8
    assert a_struct.b == 10.
    assert a_struct.c == 24.


@pytest.mark.skip(reason='More thought needed on how to test structs-of-arrays')
def test_transpile_derived_type_array():
    """
    Tests handling of multi-dimensional arrays and pointers.

    a_struct%scalar = 3.
    a_struct%vector(i) = a_struct%scalar + 2.
    a_struct%matrix(j,i) = a_struct%vector(i) + 1.

! subroutine transpile_derived_type_array(a_struct)
!   use transpile_type, only: array_struct
!   implicit none
!      ! real(kind=real64) :: vector(:)
!      ! real(kind=real64) :: matrix(:,:)
!   type(array_struct), intent(inout) :: a_struct
!   integer :: i, j

!   a_struct%scalar = 3.
!   do i=1, 3
!     a_struct%vector(i) = a_struct%scalar + 2.
!   end do
!   do i=1, 3
!     do j=1, 3
!       a_struct%matrix(j,i) = a_struct%vector(i) + 1.
!     end do
!   end do

! end subroutine transpile_derived_type_array
    """


@pytest.mark.parametrize('use_c_ptr', (False, True))
@pytest.mark.parametrize('frontend', available_frontends())
def test_transpile_module_variables(tmp_path, builder, frontend, use_c_ptr):
    """
    Tests the use of imported module variables (via getter routines in C)
    """

    fcode_type = """
module mod_var_type_mod
    use iso_fortran_env, only: real32, real64
    implicit none

    save

    integer :: PARAM1
    real(kind=real32) :: param2
    real(kind=real64) :: param3
end module mod_var_type_mod
    """.strip()

    fcode_routine = """
subroutine transp_mod_var(a, b, c)
    use iso_fortran_env, only: real32, real64
    use mod_var_type_mod, only: PARAM1, param2, param3

    integer, intent(out) :: a
    real(kind=real32), intent(out) :: b
    real(kind=real64), intent(out) :: c

    a = 1 + PARAM1  ! Ensure downcasing is done right
    b = 1. + param2
    c = 1. + param3
end subroutine transp_mod_var
    """.strip()

    module = Module.from_source(fcode_type, frontend=frontend, xmods=[tmp_path])
    routine = Subroutine.from_source(fcode_routine, definitions=module, frontend=frontend, xmods=[tmp_path])
    refname = f'ref_{routine.name}{"_c_ptr" if use_c_ptr else ""}_{frontend}'
    reference = jit_compile_lib([module, routine], path=tmp_path, name=refname, builder=builder)

    reference.mod_var_type_mod.param1 = 2
    reference.mod_var_type_mod.param2 = 4.
    reference.mod_var_type_mod.param3 = 3.
    a, b, c = reference.transp_mod_var()
    assert a == 3 and b == 5. and c == 4.

    # Translate the header module to expose parameters
    f2cwrap = FortranISOCWrapperTransformation(use_c_ptr=use_c_ptr)
    f2cwrap.apply(source=module, path=tmp_path, role='header')

    # Create transformation object and apply
    f2c = FortranCTransformation()
    f2c.apply(source=routine, path=tmp_path, role='kernel')
    f2cwrap.apply(source=routine, path=tmp_path, role='kernel')

    # Build and wrap the cross-compiled library
    sources = [
        module, wrapperpath(tmp_path, module),
        wrapperpath(tmp_path, routine), cpath(tmp_path, routine)
    ]
    wrap = [tmp_path/'mod_var_type_mod.f90', wrapperpath(tmp_path, routine).name]
    libname = f'fc_{routine.name}{"_c_ptr" if use_c_ptr else ""}_{frontend}'
    c_kernel = jit_compile_lib(sources=sources, wrap=wrap, path=tmp_path, name=libname, builder=builder)

    c_kernel.mod_var_type_mod.param1 = 2
    c_kernel.mod_var_type_mod.param2 = 4.
    c_kernel.mod_var_type_mod.param3 = 3.
    a, b, c = c_kernel.transp_mod_var_fc_mod.transp_mod_var_fc()
    assert a == 3 and b == 5. and c == 4.


@pytest.mark.parametrize('use_c_ptr', (False, True))
@pytest.mark.parametrize('frontend', available_frontends())
def test_transpile_vectorization(tmp_path, builder, frontend, use_c_ptr):
    """
    Tests vector-notation conversion and local multi-dimensional arrays.
    """

    fcode = """
subroutine transp_vect(n, m, scalar, v1, v2)
  use iso_fortran_env, only: real64
  implicit none
  integer, intent(in) :: n, m
  real(kind=real64), intent(inout) :: scalar
  real(kind=real64), intent(inout) :: v1(n), v2(n)

  real(kind=real64) :: matrix(n, m)

  integer :: i

  v1(:) = scalar + 1.0
  matrix(:, :) = scalar + 2.
  v2(:) = matrix(:, 2)
  v2(1) = 1.
end subroutine transp_vect
"""

    # Generate reference code, compile run and verify
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/(f'transp_vect{"_c_ptr" if use_c_ptr else ""}_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname='transp_vect')

    n, m = 3, 4
    scalar = 2.0
    v1 = np.zeros(shape=(n,), order='F')
    v2 = np.zeros(shape=(n,), order='F')
    function(n, m, scalar, v1, v2)

    assert np.all(v1 == 3.)
    assert v2[0] == 1. and np.all(v2[1:] == 4.)

    # Generate and test the transpiled C kernel
    f2c = FortranCTransformation()
    f2c.apply(source=routine, path=tmp_path)
    f2cwrap = FortranISOCWrapperTransformation(use_c_ptr=use_c_ptr)
    f2cwrap.apply(source=routine, path=tmp_path)

    libname = f'fc_{routine.name}{"_c_ptr" if use_c_ptr else ""}_{frontend}'
    c_kernel = jit_compile_lib(
        [wrapperpath(tmp_path, routine), cpath(tmp_path, routine)],
        path=tmp_path, name=libname, builder=builder
    )
    fc_function = c_kernel.transp_vect_fc_mod.transp_vect_fc

    # Test the trnapiled C kernel
    n, m = 3, 4
    scalar = 2.0
    v1 = np.zeros(shape=(n,), order='F')
    v2 = np.zeros(shape=(n,), order='F')
    fc_function(n, m, scalar, v1, v2)

    assert np.all(v1 == 3.)
    assert v2[0] == 1. and np.all(v2[1:] == 4.)


@pytest.mark.parametrize('use_c_ptr', (False, True))
@pytest.mark.parametrize('frontend', available_frontends())
def test_transpile_intrinsics(tmp_path, builder, frontend, use_c_ptr):
    """
    A simple test routine to test supported intrinsic functions
    """

    fcode = """
subroutine transpile_intrinsics(v1, v2, v3, v4, vmin, vmax, vabs, vmin_nested, vmax_nested)
  ! Test supported intrinsic functions
  use iso_fortran_env, only: real64
  real(kind=real64), intent(in) :: v1, v2, v3, v4
  real(kind=real64), intent(out) :: vmin, vmax, vabs, vmin_nested, vmax_nested

  vmin = min(v1, v2)
  vmax = max(v1, v2)
  vabs = abs(v1 - v2)
  vmin_nested = min(min(v1, v2), min(v3, v4))
  vmax_nested = max(max(v1, v2), max(v3, v4))
end subroutine transpile_intrinsics
"""

    # Generate reference code, compile run and verify
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/(f'transpile_intrinsics{"_c_ptr" if use_c_ptr else ""}_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname='transpile_intrinsics')

    # Test the reference solution
    v1, v2, v3, v4 = 2., 4., 1., 5.
    vmin, vmax, vabs, vmin_nested, vmax_nested = function(v1, v2, v3, v4)
    assert vmin == 2. and vmax == 4. and vabs == 2.
    assert vmin_nested == 1. and vmax_nested == 5.

    # Generate and test the transpiled C kernel
    f2c = FortranCTransformation()
    f2c.apply(source=routine, path=tmp_path)
    f2cwrap = FortranISOCWrapperTransformation(use_c_ptr=use_c_ptr)
    f2cwrap.apply(source=routine, path=tmp_path)

    libname = f'fc_{routine.name}{"_c_ptr" if use_c_ptr else ""}_{frontend}'
    c_kernel = jit_compile_lib(
        [wrapperpath(tmp_path, routine), cpath(tmp_path, routine)],
        path=tmp_path, name=libname, builder=builder
    )
    fc_function = c_kernel.transpile_intrinsics_fc_mod.transpile_intrinsics_fc

    vmin, vmax, vabs, vmin_nested, vmax_nested = fc_function(v1, v2, v3, v4)
    assert vmin == 2. and vmax == 4. and vabs == 2.
    assert vmin_nested == 1. and vmax_nested == 5.


@pytest.mark.parametrize('use_c_ptr', (False, True))
@pytest.mark.parametrize('frontend', available_frontends())
def test_transpile_loop_indices(tmp_path, builder, frontend, use_c_ptr):
    """
    Test to ensure loop indexing translates correctly
    """

    fcode = """
subroutine transp_loop_ind(n, idx, mask1, mask2, mask3)
  ! Test to ensure loop indexing translates correctly
  use iso_fortran_env, only: real64
  integer, intent(in) :: n, idx
  integer, intent(inout) :: mask1(n), mask2(n)
  real(kind=real64), intent(inout) :: mask3(n)

  integer :: i

  do i=1, n
     if (i < idx) then
        mask1(i) = 1
     end if

     if (i == idx) then
        mask1(i) = 2
     end if

     mask2(i) = i
  end do
  mask3(n) = 3.0
end subroutine transp_loop_ind
"""

    # Generate reference code, compile run and verify
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/(f'transp_loop_ind{"_c_ptr" if use_c_ptr else ""}_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname='transp_loop_ind')

    # Test the reference solution
    n = 6
    cidx, fidx = 3, 4
    mask1 = np.zeros(shape=(n,), order='F', dtype=np.int32)
    mask2 = np.zeros(shape=(n,), order='F', dtype=np.int32)
    mask3 = np.zeros(shape=(n,), order='F', dtype=np.float64)

    function(n=n, idx=fidx, mask1=mask1, mask2=mask2, mask3=mask3)
    assert np.all(mask1[:cidx-1] == 1)
    assert mask1[cidx] == 2
    assert np.all(mask1[cidx+1:] == 0)
    assert np.all(mask2 == np.arange(n, dtype=np.int32) + 1)
    assert np.all(mask3[:-1] == 0.)
    assert mask3[-1] == 3.

    # Generate and test the transpiled C kernel
    f2c = FortranCTransformation()
    f2c.apply(source=routine, path=tmp_path)
    f2cwrap = FortranISOCWrapperTransformation(use_c_ptr=use_c_ptr)
    f2cwrap.apply(source=routine, path=tmp_path)

    libname = f'fc_{routine.name}{"_c_ptr" if use_c_ptr else ""}_{frontend}'
    c_kernel = jit_compile_lib(
        [wrapperpath(tmp_path, routine), cpath(tmp_path, routine)],
        path=tmp_path, name=libname, builder=builder
    )
    fc_function = c_kernel.transp_loop_ind_fc_mod.transp_loop_ind_fc

    mask1 = np.zeros(shape=(n,), order='F', dtype=np.int32)
    mask2 = np.zeros(shape=(n,), order='F', dtype=np.int32)
    mask3 = np.zeros(shape=(n,), order='F', dtype=np.float64)
    fc_function(n=n, idx=fidx, mask1=mask1, mask2=mask2, mask3=mask3)
    assert np.all(mask1[:cidx-1] == 1)
    assert mask1[cidx] == 2
    assert np.all(mask1[cidx+1:] == 0)
    assert np.all(mask2 == np.arange(n, dtype=np.int32) + 1)
    assert np.all(mask3[:-1] == 0.)
    assert mask3[-1] == 3.


@pytest.mark.parametrize('use_c_ptr', (False, True))
@pytest.mark.parametrize('frontend', available_frontends())
def test_transpile_logical_statements(tmp_path, builder, frontend, use_c_ptr):
    """
    A simple test routine to test logical statements
    """

    fcode = """
subroutine logical_stmts(v1, v2, v_xor, v_xnor, v_nand, v_neqv, v_val)
  logical, intent(in) :: v1, v2
  logical, intent(out) :: v_xor, v_nand, v_xnor, v_neqv, v_val(2)

  v_xor = (v1 .and. .not. v2) .or. (.not. v1 .and. v2)
  v_xnor = v1 .eqv. v2
  v_nand = .not. (v1 .and. v2)
  v_neqv = v1 .neqv. v2
  v_val(1) = .true.
  v_val(2) = .false.

end subroutine logical_stmts
"""

    # Generate reference code, compile run and verify
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/(f'logical_stmts{"_c_ptr" if use_c_ptr else ""}_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname='logical_stmts')

    # Test the reference solution
    for v1 in range(2):
        for v2 in range(2):
            v_val = np.zeros(shape=(2,), order='F', dtype=np.int32)
            v_xor, v_xnor, v_nand, v_neqv = function(v1, v2, v_val)
            assert v_xor == (v1 and not v2) or (not v1 and v2)
            assert v_xnor == (v1 and v2) or not (v1 or v2)
            assert v_nand == (not (v1 and v2))
            assert v_neqv == ((not (v1 and v2)) and (v1 or v2))
            assert v_val[0] and not v_val[1]

    # Generate and test the transpiled C kernel
    f2c = FortranCTransformation()
    f2c.apply(source=routine, path=tmp_path)
    f2cwrap = FortranISOCWrapperTransformation(use_c_ptr=use_c_ptr)
    f2cwrap.apply(source=routine, path=tmp_path)

    libname = f'fc_{routine.name}{"_c_ptr" if use_c_ptr else ""}_{frontend}'
    c_kernel = jit_compile_lib(
        [wrapperpath(tmp_path, routine), cpath(tmp_path, routine)],
        path=tmp_path, name=libname, builder=builder
    )
    fc_function = c_kernel.logical_stmts_fc_mod.logical_stmts_fc

    for v1 in range(2):
        for v2 in range(2):
            v_val = np.zeros(shape=(2,), order='F', dtype=np.int32)
            v_xor, v_xnor, v_nand, v_neqv = fc_function(v1, v2, v_val)
            assert v_xor == (v1 and not v2) or (not v1 and v2)
            assert v_xnor == (v1 and v2) or not (v1 or v2)
            assert v_nand == (not (v1 and v2))
            assert v_neqv == ((not (v1 and v2)) and (v1 or v2))
            assert v_val[0] and not v_val[1]


@pytest.mark.parametrize('use_c_ptr', (False, True))
@pytest.mark.parametrize('frontend', available_frontends())
def test_transpile_multibody_conditionals(tmp_path, builder, frontend, use_c_ptr):
    """
    Test correct transformation of multi-body conditionals.
    """
    fcode = """
subroutine multibody_cond(in1, out1, out2)
  integer, intent(in) :: in1
  integer, intent(out) :: out1, out2

  if (in1 > 5) then
    out1 = 5
  else
    out1 = 1
  end if

  if (in1 < 0) then
    out2 = 0
  else if (in1 > 5) then
    out2 = 6
    out2 = out2 - 1
  else if (3 < in1 .and. in1 <= 5) then
    out2 = 4
  else
    out2 = in1
  end if
end subroutine multibody_cond
"""
    # Generate reference code, compile run and verify
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/(f'multibody_cond{"_c_ptr" if use_c_ptr else ""}_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname='multibody_cond')

    out1, out2 = function(5)
    assert out1 == 1 and out2 == 4

    out1, out2 = function(2)
    assert out1 == 1 and out2 == 2

    out1, out2 = function(-1)
    assert out1 == 1 and out2 == 0

    out1, out2 = function(10)
    assert out1 == 5 and out2 == 5

    clean_test(filepath)

    # Generate and test the transpiled C kernel
    f2c = FortranCTransformation()
    f2c.apply(source=routine, path=tmp_path)
    f2cwrap = FortranISOCWrapperTransformation(use_c_ptr=use_c_ptr)
    f2cwrap.apply(source=routine, path=tmp_path)

    libname = f'fc_{routine.name}{"_c_ptr" if use_c_ptr else ""}_{frontend}'
    c_kernel = jit_compile_lib(
        [wrapperpath(tmp_path, routine), cpath(tmp_path, routine)],
        path=tmp_path, name=libname, builder=builder
    )
    fc_function = c_kernel.multibody_cond_fc_mod.multibody_cond_fc

    out1, out2 = fc_function(5)
    assert out1 == 1 and out2 == 4

    out1, out2 = fc_function(2)
    assert out1 == 1 and out2 == 2

    out1, out2 = fc_function(-1)
    assert out1 == 1 and out2 == 0

    out1, out2 = fc_function(10)
    assert out1 == 5 and out2 == 5


@pytest.mark.parametrize('use_c_ptr', (False, True))
@pytest.mark.parametrize('frontend', available_frontends())
def test_transpile_inline_elemental_functions(tmp_path, builder, frontend, use_c_ptr):
    """
    Test correct inlining of elemental functions in C transpilation.
    """
    fcode_module = """
module multiply_mod_c
  use iso_fortran_env, only: real64
  implicit none
contains

  elemental function multiply(a, b)
    real(kind=real64) :: multiply
    real(kind=real64), intent(in) :: a, b

    multiply = a * b
  end function multiply
end module multiply_mod_c
"""

    fcode = """
subroutine inline_elemental(v1, v2, v3)
  use iso_fortran_env, only: real64
  use multiply_mod_c, only: multiply
  real(kind=real64), intent(in) :: v1
  real(kind=real64), intent(out) :: v2, v3

  v2 = multiply(v1, 6._real64)
  v3 = 600. + multiply(6._real64, 11._real64)
end subroutine inline_elemental
"""
    # Generate reference code, compile run and verify
    module = Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path])
    routine = Subroutine.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    refname = f'ref_{routine.name}{"_c_ptr" if use_c_ptr else ""}_{frontend}'
    reference = jit_compile_lib([module, routine], path=tmp_path, name=refname, builder=builder)

    v2, v3 = reference.inline_elemental(11.)
    assert v2 == 66.
    assert v3 == 666.

    (tmp_path/f'{module.name}.f90').unlink()
    (tmp_path/f'{routine.name}.f90').unlink()

    # Now transpile with supplied elementals but without module
    routine = Subroutine.from_source(fcode, definitions=module, frontend=frontend, xmods=[tmp_path])

    f2c = FortranCTransformation(inline_elementals=True)
    f2c.apply(source=routine, path=tmp_path)
    f2cwrap = FortranISOCWrapperTransformation(use_c_ptr=use_c_ptr)
    f2cwrap.apply(source=routine, path=tmp_path)

    libname = f'fc_{routine.name}{"_c_ptr" if use_c_ptr else ""}_{frontend}'
    c_kernel = jit_compile_lib(
        [wrapperpath(tmp_path, routine), cpath(tmp_path, routine)],
        path=tmp_path, name=libname, builder=builder
    )
    fc_mod = c_kernel.inline_elemental_fc_mod

    v2, v3 = fc_mod.inline_elemental_fc(11.)
    assert v2 == 66.
    assert v3 == 666.


@pytest.mark.parametrize('use_c_ptr', (False, True))
@pytest.mark.parametrize('frontend', available_frontends())
def test_transpile_inline_elementals_recursive(tmp_path, builder, frontend, use_c_ptr):
    """
    Test correct inlining of nested elemental functions.
    """
    fcode_module = """
module multiply_plus_one_mod
  use iso_fortran_env, only: real64
  implicit none
contains

  elemental function multiply(a, b)
    real(kind=real64) :: multiply
    real(kind=real64), intent(in) :: a, b

    multiply = a * b
  end function multiply

  elemental function plus_one(a)
    real(kind=real64) :: plus_one
    real(kind=real64), intent(in) :: a

    plus_one = a + 1._real64
  end function plus_one

  elemental function multiply_plus_one(a, b)
    real(kind=real64) :: multiply_plus_one
    real(kind=real64), intent(in) :: a, b

    multiply_plus_one = multiply(plus_one(a), b)
  end function multiply_plus_one
end module multiply_plus_one_mod
"""

    fcode = """
subroutine inline_elementals_rec(v1, v2, v3)
  use iso_fortran_env, only: real64
  use multiply_plus_one_mod, only: multiply_plus_one
  real(kind=real64), intent(in) :: v1
  real(kind=real64), intent(out) :: v2, v3

  v2 = multiply_plus_one(v1, 6._real64)
  v3 = 600. + multiply_plus_one(5._real64, 11._real64)
end subroutine inline_elementals_rec
"""
    # Generate reference code, compile run and verify
    module = Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path])
    routine = Subroutine.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    refname = f'ref_{routine.name}{"_c_ptr" if use_c_ptr else ""}_{frontend}'
    reference = jit_compile_lib([module, routine], path=tmp_path, name=refname, builder=builder)

    v2, v3 = reference.inline_elementals_rec(10.)
    assert v2 == 66.
    assert v3 == 666.

    (tmp_path/f'{module.name}.f90').unlink()
    (tmp_path/f'{routine.name}.f90').unlink()

    # Now transpile with supplied elementals but without module
    routine = Subroutine.from_source(fcode, definitions=module, frontend=frontend, xmods=[tmp_path])

    f2c = FortranCTransformation(inline_elementals=True)
    f2c.apply(source=routine, path=tmp_path)
    f2cwrap = FortranISOCWrapperTransformation(use_c_ptr=use_c_ptr)
    f2cwrap.apply(source=routine, path=tmp_path)

    libname = f'fc_{routine.name}{"_c_ptr" if use_c_ptr else ""}_{frontend}'
    c_kernel = jit_compile_lib(
        [wrapperpath(tmp_path, routine), cpath(tmp_path, routine)],
        path=tmp_path, name=libname, builder=builder
    )
    fc_mod = c_kernel.inline_elementals_rec_fc_mod

    v2, v3 = fc_mod.inline_elementals_rec_fc(10.)
    assert v2 == 66.
    assert v3 == 666.


@pytest.mark.parametrize('use_c_ptr', (False, True))
@pytest.mark.parametrize('frontend', available_frontends())
def test_transpile_expressions(tmp_path, builder, frontend, use_c_ptr):
    """
    A simple test to verify expression parenthesis and resolution
    of minus sign
    """

    fcode = """
subroutine transpile_expressions(n, scalar, vector)
  use iso_fortran_env, only: real64
  implicit none
  integer, intent(in) :: n
  real(kind=real64), intent(in) :: scalar
  real(kind=real64), intent(inout) :: vector(n)

  integer :: i

  vector(1) = scalar
  do i=2, n
     vector(i) = vector(i-1) - (-scalar)
  end do
end subroutine transpile_expressions
"""

    # Generate reference code, compile run and verify
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/f'{routine.name}{"_c_ptr" if use_c_ptr else ""}_{frontend!s}.f90'
    function = jit_compile(routine, filepath=filepath, objname=routine.name)

    n = 10
    scalar = 2.0
    vector = np.zeros(shape=(n,), order='F')
    function(n, scalar, vector)

    assert np.all(vector == [i * scalar for i in range(1, n+1)])

    # Generate and test the transpiled C kernel
    f2c = FortranCTransformation()
    f2c.apply(source=routine, path=tmp_path)
    f2cwrap = FortranISOCWrapperTransformation(use_c_ptr=use_c_ptr)
    f2cwrap.apply(source=routine, path=tmp_path)

    libname = f'fc_{routine.name}{"_c_ptr" if use_c_ptr else ""}_{frontend}'
    c_kernel = jit_compile_lib(
        [wrapperpath(tmp_path, routine), cpath(tmp_path, routine)],
        path=tmp_path, name=libname, builder=builder
    )
    fc_function = c_kernel.transpile_expressions_fc_mod.transpile_expressions_fc

    # Make sure minus signs are represented correctly in the C code
    ccode = cpath(tmp_path, routine).read_text()
    # double minus due to index shift to 0
    assert 'vector[i - 1 - 1]' in ccode or 'vector[-1 + i - 1]' in ccode
    assert 'vector[i - 1]' in ccode
    assert '-scalar' in ccode  # scalar with negative sign

    n = 10
    scalar = 2.0
    vector = np.zeros(shape=(n,), order='F')
    fc_function(n, scalar, vector)

    assert np.all(vector == [i * scalar for i in range(1, n+1)])


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('language', ('c', 'cuda'))
@pytest.mark.parametrize('chevron', (False, True))
def test_transpile_call(tmp_path, frontend, language, chevron):
    fcode_module = """
module transpile_call_kernel_mod
  implicit none
contains

  subroutine transpile_call_kernel(a, b, c, arr1, arr2, len)
    integer, intent(inout) :: a, c
    integer, intent(in) :: b
    integer, intent(in) :: len
    integer, intent(inout) :: arr1(len, len)
    integer, intent(in) :: arr2(len, len)
    a = b
    c = b
  end subroutine transpile_call_kernel
end module transpile_call_kernel_mod
"""

    fcode = """
subroutine transpile_call_driver(a)
  use transpile_call_kernel_mod, only: transpile_call_kernel
    integer, intent(inout) :: a
    integer, parameter :: len = 5
    integer :: arr1(len, len)
    integer :: arr2(len, len)
    integer :: b
    b = 2 * len
    call transpile_call_kernel(a, b, arr2(1, 1), arr1, arr2, len)
end subroutine transpile_call_driver
"""
    module = Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path])
    routine = Subroutine.from_source(fcode, frontend=frontend, definitions=module, xmods=[tmp_path])
    if chevron:
        calls = FindNodes(ir.CallStatement).visit(routine.body)
        calls[0]._update(chevron=(sym.IntLiteral(1), sym.IntLiteral(1)))
    f2c = FortranCTransformation(language=language)
    f2c.apply(source=module.subroutine_map['transpile_call_kernel'], path=tmp_path, role='kernel')
    ccode_kernel = cpath(tmp_path, module.routines[0]).read_text().replace(' ', '').replace('\n', '')
    f2c.apply(source=routine, path=tmp_path, role='kernel')
    ccode_driver = cpath(tmp_path, routine).read_text().replace(' ', '').replace('\n', '')

    assert "int*a,intb,int*c" in ccode_kernel
    # check for applied Dereference
    assert "(*a)=b;" in ccode_kernel
    assert "(*c)=b;" in ccode_kernel
    # check for applied 'const' and 'restrict'/'__restrict__'
    if language == 'cuda':
        assert 'int*a,intb,int*c,int*__restrict__arr1,constint*__restrict__arr2' in ccode_kernel
    else:
        assert 'int*a,intb,int*c,int*restrictarr1,int*restrictarr2' in ccode_kernel
    # check for applied Reference and chevron
    if chevron and language == 'cuda':
        assert "transpile_call_kernel<<<1,1>>>((&a),b,(&arr2[" in ccode_driver
    else:
        assert "transpile_call_kernel((&a),b,(&arr2[" in ccode_driver


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('codegen', (cgen, cppgen, cudagen))
@pytest.mark.parametrize('header', (False, True))
@pytest.mark.parametrize('guards', (False, True))
@pytest.mark.parametrize('extern', (False, True))
@pytest.mark.parametrize('guard_name', (None, 'random_guard_name'))
def test_transpile_simple_routine(tmp_path, frontend, codegen, guards, guard_name,
        header, extern):
    """
    Test correct transpilation of functions in C transpilation with a focus
    on code-gen options.
    """

    fcode = """
subroutine add(a, b, result)
    real, intent(in) :: a, b
    real, intent(out) :: result

    result = a + b
end subroutine add
""".strip()

    routine = Subroutine.from_source(fcode, frontend=frontend)
    f2c = FortranCTransformation()
    f2c.apply(source=routine, path=tmp_path)

    c_routine = codegen(routine, guards=guards, guard_name=guard_name,
            header=header, extern=extern)
    if extern and codegen in (cppgen, cudagen):
        assert 'extern "C"' in c_routine
    else:
        assert 'extern "C"' not in c_routine
    if header:
        assert 'void add(double a, double b, double result);' in c_routine
    else:
        assert 'void add(double a, double b, double result) {' in c_routine
    if guards:
        if guard_name is None:
            assert '#ifndef ADD_H' in c_routine
            assert '#define ADD_H' in c_routine
            assert '#endif' in c_routine
        else:
            assert '#ifndef random_guard_name' in c_routine
            assert '#define random_guard_name' in c_routine
            assert '#endif' in c_routine
    else:
        assert '#ifndef' not in c_routine
        assert '#define' not in c_routine
        assert '#endif' not in c_routine

@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('codegen', (cgen, cppgen, cudagen))
def test_transpile_routine_with_interface(tmp_path, frontend, codegen):
    """
    Test transpilation of 'INTERFACE's.
    """

    fcode = """
subroutine some_routine_with_interf(a, b, result)
  INTERFACE
    SUBROUTINE KERNEL(a, b, c)
      INTEGER, INTENT(INOUT) :: a, b, c
    END SUBROUTINE KERNEL
  END INTERFACE

    real, intent(in) :: a, b
    real, intent(out) :: result

    result = a + b
end subroutine some_routine_with_interf
""".strip()

    routine = Subroutine.from_source(fcode, frontend=frontend)
    f2c = FortranCTransformation()
    f2c.apply(source=routine, path=tmp_path)
    c_routine = codegen(routine).lower()
    assert 'interface' not in c_routine
    assert 'kernel' not in c_routine

@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('f_type', ['integer', 'real'])
@pytest.mark.parametrize('codegen', (cgen, cppgen, cudagen))
def test_transpile_inline_functions(tmp_path, frontend, f_type, codegen):
    """
    Test correct transpilation of functions in C transpilation.
    """

    fcode = f"""
function add(a, b)
    {f_type} :: add
    {f_type}, intent(in) :: a, b

    add = a + b
end function add
""".format(f_type)

    routine = Subroutine.from_source(fcode, frontend=frontend)
    f2c = FortranCTransformation()
    f2c.apply(source=routine, path=tmp_path)

    f_type_map = {'integer': 'int', 'real': 'double'}
    c_routine = codegen(routine)
    assert 'return add;' in c_routine
    assert f'{f_type_map[f_type]} add(' in c_routine


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('f_type', ['integer', 'real'])
@pytest.mark.parametrize('codegen', (cgen, cppgen, cudagen))
def test_transpile_inline_functions_return(tmp_path, frontend, f_type, codegen):
    """
    Test correct transpilation of functions in C transpilation.
    """

    fcode = f"""
function add(a, b) result(res)
    {f_type} :: res
    {f_type}, intent(in) :: a, b

    res = a + b
end function add
""".format(f_type)

    routine = Subroutine.from_source(fcode, frontend=frontend)
    f2c = FortranCTransformation()
    f2c.apply(source=routine, path=tmp_path)

    f_type_map = {'integer': 'int', 'real': 'double'}
    c_routine = codegen(routine)
    assert 'return res;' in c_routine
    assert f'{f_type_map[f_type]} add(' in c_routine


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('codegen', (cgen, cppgen, cudagen))
def test_transpile_multiconditional_simple(tmp_path, builder, frontend, codegen):
    """
    A simple test to verify multiconditionals/select case statements.
    """

    fcode = """
subroutine multi_cond_simple(in, out)
  implicit none
  integer, intent(in) :: in
  integer, intent(inout) :: out

  select case (in)
    case (1)
        out = 10
    case (2)
        out = 20
    case default
        out = 100
  end select

end subroutine multi_cond_simple
""".strip()

    # for testing purposes
    in_var = 0
    test_vals = [0, 1, 2, 5]
    expected_results = [100, 10, 20, 100]
    out_var = np.array(0)

    # compile original Fortran version
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/f'{routine.name}_{frontend!s}.f90'
    function = jit_compile(routine, filepath=filepath, objname=routine.name)
    # test Fortran version
    for i, val in enumerate(test_vals):
        in_var = val
        function(in_var, out_var)
        assert out_var == expected_results[i]

    # apply F2C trafo
    f2c = FortranCTransformation()
    f2c.apply(source=routine, path=tmp_path)
    f2cwrap = FortranISOCWrapperTransformation()
    f2cwrap.apply(source=routine, path=tmp_path)

    # check whether 'switch' statement is within C code
    assert 'switch' in codegen(routine)

    # compile C version
    libname = f'fc_{routine.name}_{frontend}'
    c_kernel = jit_compile_lib(
        [wrapperpath(tmp_path, routine), cpath(tmp_path, routine)],
        path=tmp_path, name=libname, builder=builder
    )
    fc_function = c_kernel.multi_cond_simple_fc_mod.multi_cond_simple_fc
    # test C version
    for i, val in enumerate(test_vals):
        in_var = val
        fc_function(in_var, out_var)
        assert out_var == expected_results[i]


@pytest.mark.parametrize('frontend', available_frontends())
def test_transpile_multiconditional(tmp_path, builder, frontend):
    """
    A test to verify multiconditionals/select case statements.
    """

    fcode = """
subroutine multi_cond(in, out)
  implicit none
  integer, intent(in) :: in
  integer, intent(inout) :: out

  select case (in)
    case (:5)
        out = 10
    case (6, 7, 10:15)
        out = 15
    case (8)
        out = 12
    case (20:30)
        out = 20
    case default
        out = 100
  end select

end subroutine multi_cond
""".strip()

    # for testing purposes
    in_var = 0
    # [(, ), (, ), ...]
    test_results = [(0, 10), (1, 10), (5, 10), (6, 15), (10, 15), (11, 15),
                    (15, 15), (8, 12), (20, 20), (21, 20), (29, 20), (50, 100)]
    out_var = np.array(0)

    # compile original Fortran version
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/f'{routine.name}_{frontend!s}.f90'
    function = jit_compile(routine, filepath=filepath, objname=routine.name)
    # test Fortran version
    for val in test_results:
        in_var = val[0]
        function(in_var, out_var)
        assert out_var == val[1]

    # apply F2C trafo
    f2c = FortranCTransformation()
    f2c.apply(source=routine, path=tmp_path)
    f2cwrap = FortranISOCWrapperTransformation()
    f2cwrap.apply(source=routine, path=tmp_path)

    # check whether 'switch' statement is within C code
    assert 'switch' in cgen(routine)

    # compile C version
    libname = f'fc_{routine.name}_{frontend}'
    c_kernel = jit_compile_lib(
        [wrapperpath(tmp_path, routine), cpath(tmp_path, routine)],
        path=tmp_path, name=libname, builder=builder
    )
    fc_function = c_kernel.multi_cond_fc_mod.multi_cond_fc
    # test C version
    for val in test_results:
        in_var = val[0]
        fc_function(in_var, out_var)
        assert out_var == val[1]



@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('dtype', ('integer', 'real',))
@pytest.mark.parametrize('add_float', (False, True))
def test_transpile_special_functions(tmp_path, builder, frontend, dtype, add_float):
    """
    A simple test to verify multiconditionals/select case statements.
    """
    if dtype == 'real':
        decl_type = f'{dtype}(kind=real64)'
        kind = '._real64'
    else:
        decl_type = dtype
        kind = ''

    fcode = f"""
subroutine transpile_special_functions(in, out)
  use iso_fortran_env, only: real64
  implicit none
  {decl_type}, intent(in) :: in
  {decl_type}, intent(inout) :: out
  if (mod(in{'+ 2._real64' if add_float else ''}, 2{kind}{'+ 0._real64' if add_float else ''}) .eq. 0) then
    out = 42{kind}
  else
    out = 11{kind}
  endif
end subroutine transpile_special_functions
""".strip()

    def init_var(dtype, val=0):
        if dtype == 'real':
            return np.array(np.float64(val))
        return np.array(np.int_(val))

    # for testing purposes
    in_var = init_var(dtype)
    test_vals = [2, 10, 5, 3]
    expected_results = [42, 42, 11, 11]
    out_var = init_var(dtype)

    # compile original Fortran version
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/f'{routine.name}_{frontend!s}.f90'
    function = jit_compile(routine, filepath=filepath, objname=routine.name)
    # test Fortran version
    for i, val in enumerate(test_vals):
        in_var = val
        function(in_var, out_var)
        assert out_var == expected_results[i]

    clean_test(filepath)

    # apply F2C trafo
    f2c = FortranCTransformation()
    f2c.apply(source=routine, path=tmp_path)
    f2cwrap = FortranISOCWrapperTransformation()
    f2cwrap.apply(source=routine, path=tmp_path)

    # check whether correct modulo was inserted
    ccode = cpath(tmp_path, routine).read_text()
    if dtype == 'integer' and not add_float:
        assert '%' in ccode
    if dtype == 'real' or add_float:
        assert 'fmod' in ccode

    # compile C version
    libname = f'fc_{routine.name}_{frontend}'
    c_kernel = jit_compile_lib(
        [wrapperpath(tmp_path, routine), cpath(tmp_path, routine)],
        path=tmp_path, name=libname, builder=builder
    )
    fc_function = c_kernel.transpile_special_functions_fc_mod.transpile_special_functions_fc
    # test C version
    for i, val in enumerate(test_vals):
        in_var = val
        fc_function(in_var, out_var)
        assert int(out_var) == expected_results[i]


@pytest.mark.parametrize('frontend', available_frontends())
def test_transpile_interface_to_module(tmp_path, frontend):
    driver_fcode = """
SUBROUTINE driver_interface_to_module(a, b, c)
  IMPLICIT NONE
  INTERFACE
    SUBROUTINE KERNEL(a, b, c)
      INTEGER, INTENT(INOUT) :: a, b, c
    END SUBROUTINE KERNEL
  END INTERFACE
  INTERFACE
    SUBROUTINE KERNEL2(a, b)
      INTEGER, INTENT(INOUT) :: a, b
    END SUBROUTINE KERNEL2
  END INTERFACE
  INTERFACE
    SUBROUTINE KERNEL3(a)
      INTEGER, INTENT(INOUT) :: a
    END SUBROUTINE KERNEL3
  END INTERFACE

  INTEGER, INTENT(INOUT) :: a, b, c

  CALL kernel(a, b ,c)
  CALL kernel2(a, b)
END SUBROUTINE driver_interface_to_module
    """.strip()

    routine = Subroutine.from_source(driver_fcode, frontend=frontend)

    interfaces = FindNodes(ir.Interface).visit(routine.spec)
    imports = FindNodes(ir.Import).visit(routine.spec)
    assert len(interfaces) == 3
    assert not imports

    f2c = FortranCTransformation()
    f2c.apply(source=routine, path=tmp_path, targets=('kernel',), role='driver')

    assert len(routine.interfaces) == 2
    imports = routine.imports
    assert len(imports) == 1
    assert imports[0].module.upper() == 'KERNEL_FC_MOD'
    assert imports[0].symbols == ('KERNEL_FC',)


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('language', ['c', 'cpp'])
def test_transpile_optional_args(tmp_path, builder, frontend, language):
    """
    A simple test to verify multiconditionals/select case statements.
    """

    fcode = """
subroutine transpile_optional_args(in, out, out2, opt_flag)
  implicit none
  integer, intent(in) :: in
  integer, intent(inout) :: out
  integer, intent(out), optional :: out2
  logical, intent(in), optional :: opt_flag

  out = in
  if (present(out2)) then
    out2 = 2*in
    if (present(opt_flag)) then
        if (opt_flag) then
            out = 2* out2
        else
            out = 4* out2
        endif
    else
        out = out2
    endif
  endif
  if (.not. present(out2) .and. present(opt_flag)) then
    if (opt_flag) then
      out = in + 1
    else
      out = in + 2
    endif
  endif

end subroutine transpile_optional_args
""".strip()

    def init_out_vars():
        return np.array(0), np.array(0)

    # for testing purposes
    in_var = 10

    # compile and test original Fortran version
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/f'{routine.name}_{frontend!s}.f90'
    function = jit_compile(routine, filepath=filepath, objname=routine.name)
    out_var, out_var2 = init_out_vars()
    function(in_var, out_var)
    assert out_var == 10 and out_var2 == 0
    out_var, out_var2 = init_out_vars()
    function(in_var, out_var, out_var2)
    assert out_var == 20 and out_var2 == 20
    opt_flag = 1
    out_var, out_var2 = init_out_vars()
    function(in_var, out_var, opt_flag=opt_flag)
    assert out_var == 11 and out_var2 == 0
    opt_flag = 0
    out_var, out_var2 = init_out_vars()
    function(in_var, out_var, opt_flag=opt_flag)
    assert out_var == 12 and out_var2 == 0
    opt_flag = 1
    out_var, out_var2 = init_out_vars()
    function(in_var, out_var, out_var2, opt_flag)
    assert out_var == 40 and out_var2 == 20
    opt_flag = 0
    out_var, out_var2 = init_out_vars()
    function(in_var, out_var, out_var2, opt_flag)
    assert out_var == 80 and out_var2 == 20

    clean_test(filepath)

    # transpile
    f2c = FortranCTransformation(language=language)
    f2c.apply(source=routine, path=tmp_path)
    f2cwrap = FortranISOCWrapperTransformation(language=language)
    f2cwrap.apply(source=routine, path=tmp_path)

    # compile and testC/C++ version
    libname = f'fc_{routine.name}_{language}_{frontend}'
    c_kernel = jit_compile_lib(
        [wrapperpath(tmp_path, routine), cpath(tmp_path, routine, suffix=f'.{language}')],
        path=tmp_path, name=libname, builder=builder
    )
    fc_function = c_kernel.transpile_optional_args_fc_mod.transpile_optional_args_fc
    if language != 'c':
        out_var, out_var2 = init_out_vars()
        fc_function(in_var, out_var)
        assert out_var == 10 and out_var2 == 0
        opt_flag = 1
        out_var, out_var2 = init_out_vars()
        fc_function(in_var, out_var, opt_flag=opt_flag)
        assert out_var == 11 and out_var2 == 0
        opt_flag = 0
        out_var, out_var2 = init_out_vars()
        fc_function(in_var, out_var, opt_flag=opt_flag)
        assert out_var == 12 and out_var2 == 0
    opt_flag = 1
    out_var, out_var2 = init_out_vars()
    fc_function(in_var, out_var, out_var2, opt_flag)
    assert out_var == 40 and out_var2 == 20
    opt_flag = 0
    out_var, out_var2 = init_out_vars()
    fc_function(in_var, out_var, out_var2, opt_flag)
    assert out_var == 80 and out_var2 == 20
loki-ecmwf-0.3.6/loki/transformations/transpile/fortran_c.py0000664000175000017500000004005515167130205024464 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from pathlib import Path

from loki.backend import cgen, cudagen, cppgen
from loki.batch import Transformation
from loki.expression import (
    symbols as sym, Variable, InlineCall, Scalar, Array,
    ProcedureSymbol, Dereference, Reference, ExpressionRetriever,
    SubstituteExpressionsMapper
)
from loki.ir import (
    Import, Intrinsic, Interface, CallStatement, Assignment,
    Transformer, FindNodes, Comment, SubstituteExpressions,
    FindInlineCalls
)
from loki.logging import debug
from loki.sourcefile import Sourcefile
from loki.tools import as_tuple, flatten
from loki.types import BasicType, DerivedType

from loki.transformations.array_indexing import (
    shift_to_zero_indexing, invert_array_indices,
    resolve_vector_notation, normalize_array_shape_and_access,
    flatten_arrays
)
from loki.transformations.inline import (
    inline_constant_parameters, inline_elemental_functions
)
from loki.transformations.sanitise import do_resolve_associates
from loki.transformations.utilities import (
    convert_to_lower_case, replace_intrinsics, sanitise_imports
)


__all__ = ['FortranCTransformation']


class DeReferenceTrafo(Transformer):
    """
    Transformation to apply/insert Dereference = `*` and
    Reference/*address-of* = `&` operators.

    Parameters
    ----------
    vars2dereference : list
        Variables to be dereferenced. Ususally the arguments except
        for scalars with `intent=in`.
    """
    # pylint: disable=unused-argument

    def __init__(self, vars2dereference):
        super().__init__()
        self.retriever = ExpressionRetriever(self.is_dereference)
        self.vars2dereference = vars2dereference

    @staticmethod
    def is_dereference(symbol):
        return isinstance(symbol, (DerivedType, Array, Scalar)) and not (
            isinstance(symbol, Array) and symbol.dimensions is not None
            and not all(dim == sym.RangeIndex((None, None)) for dim in symbol.dimensions)
        )

    def visit_Expression(self, o, **kwargs):
        symbol_map = {
            symbol: Dereference(symbol.clone()) for symbol in self.retriever.retrieve(o)
            if symbol.name.lower() in self.vars2dereference
        }
        return SubstituteExpressionsMapper(symbol_map)(o)

    def visit_CallStatement(self, o, **kwargs):
        new_args = ()
        if o.routine is BasicType.DEFERRED:
            debug(f'DeReferenceTrafo: Skipping call to {o.name!s} due to missing procedure enrichment')
            return o
        call_arg_map = dict((v,k) for k,v in o.arg_map.items())
        for arg in o.arguments:
            if not self.is_dereference(arg) and (isinstance(call_arg_map[arg], Array)\
                    or call_arg_map[arg].type.intent.lower() != 'in'):
                new_args += (Reference(arg.clone()),)
            else:
                if isinstance(arg, Scalar) and call_arg_map[arg].type.intent.lower() != 'in':
                    new_args += (Reference(arg.clone()),)
                else:
                    new_args += (arg,)
        o._update(arguments=new_args)
        return o


class FortranCTransformation(Transformation):
    """
    Fortran-to-C transformation that translates the given routine into C.

    Parameters
    ----------
    inline_elementals : bool, optional
        Inline known elemental function via expression substitution. Default is ``True``.
    language : str
        C-style language to generate; should be one of ``['c', 'cpp', 'cuda']``.
    """
    # pylint: disable=unused-argument

    # Set of standard module names that have no C equivalent
    __fortran_intrinsic_modules = ['ISO_FORTRAN_ENV', 'ISO_C_BINDING']

    def __init__(self, inline_elementals=True, language='c'):
        self.inline_elementals = inline_elementals
        self.language = language.lower()
        self._supported_languages = ['c', 'cpp', 'cuda']

        if self.language == 'c':
            self.codegen = cgen
        elif self.language == 'cpp':
            self.codegen = cppgen
        elif self.language == 'cuda':
            self.codegen = cudagen
        else:
            raise ValueError(f'language "{self.language}" is not supported!'
                             f' (supported languages: "{self._supported_languages}")')

    def file_suffix(self):
        if self.language == 'cpp':
            return '.cpp'
        return '.c'

    def transform_subroutine(self, routine, **kwargs):
        if 'path' in kwargs:
            path = kwargs.get('path')
        else:
            build_args = kwargs.get('build_args')
            path = Path(build_args.get('output_dir'))

        role = kwargs.get('role', 'kernel')
        item = kwargs.get('item', None)
        depths = kwargs.get('depths', None)
        targets = kwargs.get('targets', None)
        sub_sgraph = kwargs.get('sub_sgraph', None)
        successors = as_tuple(sub_sgraph.successors(item)) if sub_sgraph is not None else ()

        depth = 0
        if depths is None:
            if role == 'driver':
                depth = 0
            elif role == 'kernel':
                depth = 1
        else:
            depth = depths[item]

        if role == 'driver':
            self.interface_to_import(routine, targets)
            return

        # for calls and inline calls: convert kwarguments to arguments
        self.convert_kwargs_to_args(routine, targets)

        if role == 'kernel':
            # Generate C source file from Loki IR
            c_kernel = self.generate_c_kernel(routine, targets=targets)

            for successor in successors:
                c_kernel.spec.prepend(Import(module=f'{successor.ir.name.lower()}_c.h', c_import=True))

            if depth == 1:
                if self.language != 'c':
                    c_kernel_launch = c_kernel.clone(name=f"{c_kernel.name}_launch", prefix="extern_c")
                    self.generate_c_kernel_launch(c_kernel_launch, c_kernel)
                    c_path = (path/c_kernel_launch.name.lower()).with_suffix('.h')
                    Sourcefile.to_file(source=self.codegen(c_kernel_launch, extern=True), path=c_path)

            assignments = FindNodes(Assignment).visit(c_kernel.body)
            assignments2remove = ['griddim', 'blockdim']
            assignment_map = {assignment: None for assignment in assignments
                    if assignment.lhs.name.lower() in assignments2remove}
            c_kernel.body = Transformer(assignment_map).visit(c_kernel.body)

            if depth > 1:
                c_kernel.spec.prepend(Import(module=f'{c_kernel.name.lower()}.h', c_import=True))
            c_path = (path/c_kernel.name.lower()).with_suffix(self.file_suffix())
            Sourcefile.to_file(source=self.codegen(c_kernel, extern=self.language=='cpp'), path=c_path)
            header_path = (path/c_kernel.name.lower()).with_suffix('.h')
            Sourcefile.to_file(source=self.codegen(c_kernel, header=True), path=header_path)

    def convert_kwargs_to_args(self, routine, targets):
        # calls (to subroutines)
        for call in as_tuple(FindNodes(CallStatement).visit(routine.body)):
            if str(call.name).lower() in as_tuple(targets):
                call.convert_kwargs_to_args()
        # inline calls (to functions)
        inline_call_map = {}
        for inline_call in as_tuple(FindInlineCalls().visit(routine.body)):
            if str(inline_call.name).lower() in as_tuple(targets) and inline_call.routine is not BasicType.DEFERRED:
                inline_call_map[inline_call] = inline_call.clone_with_kwargs_as_args()
        if inline_call_map:
            routine.body = SubstituteExpressions(inline_call_map).visit(routine.body)

    def interface_to_import(self, routine, targets):
        """
        Convert interface to import.
        """
        for call in FindNodes(CallStatement).visit(routine.body):
            if str(call.name).lower() in as_tuple(targets):
                call.convert_kwargs_to_args()
        intfs = FindNodes(Interface).visit(routine.spec)
        removal_map = {}
        for i in intfs:
            for s in i.symbols:
                if targets and s in targets:
                    # Create a new module import with explicitly qualified symbol
                    new_symbol = s.clone(name=f'{s.name}_FC', scope=routine)
                    modname = f'{new_symbol.name}_MOD'
                    new_import = Import(module=modname, c_import=False, symbols=(new_symbol,))
                    routine.spec.prepend(new_import)
                    # Mark current import for removal
                    removal_map[i] = None
        # Apply any scheduled interface removals to spec
        if removal_map:
            routine.spec = Transformer(removal_map).visit(routine.spec)

    @staticmethod
    def apply_de_reference(routine):
        """
        Utility method to apply/insert Dereference = `*` and
        Reference/*address-of* = `&` operators.
        """
        to_be_dereferenced = []
        for arg in routine.arguments:
            if not(arg.type.intent.lower() == 'in' and isinstance(arg, Scalar)) or arg.type.optional:
                to_be_dereferenced.append(arg.name.lower())

        routine.body = DeReferenceTrafo(to_be_dereferenced).visit(routine.body)

    def generate_c_kernel(self, routine, targets, **kwargs):
        """
        Re-generate the C kernel and insert wrapper-specific peculiarities,
        such as the explicit getter calls for imported module-level variables.
        """

        # CAUTION! Work with a copy of the original routine to not break the
        #  dependency graph of the Scheduler through the rename
        kernel = routine.clone()
        kernel.name = f'{kernel.name.lower()}_c'

        # Clean up Fortran vector notation
        resolve_vector_notation(kernel)
        normalize_array_shape_and_access(kernel)

        # Convert array indexing to C conventions
        # TODO: Resolve reductions (eg. SUM(myvar(:)))
        invert_array_indices(kernel)
        shift_to_zero_indexing(kernel, ignore=() if self.language == 'c' else ('jl', 'ibl'))
        flatten_arrays(kernel, order='C', start_index=0)

        # Inline all known parameters, since they can be used in declarations,
        # and thus need to be known before we can fetch them via getters.
        inline_constant_parameters(kernel, external_only=True)

        if self.inline_elementals:
            # Inline known elemental function via expression substitution
            inline_elemental_functions(kernel)

        # Create declarations for module variables
        if self.language == 'c':
            module_variables = {
                im.module.lower(): [
                    s.clone(scope=kernel, type=s.type.clone(imported=None, module=None)) for s in im.symbols
                    if isinstance(s, Scalar) and s.type.dtype is not BasicType.DEFERRED and not s.type.parameter
                ]
                for im in kernel.imports
            }
            kernel.variables += as_tuple(flatten(list(module_variables.values())))

            # Create calls to getter routines for module variables
            getter_calls = []
            for module, variables in module_variables.items():
                for var in variables:
                    getter = f'{module}__get__{var.name.lower()}'
                    vget = Assignment(lhs=var, rhs=InlineCall(ProcedureSymbol(getter, scope=var.scope)))
                    getter_calls += [vget]
            kernel.body.prepend(getter_calls)

            # Change imports to C header includes
            import_map = {}
            for im in kernel.imports:
                if str(im.module).upper() in self.__fortran_intrinsic_modules:
                    # Remove imports of Fortran intrinsic modules
                    import_map[im] = None

                elif not im.c_import and im.symbols:
                    # Create a C-header import for any converted modules
                    import_map[im] = im.clone(module=f'{im.module.lower()}_c.h', c_import=True, symbols=())

                else:
                    # Remove other imports, as they might include untreated Fortran code
                    import_map[im] = None
            kernel.spec = Transformer(import_map).visit(kernel.spec)

        # Remove intrinsics from spec (eg. implicit none)
        intrinsic_map = {i: None for i in FindNodes(Intrinsic).visit(kernel.spec)
                         if 'implicit' in i.text.lower()}
        kernel.spec = Transformer(intrinsic_map).visit(kernel.spec)

        # Resolve implicit struct mappings through "associates"
        do_resolve_associates(kernel)

        # Force all variables to lower-caps, as C/C++ is case-sensitive
        convert_to_lower_case(kernel)

        # Force pointer on reference-passed arguments (and lower case type names for derived types)
        for arg in kernel.arguments:

            if not(arg.type.intent.lower() == 'in' and isinstance(arg, Scalar)):
                _type = arg.type.clone(pointer=True)
                if isinstance(arg.type.dtype, DerivedType):
                    # Lower case type names for derived types
                    typedef = _type.dtype.typedef.clone(name=_type.dtype.typedef.name.lower())
                    _type = _type.clone(dtype=typedef.dtype)
                kernel.symbol_attrs[arg.name] = _type

        # apply dereference and reference where necessary
        self.apply_de_reference(kernel)

        # adapt call and inline call names -> '_c'
        self.convert_call_names(kernel, targets)

        symbol_map = {'epsilon': 'DBL_EPSILON'}
        function_map = {'min': 'fmin', 'max': 'fmax', 'abs': 'fabs',
                        'exp': 'exp', 'sqrt': 'sqrt', 'sign': 'copysign'}
        replace_intrinsics(kernel, symbol_map=symbol_map, function_map=function_map)

        # Remove redundant imports
        sanitise_imports(kernel)

        return kernel

    def convert_call_names(self, routine, targets):
        # calls (to subroutines)
        calls = FindNodes(CallStatement).visit(routine.body)
        for call in calls:
            if call.name not in as_tuple(targets):
                continue
            call._update(name=Variable(name=f'{call.name}_c'.lower()))
        # inline calls (to functions)
        callmap = {}
        for call in FindInlineCalls(unique=False).visit(routine.body):
            if call.routine is not BasicType.DEFERRED and (targets is None or call.name in as_tuple(targets)):
                callmap[call.function] = call.function.clone(name=f'{call.name}_c')
        routine.body = SubstituteExpressions(callmap).visit(routine.body)

    def generate_c_kernel_launch(self, kernel_launch, kernel, **kwargs):
        import_map = {}
        for im in FindNodes(Import).visit(kernel_launch.spec):
            import_map[im] = None
        kernel_launch.spec = Transformer(import_map).visit(kernel_launch.spec)

        kernel_call = kernel.clone()
        call_arguments = []
        for arg in kernel_call.arguments:
            call_arguments.append(arg)

        griddim = None
        blockdim = None
        if 'griddim' in kernel_launch.variable_map:
            griddim = kernel_launch.variable_map['griddim']
        if 'blockdim' in kernel_launch.variable_map:
            blockdim = kernel_launch.variable_map['blockdim']
        assignments = FindNodes(Assignment).visit(kernel_launch.body)
        griddim_assignment = None
        blockdim_assignment = None
        for assignment in assignments:
            if assignment.lhs == griddim:
                griddim_assignment = assignment.clone()
            if assignment.lhs == blockdim:
                blockdim_assignment = assignment.clone()
        kernel_launch.body = (Comment(text="! here should be the launcher ...."),
                griddim_assignment, blockdim_assignment, CallStatement(name=Variable(name=kernel.name),
                    arguments=call_arguments, chevron=(sym.Variable(name="griddim"),
                        sym.Variable(name="blockdim"))))
loki-ecmwf-0.3.6/loki/transformations/transpile/fortran_python.py0000664000175000017500000001034115167130205025556 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from pathlib import Path

from loki.backend import pygen, dacegen
from loki.batch import Transformation
from loki.expression import symbols as sym
from loki.ir import (
    nodes as ir, FindNodes, Transformer, pragmas_attached,
    FindInlineCalls, SubstituteExpressions
)
from loki.sourcefile import Sourcefile

from loki.transformations.array_indexing import (
    shift_to_zero_indexing, invert_array_indices
)
from loki.transformations.sanitise import do_resolve_associates
from loki.transformations.utilities import (
    convert_to_lower_case, replace_intrinsics
)


__all__ = ['FortranPythonTransformation']


class FortranPythonTransformation(Transformation):
    """
    A transformer class to convert Fortran to Python or DaCe.

    This :any:`Transformation` will generate Python code from a
    given Fortran routine, and if configured, annotate it with DaCe
    dataflow pragmas.

    Parameters
    ----------
    with_dace : bool
        Generate DaCe-specific Python code via :any:`dacegen` backend.
        This option implies inverted array indexing; default: ``False``
    invert_indices : bool
        Switch to C-style indexing (row-major) with fastest loop
        indices being used rightmost; default: ``False``
    suffix : str
        Optional suffix to append to converted routine names.
    """

    def __init__(self, **kwargs):
        self.with_dace = kwargs.pop('with_dace', False)
        self.invert_indices = kwargs.pop('invert_indices', False)
        self.suffix = kwargs.pop('suffix', '')

    def transform_subroutine(self, routine, **kwargs):
        path = Path(kwargs.get('path'))

        # Rename subroutine to generate Python kernel
        routine.name = f'{routine.name}{self.suffix}'.lower()

        # Remove all "IMPLICT" intrinsic statements
        mapper = {
            i: None for i in FindNodes(ir.Intrinsic).visit(routine.spec)
            if 'implicit' in i.text.lower()
        }
        routine.spec = Transformer(mapper).visit(routine.spec)

        # Force all variables to lower-caps, as Python is case-sensitive
        convert_to_lower_case(routine)

        # Resolve implicit struct mappings through "associates"
        do_resolve_associates(routine)

        # Do some vector and indexing transformations
        if self.with_dace or self.invert_indices:
            invert_array_indices(routine)
        shift_to_zero_indexing(routine)

        # We replace calls to intrinsic functions with their Python counterparts
        # Note that this substitution is case-insensitive, and therefore we have
        # this seemingly identity mapping to make sure Python function names are
        # lower-case
        intrinsic_map = {
            'min': 'min', 'max': 'max', 'abs': 'abs',
            'exp': 'np.exp', 'sqrt': 'np.sqrt',
        }
        replace_intrinsics(routine, function_map=intrinsic_map)

        # Sign intrinsic function takes a little more thought
        sign_map = {}
        for c in FindInlineCalls(unique=False).visit(routine.ir):
            if c.function == 'sign':
                assert len(c.parameters) == 2
                sign = sym.InlineCall(
                    function=sym.ProcedureSymbol(name='np.sign', scope=routine),
                    parameters=(c.parameters[1],)
                )
                sign_map[c] = sym.Product((c.parameters[0], sign))

        routine.spec = SubstituteExpressions(sign_map).visit(routine.spec)
        routine.body = SubstituteExpressions(sign_map).visit(routine.body)

        # Rename subroutine to generate Python kernel
        self.py_path = (path/routine.name.lower()).with_suffix('.py')
        self.mod_name = routine.name.lower()
        # Need to attach Loop pragmas to honour dataflow pragmas for loops
        with pragmas_attached(routine, ir.Loop):
            source = dacegen(routine) if self.with_dace else pygen(routine)
        Sourcefile.to_file(source=source, path=self.py_path)
loki-ecmwf-0.3.6/loki/transformations/transpile/fortran_iso_c_wrapper.py0000664000175000017500000005131615167130205027100 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from pathlib import Path
from collections import OrderedDict

from loki.backend import cgen, fgen, cudagen, cppgen
from loki.batch import Transformation, ProcedureItem, ModuleItem
from loki.expression import symbols as sym
from loki.ir import (
    nodes as ir, FindNodes, SubstituteExpressions, Transformer
)
from loki.function import Function
from loki.module import Module
from loki.sourcefile import Sourcefile
from loki.subroutine import Subroutine
from loki.types import BasicType, DerivedType, SymbolAttributes
from loki.tools import as_tuple

from loki.transformations.utilities import sanitise_imports


__all__ = [
    'c_intrinsic_kind', 'iso_c_intrinsic_import',
    'iso_c_intrinsic_kind', 'c_struct_typedef',
    'generate_iso_c_interface', 'generate_iso_c_wrapper_routine',
    'generate_iso_c_wrapper_module', 'generate_c_header',
    'FortranISOCWrapperTransformation'
]


class FortranISOCWrapperTransformation(Transformation):
    """
    Wrapper transformation that generates ISO-C Fortran wrappers and C
    headers for translated kernels or additional header modules.

    In addition to :any:`Subroutine` objects with the role
    ``'kernel'``, this transformation will process :any:`Module`
    objects with the role ``'header'``. This will generate ISO-C
    wrappers for derived types and the corresponding C-compatible
    structs in C header files.

    Parameters
    ----------
    use_c_ptr : bool, optional
        Use ``c_ptr`` for array declarations and ``c_loc(...)`` to
        pass the corresponding argument. Default is ``False``.
    language : string
        Actual C-style language to generate; must be on of ``'c'``,
        ``'cpp'`` or ``'cuda'`` for C, C++ and CUDA respectively.
    """

    item_filter = (ProcedureItem, ModuleItem)

    _supported_languages = ['c', 'cpp', 'cuda']

    def __init__(self, use_c_ptr=False, language='c'):
        self.use_c_ptr = use_c_ptr
        self.language = language.lower()

        if self.language == 'c':
            self.codegen = cgen
        elif self.language == 'cpp':
            self.codegen = cppgen
        elif self.language == 'cuda':
            self.codegen = cudagen
        else:
            raise ValueError(f'language "{self.language}" is not supported!'
                             f' (supported languages: "{self._supported_languages}")')

    def transform_module(self, module, **kwargs):
        if 'path' in kwargs:
            path = kwargs.get('path')
        else:
            build_args = kwargs.get('build_args')
            path = Path(build_args.get('output_dir'))

        role = kwargs.get('role', 'kernel')

        if role == 'header':
            # Generate Fortran wrapper module
            wrapper = generate_iso_c_wrapper_module(
                module, use_c_ptr=self.use_c_ptr, language=self.language
            )
            wrapperpath = (path/wrapper.name.lower()).with_suffix('.F90')
            Sourcefile.to_file(source=fgen(wrapper), path=wrapperpath)

            # Generate C header file from module
            c_header = generate_c_header(module)
            c_path = (path/c_header.name.lower()).with_suffix('.h')
            Sourcefile.to_file(source=self.codegen(c_header), path=c_path)


    def transform_subroutine(self, routine, **kwargs):
        if 'path' in kwargs:
            path = kwargs.get('path')
        else:
            build_args = kwargs.get('build_args')
            path = Path(build_args.get('output_dir'))

        role = kwargs.get('role', 'kernel')

        if role == 'kernel':
            c_structs = {}
            for arg in routine.arguments:
                if isinstance(arg.type.dtype, DerivedType):
                    c_structs[arg.type.dtype.name.lower()] = c_struct_typedef(arg.type, use_c_ptr=self.use_c_ptr)

            # Generate Fortran wrapper module
            bind_name = None if self.language in ['c', 'cpp'] else f'{routine.name.lower()}_c_launch'
            wrapper = generate_iso_c_wrapper_routine(
                routine, c_structs, bind_name=bind_name,
                use_c_ptr=self.use_c_ptr, language=self.language
            )
            contains = ir.Section(body=(ir.Intrinsic('CONTAINS'), wrapper))
            wrapperpath = (path/wrapper.name.lower()).with_suffix('.F90')
            module = Module(name=f'{wrapper.name.upper()}_MOD', contains=contains)
            module.spec = ir.Section(body=(ir.Import(module='iso_c_binding'),))
            Sourcefile.to_file(source=fgen(module), path=wrapperpath)


def c_intrinsic_kind(_type, scope):
    """
    Determine the intrinsic C-type for a given symbol table entry.

    Parameters
    ----------
    _type : :any:`SymbolAttr`
        The symbols type attribute to determine type and kind
    scope : :any:`Scope`
        The containing scope in which to clone the type symbol
    """
    if _type.dtype == BasicType.LOGICAL:
        return sym.Variable(name='int', scope=scope)
    if _type.dtype == BasicType.INTEGER:
        return sym.Variable(name='int', scope=scope)
    if _type.dtype == BasicType.REAL:
        kind = str(_type.kind)
        if kind.lower() in ('real32', 'c_float'):
            return sym.Variable(name='float', scope=scope)
        if kind.lower() in ('real64', 'jprb', 'selected_real_kind(13, 300)', 'c_double'):
            return sym.Variable(name='double', scope=scope)
    return None


def iso_c_intrinsic_import(scope, use_c_ptr=False):
    """
    Create :any:`Import` object for the intrinsic C base types.

    Parameters
    ----------
    scope : :any:`Scope`
        The scope in which to create the import node and type symbols.
    use_c_ptr : bool, optional
        Use ``c_ptr`` for array declarations and ``c_loc(...)`` to
        pass the corresponding argument. Default is ``False``.
    """
    import_symbols = ['c_int', 'c_double', 'c_float']
    if use_c_ptr:
        import_symbols += ['c_ptr', 'c_loc']
    symbols = as_tuple(sym.Variable(name=name, scope=scope) for name in import_symbols)
    isoc_import = ir.Import(module='iso_c_binding', symbols=symbols)
    return isoc_import


def iso_c_intrinsic_kind(_type, scope, is_array=False, use_c_ptr=False):
    """
    Determine the intrinsic ISO-C type for a given symbol table entry.

    Parameters
    ----------
    _type : :any:`SymbolAttr`
        The symbols type attribute to determine type and kind
    is_array : bool
        Flag indicating if the passed type belongs to an array symbol.
    use_c_ptr : bool, optional
        Use ``c_ptr`` for array declarations and ``c_loc(...)`` to
        pass the corresponding argument. Default is ``False``.
    """
    if _type.dtype == BasicType.INTEGER:
        return sym.Variable(name='c_int', scope=scope)

    if _type.dtype == BasicType.REAL:
        kind = str(_type.kind)
        if kind.lower() in ('real32', 'c_float'):
            return sym.Variable(name='c_float', scope=scope)
        if kind.lower() in ('real64', 'jprb', 'selected_real_kind(13, 300)', 'c_double', 'c_ptr'):
            if use_c_ptr and is_array:
                return sym.Variable(name='c_ptr', scope=scope)
            return sym.Variable(name='c_double', scope=scope)

    return None


def c_struct_typedef(derived, use_c_ptr=False):
    """
    Create the :class:`TypeDef` for the C-wrapped struct definition.

    Parameters
    ----------
    use_c_ptr : bool, optional
        Use ``c_ptr`` for array declarations and ``c_loc(...)`` to
        pass the corresponding argument. Default is ``False``.
    """
    typename = f'{derived.name if isinstance(derived, ir.TypeDef) else derived.dtype.name}_c'
    typedef = ir.TypeDef(name=typename.lower(), body=(), bind_c=True)  # pylint: disable=unexpected-keyword-arg
    if isinstance(derived, ir.TypeDef):
        variables = derived.variables
    else:
        variables = derived.dtype.typedef.variables
    declarations = []
    for v in variables:
        ctype = v.type.clone(kind=iso_c_intrinsic_kind(v.type, typedef, use_c_ptr=use_c_ptr))
        vnew = v.clone(name=v.basename.lower(), scope=typedef, type=ctype)
        declarations += (ir.VariableDeclaration(symbols=(vnew,)),)
    typedef._update(body=as_tuple(declarations))
    return typedef


def generate_iso_c_interface(routine, bind_name, c_structs, scope, use_c_ptr=False, language='c'):
    """
    Generate the ISO-C subroutine :any:`Interface` object for a given :any:`Subroutine`.

    Parameters
    ----------
    routine : :any:`Subroutine`
        The subroutine for which to generate the interface
    bind_name : str
        Name of the C-function to which this interface corresponds.
    c_structs : dict of str to str
        Map from Fortran derived type name to  C-struct type name
    scope : :any:`Scope`
        Parent scope in which to create the :any:`Interface`
    use_c_ptr : bool, optional
        Use ``c_ptr`` for array declarations and ``c_loc(...)`` to
        pass the corresponding argument. Default is ``False``.
    language : string
        C-style language to generate; if this is ``'c'``, we resolve
        non-C imports.
    """
    intf_name = f'{routine.name}_iso_c'
    intf_routine = Subroutine(name=intf_name, body=None, args=(), parent=scope, bind=bind_name)
    intf_spec = ir.Section(
        body=as_tuple(iso_c_intrinsic_import(intf_routine, use_c_ptr=use_c_ptr))
    )
    if language == 'c':
        for im in FindNodes(ir.Import).visit(routine.spec):
            if not im.c_import:
                im_symbols = tuple(s.clone(scope=intf_routine) for s in im.symbols)
                intf_spec.append(im.clone(symbols=im_symbols))
    intf_spec.append(ir.Intrinsic(text='implicit none'))
    intf_spec.append(c_structs.values())
    intf_routine.spec = intf_spec

    # Generate variables and types for argument declarations
    for arg in routine.arguments:
        if isinstance(arg.type.dtype, DerivedType):
            struct_name = c_structs[arg.type.dtype.name.lower()].name
            ctype = SymbolAttributes(DerivedType(name=struct_name), shape=arg.type.shape)
        else:
            # Only scalar, intent(in) arguments are pass by value
            # Pass by reference for array types
            value = isinstance(arg, sym.Scalar) and arg.type.intent.lower() == 'in' and not arg.type.optional
            kind = iso_c_intrinsic_kind(arg.type, intf_routine, is_array=isinstance(arg, sym.Array))
            if use_c_ptr:
                if isinstance(arg, sym.Array):
                    ctype = SymbolAttributes(DerivedType(name="c_ptr"), value=True, kind=None)
                else:
                    ctype = SymbolAttributes(arg.type.dtype, value=value, kind=kind)
            else:
                ctype = SymbolAttributes(arg.type.dtype, value=value, kind=kind)
        if use_c_ptr:
            dimensions = None
        else:
            dimensions = arg.dimensions if isinstance(arg, sym.Array) else None
        var = sym.Variable(name=arg.name, dimensions=dimensions, type=ctype, scope=intf_routine)
        intf_routine.variables += (var,)
        intf_routine.arguments += (var,)

    sanitise_imports(intf_routine)

    return ir.Interface(body=(intf_routine, ))


def generate_iso_c_wrapper_routine(routine, c_structs, bind_name=None, use_c_ptr=False, language='c'):
    """
    Generate Fortran ISO-C wrapper :any:`Subroutine` that corresponds
    to a transpiled C method.

    The new wrapper subroutine will have the suffix ``'_fc'`` appended
    to the name original subroutine name and bind to a C function with
    the suffix ``'_c'``.

    This method will call :meth:`generate_iso_c_interface` to generate
    the ISO-C compatible interface for the C function and generate a
    wrapper :any:`Subroutine` that converts the native Fortran arguments
    to a call to the C function with ISO-C compatible arguments.

    Parameters
    ----------
    routine : :any:`Subroutine`
        The subroutine for which to generate the interface
    c_structs : dict of str to str
        Map from Fortran derived type name to  C-struct type name
    bind_name : str
        Name of the C-function to which this interface corresponds.
    use_c_ptr : bool, optional
        Use ``c_ptr`` for array declarations and ``c_loc(...)`` to
        pass the corresponding argument. Default is ``False``.
    language : string
        C-style language to generate; if this is ``'c'``, we resolve
        non-C imports.
    """
    wrapper = Subroutine(name=f'{routine.name}_fc')

    if bind_name is None:
        bind_name = f'{routine.name.lower()}_c'
    interface = generate_iso_c_interface(
        routine, bind_name, c_structs, scope=wrapper, use_c_ptr=use_c_ptr, language=language
    )

    # Generate the wrapper function
    wrapper_spec = Transformer().visit(routine.spec)
    wrapper_spec.prepend(iso_c_intrinsic_import(wrapper, use_c_ptr=use_c_ptr))
    wrapper_spec.append(struct.clone(parent=wrapper) for struct in c_structs.values())
    wrapper_spec.append(interface)
    wrapper.spec = wrapper_spec

    # Create the wrapper function with casts and interface invocation
    local_arg_map = OrderedDict()
    casts_in = []
    casts_out = []
    for arg in routine.arguments:
        if isinstance(arg.type.dtype, DerivedType):
            ctype = SymbolAttributes(DerivedType(name=c_structs[arg.type.dtype.name.lower()].name))
            cvar = sym.Variable(name=f'{arg.name}_c', type=ctype, scope=wrapper)
            cast_in = sym.InlineCall(sym.ProcedureSymbol('transfer', scope=wrapper),
                                     parameters=(arg,), kw_parameters={'mold': cvar})
            casts_in += [ir.Assignment(lhs=cvar, rhs=cast_in)]

            cast_out = sym.InlineCall(sym.ProcedureSymbol('transfer', scope=wrapper),
                                      parameters=(cvar,), kw_parameters={'mold': arg})
            casts_out += [ir.Assignment(lhs=arg, rhs=cast_out)]
            local_arg_map[arg.name] = cvar

    arguments = tuple(local_arg_map[a] if a in local_arg_map else sym.Variable(name=a)
                      for a in routine.argnames)
    use_device_addr = []
    if use_c_ptr:
        arg_map = {}
        for arg in routine.arguments:
            if isinstance(arg, sym.Array):
                new_dims = tuple(sym.RangeIndex((None, None)) for _ in arg.dimensions)
                arg_map[arg] = arg.clone(dimensions=new_dims, type=arg.type.clone(target=True))
        routine.spec = SubstituteExpressions(arg_map).visit(routine.spec)

        call_arguments = []
        for arg in routine.arguments:
            if isinstance(arg, sym.Array):
                new_arg = arg.clone(dimensions=None)
                c_loc = sym.InlineCall(
                    function=sym.ProcedureSymbol(name="c_loc", scope=routine),
                    parameters=(new_arg,))
                call_arguments.append(c_loc)
                use_device_addr.append(arg.name)
            elif isinstance(arg.type.dtype, DerivedType):
                cvar = sym.Variable(name=f'{arg.name}_c', type=ctype, scope=wrapper)
                call_arguments.append(cvar)
            else:
                call_arguments.append(arg)
    else:
        call_arguments = arguments

    wrapper_body = casts_in
    if language in ['cuda', 'hip']:
        wrapper_body += [
            ir.Pragma(keyword='acc', content=f'host_data use_device({", ".join(use_device_addr)})')
        ]
    wrapper_body += [
        ir.CallStatement(name=sym.Variable(name=interface.body[0].name), arguments=call_arguments)
    ]
    if language in ['cuda', 'hip']:
        wrapper_body += [ir.Pragma(keyword='acc', content='end host_data')]
    wrapper_body += casts_out
    wrapper.body = ir.Section(body=as_tuple(wrapper_body))

    # Copy internal argument and declaration definitions
    wrapper.variables = tuple(arg.clone(scope=wrapper) for arg in routine.arguments) + tuple(local_arg_map.values())
    wrapper.arguments = tuple(arg.clone(scope=wrapper) for arg in routine.arguments)

    # Remove any unused imports
    sanitise_imports(wrapper)
    return wrapper


def generate_iso_c_wrapper_module(module, use_c_ptr=False, language='c'):
    """
    Generate the ISO-C wrapper module for a raw Fortran module.

    The new wrapper module will have the suffix ``'_fc'`` appended to
    the name and contain ISO-C function interfaces for contained
    :any:`Subroutine` objects. This method will call
    :meth:`generate_iso_c_routine` to generate the ISO-C compatible
    procedure interfaces.

    Note
    ----
    If the module contains global variables, we generate templated
    getter functions here, as global Fortran variables are not
    accessible via ISO-C interfaces.

    Parameters
    ----------
    module : :any:`Module`
        The module for which to generate the interface module
    use_c_ptr : bool, optional
        Use ``c_ptr`` for array declarations and ``c_loc(...)`` to
        pass the corresponding argument. Default is ``False``.
    language : string
        C-style language to generate; if this is ``'c'``, we resolve
        non-C imports.
    """
    modname = f'{module.name}_fc'
    wrapper_module = Module(name=modname)

    # Create getter methods for module-level variables (I know... :( )
    if language == 'c':
        wrappers = []
        for decl in FindNodes(ir.VariableDeclaration).visit(module.spec):
            for v in decl.symbols:
                if isinstance(v.type.dtype, DerivedType) or v.type.pointer or v.type.allocatable:
                    continue
                gettername = f'{module.name.lower()}__get__{v.name.lower()}'
                getter = Function(name=gettername, bind=gettername, parent=wrapper_module)

                getter.spec = ir.Section(
                    body=(ir.Import(module=module.name, symbols=(v.clone(scope=getter), )), )
                )
                isoctype = SymbolAttributes(
                    v.type.dtype, kind=iso_c_intrinsic_kind(v.type, getter, use_c_ptr=use_c_ptr)
                )
                if isoctype.kind in ['c_int', 'c_float', 'c_double']:
                    getter.spec.append(ir.Import(module='iso_c_binding', symbols=(isoctype.kind, )))
                getter.body = ir.Section(
                    body=(ir.Assignment(lhs=sym.Variable(name=gettername, scope=getter), rhs=v),)
                )
                getter.variables = as_tuple(sym.Variable(name=gettername, type=isoctype, scope=getter))
                wrappers += [getter]
        wrapper_module.contains = ir.Section(body=(ir.Intrinsic('CONTAINS'), *wrappers))

    # Remove any unused imports
    sanitise_imports(wrapper_module)
    return wrapper_module


def generate_c_header(module):
    """
    Re-generate the C header as a module with all pertinent nodes,
    but not Fortran-specific intrinsics (eg. implicit none or save).

    The new header module will have the suffix ``'_c'`` appended to
    the original module name.

    Parameters
    ----------
    module : :any:`Module`
        The module for which to generate the C header
    """
    header_module = Module(name=f'{module.name}_c')

    # Generate stubs for getter functions
    spec = []
    for decl in FindNodes(ir.VariableDeclaration).visit(module.spec):
        assert len(decl.symbols) == 1
        v = decl.symbols[0]
        # Bail if not a basic type
        if isinstance(v.type.dtype, DerivedType):
            continue
        ctype = c_intrinsic_kind(v.type, scope=module)
        tmpl_function = f'{ctype} {module.name.lower()}__get__{v.name.lower()}();'
        spec += [ir.Intrinsic(text=tmpl_function)]

    # Re-create type definitions with range indices (``:``) replaced by pointers
    for td in FindNodes(ir.TypeDef).visit(module.spec):
        header_td = ir.TypeDef(name=td.name.lower(), body=(), parent=header_module)  # pylint: disable=unexpected-keyword-arg
        declarations = []
        for decl in td.declarations:
            variables = []
            for v in decl.symbols:
                # Note that we force lower-case on all struct variables
                if isinstance(v, sym.Array):
                    new_shape = as_tuple(d for d in v.shape if not isinstance(d, sym.RangeIndex))
                    new_type = v.type.clone(shape=new_shape)
                    variables += [v.clone(name=v.name.lower(), type=new_type, scope=header_td)]
                else:
                    variables += [v.clone(name=v.name.lower(), scope=header_td)]
            declarations += [ir.VariableDeclaration(
                symbols=as_tuple(variables), dimensions=decl.dimensions,
                comment=decl.comment, pragma=decl.pragma
            )]
        header_td._update(body=as_tuple(declarations))
        spec += [header_td]

    header_module.spec = spec
    header_module.rescope_symbols()
    return header_module
loki-ecmwf-0.3.6/loki/transformations/transform_derived_types.py0000664000175000017500000010763115167130205025453 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Transformations dealing with derived types in subroutines and
derived-type arguments in complex calling structures.

 * DerivedTypeArgumentsTransformation:
        Transformation to resolve array-of-structure (AOS) uses of derived-type
        variables to explicitly expose arrays from which to hoist dimensions.
"""

from collections import defaultdict

from loki.batch import Transformation
from loki.expression import (
    InlineCall, Variable, RangeIndex, ExpressionRetriever,
    SubstituteExpressionsMapper
)
from loki.ir import (
    Import, CallStatement, ProcedureDeclaration, Transformer,
    FindNodes, FindInlineCalls, FindVariables, SubstituteExpressions
)
from loki.logging import warning, debug
from loki.module import Module
from loki.tools import as_tuple, flatten, CaseInsensitiveDict, OrderedSet
from loki.types import BasicType, DerivedType, ProcedureType

from loki.transformations.utilities import recursive_expression_map_update


__all__ = ['DerivedTypeArgumentsTransformation', 'TypeboundProcedureCallTransformation']


class DerivedTypeArgumentsTransformation(Transformation):
    """
    Remove derived types from procedure signatures by replacing the
    (relevant) derived type arguments by its member variables

    .. note::

       This transformation requires a Scheduler traversal that
       processes callees before callers.

    On the caller side, this updates calls to transformed subroutines
    and functions by passing the relevant derived type member variables
    instead of the original derived type argument. This uses information
    from previous application of this transformation to the called
    procedure.

    On the callee side, this identifies derived type member variable
    usage, builds an expansion mapping, updates the procedure's
    signature accordingly, and substitutes the variable's use inside
    the routine. The information about the expansion map is stored
    in the :any:`Item`'s ``trafo_data``.
    See :meth:`expand_derived_args_kernel` for more information.

    Parameters
    ----------
    all_derived_types : bool, optional
        Whether to remove all derived types from procedure signatures by
        replacing the derived type arguments using its member variables or
        only the "relevant" ones, referring to derived types with array
        members or nested derived types (default: `False`).
    key : str, optional
        Overwrite the key that is used to store analysis results in ``trafo_data``.
    """

    _key = 'DerivedTypeArgumentsTransformation'
    """Default identifier for trafo_data entry"""

    reverse_traversal = True
    """Traversal from the leaves upwards"""

    def __init__(self, all_derived_types=False, key=None, **kwargs):
        self.all_derived_types = all_derived_types
        if key is not None:
            self._key = key
        super().__init__(**kwargs)

    def transform_subroutine(self, routine, **kwargs):
        role = kwargs.get('role')
        item = kwargs.get('item')

        # Initialize the transformation data dictionary
        if item:
            item.trafo_data[self._key] = {}

        # Extract expansion maps and argument re-mapping for successors
        sub_sgraph = kwargs.get('sub_sgraph', None)
        successors = as_tuple(sub_sgraph.successors(item)) if sub_sgraph is not None else ()
        successors = [child for child in successors if self._key in child.trafo_data]

        # Create a map that accounts for potential renaming of successors upon import,
        # which can lead to calls having a different name than the successor item they
        # correspond to
        renamed_import_map = {
            import_.module.lower(): {
                s.type.use_name.lower(): s.name.lower()
                for s in import_.symbols if s.type.use_name
            }
            for import_ in routine.imports + getattr(routine.parent, 'imports', ())
        }
        successors_data = CaseInsensitiveDict(
            (
                renamed_import_map.get(child.scope_name, {}).get(child.local_name, child.local_name),
                child.trafo_data[self._key]
            )
            for child in successors
        )

        # Apply caller transformation first to update calls to successors...
        self.expand_derived_args_caller(routine, successors_data)

        # ...before updating the routine's signature and replacing
        # use of members in the body
        if role == 'kernel':
            # Expand derived type arguments in kernel...
            trafo_data = self.expand_derived_args_kernel(routine)
            if item:
                item.trafo_data[self._key] = trafo_data

            # ...and make sure missing symbols are imported...
            self.add_new_imports_kernel(routine, trafo_data)

            # For recursive routines, we have to update calls to itself
            if any('recursive' in prefix.lower() for prefix in routine.prefix or ()):
                self.expand_derived_args_recursion(routine, trafo_data)


    def expand_derived_args_caller(self, routine, successors_data):
        """
        For all active :any:`CallStatement` nodes, apply the derived type argument
        expansion on the caller side.

        The convention used is: ``derived%var => derived_var``.

        Parameters
        ----------
        routine : :any:`Subroutine`
            The routine in which to transform call statements
        successors_data : :any:`CaseInsensitiveDict` of (str, dict)
            Dictionary containing the expansion maps (key ``'expansion_map'``) and
            original argnames (key ``'orig_argnames'``) of every child routine

        Returns
        -------
        bool
            Flag to indicate that dependencies have been changed (e.g. via new imports)
        """
        call_mapper = {}
        for call in FindNodes(CallStatement).visit(routine.body):
            if call.not_active:
                continue
            call_name = str(call.name)
            if call_name in successors_data:
                # Set the new call signature on the IR node
                arguments, kwarguments  = self.expand_call_arguments(call, successors_data[call_name])
                call_mapper[call] = call.clone(arguments=arguments, kwarguments=kwarguments)

        # Rebuild the routine's IR tree
        if call_mapper:
            routine.body = Transformer(call_mapper).visit(routine.body)

        call_mapper = {}
        for call in FindInlineCalls().visit(routine.body):
            if (call_name := str(call.name)) in successors_data:
                # Set the new call signature on the expression node
                arguments, kwarguments = self.expand_call_arguments(call, successors_data[call_name])
                call_mapper[call] = call.clone(parameters=arguments, kw_parameters=kwarguments)

        if call_mapper:
            routine.body = SubstituteExpressions(call_mapper).visit(routine.body)

    @staticmethod
    def _expand_relative_to_local_var(local_var, expansion_components):
        """
        Utility routine that returns an expanded (nested) derived type argument
        relative to the local derived type variable declared on caller side

        Example: A subroutine that previously accepted a derived type argument
        ``some_arg`` uses only a member variable ``some_arg%nested_thing%var``,
        which is now replaced in the procedure interface by
        ``some_arg_nested_thing_var``. On the caller side, ``local_var`` is the
        instance of the derived type argument that used to be passed to the
        subroutine call. This utility routine determines and returns
        the new call argument ``local_var%nested_thing%var`` that has to be
        passed instead.
        """
        # We build the name bit-by-bit to obtain nested derived type arguments
        # relative to the local derived type variable
        for child in expansion_components:
            local_var = child.clone(
                name=f'{local_var.name}%{child.name}',
                parent=local_var,
                scope=local_var.scope
            )
        return local_var

    @classmethod
    def _expand_call_argument(cls, caller_arg, expansion_list):
        """
        Utility routine to expand :data:`caller_arg` in a subroutine call :data:`call`
        according to the provided :data:`expansion_list` and original arguments of the
        call target, as given in :data:`orig_argnames

        It returns a list of new arguments.
        """
        arguments = [
            cls._expand_relative_to_local_var(caller_arg, [*member.parents[1:], member])
            for member in expansion_list
        ]
        return arguments

    @classmethod
    def expand_call_arguments(cls, call, successor_data):
        """
        Create the call's argument list with derived type arguments expanded

        Parameters
        ----------
        call : :any:`CallStatement`
            The call statement to process
        successor_data : dict
            Dictionary containing the expansion map (key ``'expansion_map'``) and
            original argnames (key ``'orig_argnames'``) of the called routine

        Returns
        -------
        (tuple, tuple) :
            The argument and keyword argument list with derived type arguments expanded
        """
        expansion_map = successor_data['expansion_map']
        orig_argnames = successor_data['orig_argnames']

        arguments = []
        for kernel_argname, caller_arg in zip(orig_argnames, call.arguments):
            if kernel_argname in expansion_map:
                arguments += cls._expand_call_argument(caller_arg, expansion_map[kernel_argname])
            else:
                arguments += [caller_arg]

        kwarguments = []
        for kernel_argname, caller_arg in call.kwarguments:
            if kernel_argname in expansion_map:
                expanded_arguments = cls._expand_call_argument(caller_arg, expansion_map[kernel_argname])
                kwarguments += [
                    (cls._expand_kernel_variable(kernel_arg).name, caller_arg)
                    for kernel_arg, caller_arg in zip(expansion_map[kernel_argname], expanded_arguments)
                ]
            else:
                kwarguments += [(kernel_argname, caller_arg)]

        return as_tuple(arguments), as_tuple(kwarguments)

    @staticmethod
    def _expand_kernel_variable(var, **kwargs):
        """
        Utility routine that yields the expanded variable in the
        kernel for a given derived type variable member use :data:`var`
        """
        new_name = var.name.replace('%', '_')
        return var.clone(name=new_name, parent=None, **kwargs)

    @staticmethod
    def _get_expanded_kernel_var_type(arg, var):
        """
        Utility routine that yields the variable type for an expanded kernel variable
        """
        return var.type.clone(
            intent=arg.type.intent, initial=None, allocatable=None,
            target=arg.type.target if not var.type.pointer else None
        )

    def expand_derived_args_kernel(self, routine):
        """
        Find the use of member variables for derived type arguments of
        :data:`routine`, update the call signature to directly pass the
        variable and substitute its use in the routine's body.

        Note that this will only carry out replacements for derived types
        that contain an allocatable, pointer, or nested derived type member.

        See :meth:`expand_derived_type_member` for more details on how
        the expansion is performed.
        """
        trafo_data = {'orig_argnames': tuple(arg.lower() for arg in routine.argnames)}

        # All derived type arguments are candidates for expansion
        candidates = []
        for arg in routine.arguments:
            if isinstance(arg.type.dtype, DerivedType):
                if self.all_derived_types or any(v.type.pointer or v.type.allocatable or
                       isinstance(v.type.dtype, DerivedType) for v in as_tuple(arg.variables)):
                    # Only include derived types with array members or nested derived types
                    # unless self.all_derived_types is True
                    candidates += [arg]

        # Inspect all derived type member use and determine their expansion
        vars_to_expand = [var for var in FindVariables(unique=False).visit(routine.ir) if var.parent]
        nested_parents = [var.parent for var in vars_to_expand if var.parent in vars_to_expand]
        vars_to_expand = [var for var in vars_to_expand if var not in nested_parents]

        expansion_map = defaultdict(OrderedSet)
        non_expansion_map = defaultdict(OrderedSet)
        vmap = {}
        for var in vars_to_expand:
            declared_var, expansion, local_use = self.expand_derived_type_member(var)
            if expansion and declared_var in candidates:
                # Mark this derived type member for expansion
                expansion_map[declared_var].add(expansion)
                vmap[var] = local_use
            elif declared_var in candidates:
                non_expansion_map[declared_var].add(var)

        # Update the expansion map by re-adding the derived type argument when
        # there are non-expanded members left
        # Here, we determine the ordering in the updated call signature
        expansion_map = dict(expansion_map)
        for arg in candidates:
            if arg in expansion_map:
                sorted_expansion = sorted(expansion_map[arg], key=lambda v: str(v).lower())
                if arg in non_expansion_map:
                    expansion_map[arg] = (arg, *sorted_expansion)
                else:
                    expansion_map[arg] = tuple(sorted_expansion)

        def assumed_dim_or_none(shape):
            if not shape:
                return None
            return tuple(RangeIndex((None, None)) for _ in shape)

        # Build the arguments map to update the call signature
        arguments_map = {}
        for arg in routine.arguments:
            if arg in expansion_map:
                arguments_map[arg] = [
                    self._expand_kernel_variable(
                        var, type=self._get_expanded_kernel_var_type(arg, var),
                        dimensions=assumed_dim_or_none(var.type.shape), scope=routine
                    )
                    for var in expansion_map[arg]
                ]

        # Update arguments list
        routine.arguments = [a for arg in routine.arguments for a in arguments_map.get(arg, [arg])]

        # Update variable list, too, as this triggers declaration generation
        routine.variables = [v for var in routine.variables for v in arguments_map.get(var, [var])]

        # Substitue derived type member use in the spec and body
        vmap = recursive_expression_map_update(vmap)
        routine.spec = SubstituteExpressions(vmap).visit(routine.spec)
        routine.body = SubstituteExpressions(vmap).visit(routine.body)

        # Update procedure bindings by specifying NOPASS attribute
        for arg in arguments_map:
            for decl in arg.type.dtype.typedef.declarations:
                if isinstance(decl, ProcedureDeclaration) and not decl.generic:
                    for proc in decl.symbols:
                        if routine.name == proc or routine.name in as_tuple(proc.type.bind_names):
                            proc.type = proc.type.clone(pass_attr=False)

        trafo_data['expansion_map'] = expansion_map
        return trafo_data

    @classmethod
    def expand_derived_type_member(cls, var):
        """
        Determine the member expansion for a derived type member variable

        For a derived type member variable, provided as :data:`var`, this determines
        the name of the root parent and the member expansion.

        A few examples to illustrate the behaviour, with the Fortran variable use
        that :data:`var` represents in the left column and corresponding return value
        of this routine on the right:

        .. code-block::

            var name            | return value (parent_name, expansion, new use)   | remarks
           ---------------------+--------------------------------------------------+------------------------------------
            SOME_VAR            | ('some_var', None, None)                         | No expansion
            SOME%VAR            | ('some', 'some%var', 'some_var')                 |
            ARRAY(5)%VAR        | ('array', None, None)                            | Can't expand array of derived types
            SOME%NESTED%VAR     | ('some', 'some%nested%var', 'some_nested_var)    |
            NESTED%ARRAY(I)%VAR | ('nested', 'nested%array', 'nested_array(i)%var')| Partial expansion

        Parameters
        ----------
        var : :any:`MetaSymbol`
            The use of a derived type member

        Returns
        -------
        (:any:`Variable`, :any:`Variable` or None, :any:`Variable` or None)
        """
        parents = var.parents
        if not parents:
            return var, None, None

        # We unroll the derived type member as far as possible, stopping at
        # the occurence of an intermediate derived type array.
        # Note that we set scope=None, which detaches the symbol from the current
        # scope and stores the type information locally on the symbol. This makes
        # them available later on without risking losing this information due to
        # intermediate rescoping operations
        for idx, parent in enumerate(parents):
            if hasattr(parent, 'dimensions'):
                expansion = parent.clone(scope=None, dimensions=None)
                if parent is parents[0]:
                    debug(f'Array of derived types {var!s}. Cannot expand argument.')
                    local_use = var
                else:
                    debug(f'Array of derived types {var!s}. '
                        f'Can only partially expand argument as {expansion!s}.')
                    local_use = cls._expand_kernel_variable(parent)
                    local_use = cls._expand_relative_to_local_var(local_use, [*parents[idx+1:], var])
                return parents[0], expansion, local_use

        # None of the parents had a dimensions attribute, which means we can
        # completely expand
        expansion = var.clone(scope=None, dimensions=None)
        local_use = cls._expand_kernel_variable(var)

        return parents[0], expansion, local_use

    @staticmethod
    def _get_imports_for_expr(expr, scope, symbol_map):
        """
        Helper utility to build the list of symbols per module in an expression :data:`expr` that do
        not exist in the current :data:`scope`
        """
        def _warn(symbol):
            warning((
                '[Loki::DerivedTypeArgumentsTransformation] '
                f'Cannot insert import for symbol "{symbol.name}" in {scope.name}. No type information available.'
            ))

        new_imports = defaultdict(OrderedSet)
        for symbol in FindVariables().visit(expr):
            if symbol.name in symbol_map:
                continue

            if symbol.type.imported:
                # This new symbol had been imported for use in the typedef
                if not symbol.type.module:
                    _warn(symbol)
                else:
                    new_imports[symbol.type.module.name.lower()].add(symbol.clone(scope=scope))

            elif (symbol_scope := symbol.scope):
                # This new symbol has been declared in the symbol_scope we inherited it from
                while symbol_scope.parent:
                    symbol_scope = symbol_scope.parent
                new_imports[symbol_scope.name.lower()].add(
                    symbol.clone(scope=scope, type=symbol.type.clone(imported=True))
                )

            else:
                _warn(symbol)

        return new_imports

    @classmethod
    def add_new_imports_kernel(cls, routine, trafo_data):
        """
        Inspect the expansion map in :data:`trafo_data` for new symbols that need to be imported
        as a result of flattening a derived type and add the corresponding imports
        """
        new_arguments = flatten(trafo_data['expansion_map'].values())
        symbol_map = routine.parent.symbol_map if routine.parent else {}
        symbol_map.update(routine.symbol_map)

        # Check for derived types, kind, or shape dimensions declared via parameters among new arguments
        new_imports = defaultdict(OrderedSet)
        for arg in new_arguments:
            type_ = arg.type
            if isinstance(type_.dtype, DerivedType) and type_.dtype.name not in symbol_map:
                typedef = type_.dtype.typedef
                if typedef is BasicType.DEFERRED:
                    warning((
                        '[Loki::DerivedTypeArgumentsTransformation] '
                        f'Cannot insert import for derived type "{type_.dtype.name}" in {routine.name}. '
                        'No type information available.'
                    ))
                elif typedef.parent is not routine.parent:
                    # Derived type needs to be imported
                    new_imports[typedef.parent.name.lower()].add(
                        Variable(name=type_.dtype.name, scope=routine, type=type_.clone(imported=True))
                    )

            if type_.kind:
                for module, symbols in cls._get_imports_for_expr(type_.kind, routine, symbol_map).items():
                    new_imports[module] |= symbols

            if getattr(arg, 'dimensions', None):
                for module, symbols in cls._get_imports_for_expr(arg.dimensions, routine, symbol_map).items():
                    new_imports[module] |= symbols

        if new_imports:
            new_imports = tuple(
                Import(module=module, symbols=as_tuple(symbols))
                for module, symbols in new_imports.items()
            )
            routine.spec.prepend(new_imports)

    @classmethod
    def expand_derived_args_recursion(cls, routine, trafo_data):
        """
        Find recursive calls to itcls and apply the derived args flattening
        to these calls
        """
        def _update_call(call):
            # Expand the call signature first
            arguments, kwarguments = cls.expand_call_arguments(call, trafo_data)
            # And expand the derived type members in the new call signature next
            expansion_map = {}
            vars_to_expand = {var for var in FindVariables().visit((arguments, kwarguments)) if var.parent}
            nested_parents = {var.parent for var in vars_to_expand if var.parent in vars_to_expand}
            vars_to_expand -= nested_parents
            for var in vars_to_expand:
                orig_arg = var.parents[0]
                expanded_var = cls._expand_kernel_variable(
                    var, type=cls._get_expanded_kernel_var_type(orig_arg, var), scope=routine, dimensions=None
                )
                expansion_map[var] = expanded_var
            expansion_mapper = SubstituteExpressionsMapper(recursive_expression_map_update(expansion_map))
            arguments = tuple(expansion_mapper(arg) for arg in arguments)
            kwarguments = tuple((k, expansion_mapper(v)) for k, v in kwarguments)
            return arguments, kwarguments

        # Deal with subroutine calls first
        call_mapper = {}
        for call in FindNodes(CallStatement).visit(routine.body):
            if str(call.name).lower() == routine.name.lower():
                arguments, kwarguments = _update_call(call)
                call_mapper[call] = call.clone(arguments=arguments, kwarguments=kwarguments)

        # Rebuild the routine's IR tree
        if call_mapper:
            routine.body = Transformer(call_mapper).visit(routine.body)

        # Deal with inline calls next
        call_mapper = {}
        for call in FindInlineCalls().visit(routine.body):
            if str(call.name).lower() == routine.name.lower():
                arguments, kwarguments = _update_call(call)
                call_mapper[call] = call.clone(parameters=arguments, kw_parameters=kwarguments)

        # Rebuild the routine's IR tree with expression substitution
        if call_mapper:
            routine.body = SubstituteExpressions(call_mapper).visit(routine.body)


def get_procedure_symbol_from_typebound_procedure_symbol(proc_symbol, routine_name):
    """
    Utility routine that returns the :any:`ProcedureSymbol` of the :any:`Subroutine`
    that a typebound procedure corresponds to.

    .. warning::
       Resolving generic bindings is currently not implemented

    This uses binding information (such as ``proc_symbol.type.bind_names``) or the
    :any:`TypeDef` to resolve the procedure binding. If the type information is
    incomplete or the resolution fails for other reasons, ``None`` is returned.

    Parameters
    ----------
    proc_symbol : :any:`ProcedureSymbol`
        The typebound procedure symbol that is to be resolved
    routine_name : str
        The name of the routine :data:`proc_symbol` appears in. This is used for
        logging purposes only

    Returns
    -------
    :any:`ProcedureSymbol` or None
        The procedure symbol of the :any:`Subroutine` or ``None`` if it fails to resolve
    """
    if proc_symbol.type.bind_names is not None:
        return proc_symbol.type.bind_names[0]

    parent = proc_symbol.parents[0]
    if parent.type.dtype.typedef is not BasicType.DEFERRED:
        # Fiddle our way through derived type nesting until we obtain the symbol corresponding
        # to the procedure-binding in the TypeDef
        local_parent = None
        local_var = parent
        try:
            for local_name in proc_symbol.name_parts[1:]:
                local_parent = local_var
                local_var = local_var.type.dtype.typedef.variable_map[local_name]
        except AttributeError:
            warning('Type definitions incomplete for %s in %s', proc_symbol, routine_name)
            return None

        if local_var.type.dtype.is_generic:
            warning('Cannot resolve generic binding %s (not implemented) in %s', proc_symbol, routine_name)
            return None

        if local_var.type.bind_names is not None:
            # Although this should have ben taken care of by the first if branch,
            # this may trigger here when the bind_names property hasn't been imported
            # into the local symbol table
            new_name = local_var.type.bind_names[0]
        else:
            # If the binding doesn't have any specific bind_names, this means the
            # corresponding subroutine has the same name and should be declared
            # in the same module as the typedef
            new_name = Variable(name=local_var.name, scope=local_parent.type.dtype.typedef.parent)
        return new_name

    # We don't have any binding information available
    return None


class TypeboundProcedureCallTransformer(Transformer):
    """
    Transformer to carry out the replacement of subroutine and inline function
    calls to typebound procedures by direct calls to the respective procedures

    During the transformer pass, this identifies also new dependencies due to
    inline function calls, which the :any:`Scheduler` may not be able to
    discover otherwise at the moment.

    Parameters
    ----------
    routine_name : str
        The name of the :any:`Subroutine` the replacement takes place. This is used
        for logging purposes only.
    current_module : str
        The name of the enclosing module. This is used to determine whether the
        resolved procedure needs to be added as an import.

    Attributes
    ----------
    new_procedure_imports : dict
        After a transformer pass, this will contain the mapping
        ``{module: {proc_name, proc_name, ...}}`` for new imports that are required
        as a consequence of the replacement.
    """

    def __init__(self, routine_name, current_module, **kwargs):
        super().__init__(inplace=True, **kwargs)
        self.routine_name = routine_name
        self.current_module = current_module
        self.new_procedure_imports = defaultdict(OrderedSet)
        self._retriever = ExpressionRetriever(lambda e: isinstance(e, InlineCall) and e.function.parent)

    def retrieve(self, o):
        return self._retriever.retrieve(o)

    def visit_CallStatement(self, o, **kwargs):
        """
        Rebuild a :any:`CallStatement`

        If this is a call to a typebound procedure, resolve the procedure binding and
        insert the derived type as the first argument in the call statement.
        """
        rebuilt = {k: self.visit(c, **kwargs) for k, c in zip(o._traversable, o.children)}
        if rebuilt['name'].parent:
            new_proc_symbol = get_procedure_symbol_from_typebound_procedure_symbol(rebuilt['name'], self.routine_name)

            if new_proc_symbol:
                # Add the derived type as first argument to the call
                rebuilt['arguments'] = (rebuilt['name'].parent, ) + rebuilt['arguments']

                # Add the subroutine to the list of symbols that need to be imported
                if isinstance(new_proc_symbol.scope, Module):
                    module_name = new_proc_symbol.scope.name.lower()
                else:
                    module_name = new_proc_symbol.type.dtype.procedure.procedure_symbol.scope.name.lower()

                if module_name != self.current_module:
                    self.new_procedure_imports[module_name].add(new_proc_symbol.name.lower())

                rebuilt['name'] = new_proc_symbol
        children = [rebuilt[k] for k in o._traversable]
        return self._rebuild(o, children)

    def visit_Expression(self, o, **kwargs):
        """
        Return the expression unchanged unless there are :any:`InlineCall` nodes in the expression
        that are calls to typebound procedures, which are replaced by direct calls to the function
        with the derived type added as the first argument.
        """
        inline_calls = self.retrieve(o)
        if not inline_calls:
            return o

        expr_map = {}
        for call in inline_calls:
            new_proc_symbol = get_procedure_symbol_from_typebound_procedure_symbol(call.function, self.routine_name)

            if new_proc_symbol:
                new_arguments = (call.function.parent,) + call.parameters
                expr_map[call] = call.clone(
                    function=new_proc_symbol.rescope(scope=kwargs['scope']),
                    parameters=new_arguments
                )
                # Add the function to the list of symbols that need to be imported
                if isinstance(new_proc_symbol.scope, Module):
                    module_name = new_proc_symbol.scope.name.lower()
                else:
                    module_name = new_proc_symbol.type.dtype.procedure.procedure_symbol.scope.name.lower()

                if module_name != self.current_module:
                    self.new_procedure_imports[module_name].add(new_proc_symbol.name.lower())

        if not expr_map:
            return o

        expr_map = recursive_expression_map_update(expr_map)
        return SubstituteExpressionsMapper(expr_map)(o)


class TypeboundProcedureCallTransformation(Transformation):
    """
    Replace calls to type-bound procedures with direct calls to the
    corresponding subroutines/functions

    Instead of calling a type-bound procedure, e.g. ``CALL my_type%proc``,
    it is possible to import the bound procedure and call it directly, with
    the derived type as first argument, i.e. ``CALL proc(my_type)``.
    This transformation replaces all calls to type-bound procedures accordingly
    and inserts necessary imports.

    Also, for some compilers these direct calls seem to require an explicit
    ``INTENT`` specification on the polymorphic derived type dummy argument,
    which is set to `INOUT` by default, if missing. This behaviour can be switched
    off by setting :data:`fix_intent` to `False`.

    Parameters
    ----------
    duplicate_typebound_kernels : bool
        Optionally, create a copy of unchanged routines before flattening calls to
        typebound procedures, and update the procedure binding to point to the
        unchanged copy.
    fix_intent : bool
        Update intent on polymorphic dummy arguments missing an intent as ``INOUT``.
    """

    def __init__(self, duplicate_typebound_kernels=False, fix_intent=True, **kwargs):
        super().__init__(**kwargs)
        self.duplicate_typebound_kernels = duplicate_typebound_kernels
        self.fix_intent = fix_intent

    def apply_default_polymorphic_intent(self, routine):
        """
        Utility routine to set a default ``INTENT(INOUT)`` on polymorphic dummy
        arguments (i.e. declared via ``CLASS``) that don't have an explicit intent
        """
        for arg in routine.arguments:
            type_ = arg.type
            if type_.polymorphic and not type_.intent:
                arg.type = type_.clone(intent='inout')

    def transform_subroutine(self, routine, **kwargs):
        """
        Apply the transformation of calls to the given :data:`routine`
        """
        role = kwargs.get('role')

        # Fix any wrong intents on polymorphic arguments
        # (sadly, it's not uncommon to omit the intent specification on the CLASS declaration,
        # so we set them to `inout` here for any missing intents)
        if self.fix_intent:
            self.apply_default_polymorphic_intent(routine)

        if routine.parent:
            current_module = routine.parent.name.lower()
        else:
            current_module = None

        # Check if this routine is a typebound routine and, if it is, create a duplicate of
        # the original routine before applying the transformation
        is_duplicate_kernels = (
            self.duplicate_typebound_kernels and role == 'kernel' and isinstance(routine.parent, Module)
        )
        if is_duplicate_kernels:
            typedefs = routine.parent.typedefs
            proc_binding_update_maps = {}
            for tdef in typedefs:
                proc_binding_update_maps[tdef.name] = {}
                for var in tdef.variables:
                    if not isinstance(var.type.dtype, ProcedureType):
                        continue
                    if (
                        (var.type.bind_names and routine.name in var.type.bind_names) or
                        (not var.type.bind_names and var == routine.name)
                    ):
                        # Create a duplicate routine
                        new_routine = routine.clone(name=f'{routine.name}_', rescope_symbols=True)
                        # Update result name if this is a function
                        routine.parent.contains.append(new_routine)
                        # Update the procedure binding
                        new_type = var.type.clone(bind_names=(new_routine.procedure_symbol,))
                        proc_binding_update_maps[tdef.name][var.name] = new_type

        # Traverse the routine's body and replace all calls to typebound procedures by
        # direct calls to the procedures they refer to
        transformer = TypeboundProcedureCallTransformer(routine.name, current_module)
        routine.body = transformer.visit(routine.body, scope=routine)
        new_procedure_imports = transformer.new_procedure_imports

        # Add missing imports
        imported_symbols = routine.imported_symbols
        new_imports = []
        for module_name, proc_symbols in new_procedure_imports.items():
            new_symbols = tuple(Variable(name=s, scope=routine) for s in proc_symbols if s not in imported_symbols)
            if new_symbols:
                new_imports += [Import(module=module_name, symbols=new_symbols)]

        if new_imports:
            routine.spec.prepend(as_tuple(new_imports))

        # Update the procedure bindings in the typedefs
        if is_duplicate_kernels:
            for tdef in typedefs:
                for var_name, new_type in proc_binding_update_maps[tdef.name].items():
                    tdef.symbol_attrs[var_name] = new_type
loki-ecmwf-0.3.6/loki/transformations/single_column/0000775000175000017500000000000015167130205022766 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/transformations/single_column/demote.py0000664000175000017500000001461715167130205024626 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.batch import Transformation
from loki.expression import is_dimension_constant, Array
from loki.ir import nodes as ir, FindNodes, FindInlineCalls
from loki.tools import flatten, as_tuple, OrderedSet

from loki.transformations.array_indexing import demote_variables
from loki.transformations.utilities import get_local_arrays


__all__ = ['SCCDemoteTransformation']


class SCCDemoteTransformation(Transformation):
    """
    A set of utilities to determine which local arrays can be safely demoted in a
    :any:`Subroutine` as part of a transformation pass.

    Unless the option `demote_local_arrays` is set to `False`, this transformation will demote
    local arrays that do not buffer values between vector loops. Specific arrays in individual
    routines can also be marked for preservation by assigning them to the `preserve_arrays` list
    in the :any:`SchedulerConfig`.

    Parameters
    ----------
    horizontal : :any:`Dimension`
        :any:`Dimension` object describing the variable conventions used in code
        to define the horizontal data dimension and iteration space.
    """

    def __init__(self, horizontal, demote_local_arrays=True):
        self.horizontal = horizontal

        self.demote_local_arrays = demote_local_arrays

    @classmethod
    def get_locals_to_demote(cls, routine, sections, horizontal):
        """
        Create a list of local temporary arrays after checking that
        demotion is safe.

        Demotion is considered safe if the temporary is only used
        within one coherent vector-section (see
        :any:`extract_vector_sections`).

        Local temporaries get demoted if they have:
        * Only one dimension, which is the ``horizontal``
        * Have the ``horizontal`` as the innermost dimension, with all
          other dimensions being declared constant parameters.

        """
        # Create a list of local temporary arrays to filter down
        candidates = get_local_arrays(routine, routine.spec)

        # Only demote local arrays with the horizontal as fast dimension
        candidates = [
            v for v in candidates if v.shape and
            v.shape[0] in horizontal.sizes
        ]
        # Also demote arrays whose remaning dimensions are known constants
        candidates = [
            v for v in candidates
            if all(is_dimension_constant(d) for d in v.shape[1:])
        ]

        # Create an index into all variable uses per vector-level section
        vars_per_section = {
            s: OrderedSet(
                v.name.lower() for v in get_local_arrays(routine, s, unique=False)
            ) for s in sections
        }

        # Count in how many sections each temporary is used
        counts = {}
        for arr in candidates:
            counts[arr] = sum(
                1 if arr.name.lower() in v else 0
                for v in vars_per_section.values()
            )

        # Demote temporaries that are only used in one section or not at all
        to_demote = [k for k, v in counts.items() if v <= 1]

        # Get InlineCall args containing a horizontal array section
        icalls = FindInlineCalls().visit(routine.body)
        _params = flatten([call.parameters + as_tuple(call.kw_parameters.values()) for call in icalls])
        _params = [p for p in _params if isinstance(p, Array)]

        call_args = [
            p.clone(dimensions=None) for p in _params
            if any(s in (p.dimensions or p.shape) for s in horizontal.size_expressions)
        ]

        # Filter out variables that we will pass down the call tree
        calls = FindNodes(ir.CallStatement).visit(routine.body)
        call_args += flatten(call.arguments for call in calls)
        call_args += flatten(list(dict(call.kwarguments).values()) for call in calls)
        to_demote = [v for v in to_demote if v.name not in call_args]

        return set(to_demote)

    def transform_subroutine(self, routine, **kwargs):
        """
        Apply SCCDemote utilities to a :any:`Subroutine`.

        Parameters
        ----------
        routine : :any:`Subroutine`
            Subroutine to apply this transformation to.
        role : string
            Role of the subroutine in the call tree; should be ``"kernel"``
        """
        role = kwargs['role']
        item = kwargs.get('item', None)

        if role == 'kernel':
            demote_locals = self.demote_local_arrays
            preserve_arrays = []
            if item:
                demote_locals = item.config.get('demote_locals', self.demote_local_arrays)
                preserve_arrays = item.config.get('preserve_arrays', [])
            self.process_kernel(routine, demote_locals=demote_locals, preserve_arrays=preserve_arrays)

    def process_kernel(self, routine, demote_locals=True, preserve_arrays=None):
        """
        Applies the SCCDemote utilities to a "kernel" and demotes all suitable local arrays.

        Parameters
        ----------
        routine : :any:`Subroutine`
            Subroutine to apply this transformation to.
        """

        # Find vector sections marked in the SCCDevectorTransformation
        sections = [
            s for s in FindNodes(ir.Section).visit(routine.body)
            if s.label == 'vector_section'
        ]

        # Extract the local variables to demote after we wrap the sections in vector loops.
        # We do this, because need the section blocks to determine which local arrays
        # may carry buffered values between them, so that we may not demote those!
        to_demote = self.get_locals_to_demote(routine, sections, self.horizontal)

        # Filter out arrays marked explicitly for preservation
        if preserve_arrays:
            to_demote = [v for v in to_demote if not v.name in preserve_arrays]

        # Demote all private local variables that do not buffer values between sections
        if demote_locals:
            variables = tuple(v.name for v in to_demote)
            if variables:
                demote_variables(
                    routine, variable_names=variables,
                    dimensions=self.horizontal.sizes
                )
loki-ecmwf-0.3.6/loki/transformations/single_column/__init__.py0000664000175000017500000000177115167130205025105 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.transformations.single_column.annotate import * # noqa
from loki.transformations.single_column.base import * # noqa
from loki.transformations.single_column.demote import * # noqa
from loki.transformations.single_column.devector import * # noqa
from loki.transformations.single_column.hoist import * # noqa
from loki.transformations.single_column.revector import * # noqa
from loki.transformations.single_column.scc import * # noqa
from loki.transformations.single_column.scc_cuf import * # noqa
from loki.transformations.single_column.scc_low_level import * # noqa
from loki.transformations.single_column.vertical import * # noqa
loki-ecmwf-0.3.6/loki/transformations/single_column/tests/0000775000175000017500000000000015167130205024130 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/transformations/single_column/tests/__init__.py0000664000175000017500000000057015167130205026243 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
loki-ecmwf-0.3.6/loki/transformations/single_column/tests/test_scc_vector.py0000664000175000017500000007272315167130205027706 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki import Subroutine, Sourcefile, Dimension, fgen, Module
from loki.frontend import available_frontends, OMNI
from loki.ir import (
    nodes as ir, FindNodes, pragmas_attached, is_loki_pragma
)
from loki.transformations.single_column import (
    SCCDevectorTransformation, SCCRevectorTransformation, SCCVectorPipeline,
    SCCVecRevectorTransformation, SCCSeqRevectorTransformation,
    SCCVVectorPipeline, SCCSVectorPipeline
)


@pytest.fixture(scope='module', name='horizontal')
def fixture_horizontal():
    return Dimension(
        name='horizontal', size='nlon', index='jl',
        bounds=('start', 'end'), aliases=('nproma',)
    )

@pytest.fixture(scope='module', name='horizontal_bounds_aliases')
def fixture_horizontal_bounds_aliases():
    return Dimension(
        name='horizontal_bounds_aliases', size='nlon', index='jl',
        bounds=('start', 'end'), aliases=('nproma',),
        bounds_aliases=('bnds%start', 'bnds%end')
    )

@pytest.fixture(scope='module', name='vertical')
def fixture_vertical():
    return Dimension(name='vertical', size='nz', index='jk', aliases=('nlev',))

@pytest.fixture(scope='module', name='blocking')
def fixture_blocking():
    return Dimension(name='blocking', size='nb', index='b')


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('revector_trafo', [SCCSeqRevectorTransformation, SCCVecRevectorTransformation])
@pytest.mark.parametrize('ignore_nested_kernel', [False, True])
def test_scc_revector_transformation(frontend, horizontal, revector_trafo, ignore_nested_kernel, tmp_path):
    """
    Test removal of vector loops in kernel and re-insertion of a single
    hoisted horizontal loop in the kernel.
    """

    fcode_driver = """
  SUBROUTINE column_driver(nlon, nz, q, t, nb)
    use compute_mod, only: compute_column
    INTEGER, INTENT(IN)   :: nlon, nz, nb  ! Size of the horizontal and vertical
    REAL, INTENT(INOUT)   :: t(nlon,nz,nb)
    REAL, INTENT(INOUT)   :: q(nlon,nz,nb)
    INTEGER :: b, start, end

    start = 1
    end = nlon
    do b=1, nb
      call compute_column(start, end, nlon, nz, q(:,:,b), t(:,:,b))
    end do
  END SUBROUTINE column_driver
"""

    fcode_kernel = """
  MODULE compute_mod
  use compute_ctl_mod, only: compute_ctl
  use compute_ctl2_mod, only: compute_ctl2
  contains
  SUBROUTINE compute_column(start, end, nlon, nz, q, t)
    INTEGER, INTENT(IN) :: start, end  ! Iteration indices
    INTEGER, INTENT(IN) :: nlon, nz    ! Size of the horizontal and vertical
    REAL, INTENT(INOUT) :: t(nlon,nz)
    REAL, INTENT(INOUT) :: q(nlon,nz)
    INTEGER :: jl, jk
    REAL :: c

    c = 5.345
    DO jk = 2, nz
      DO jl = start, end
        t(jl, jk) = c * jk
        q(jl, jk) = q(jl, jk-1) + t(jl, jk) * c
      END DO
    END DO

    ! The scaling is purposefully upper-cased
    DO JL = START, END
      Q(JL, NZ) = Q(JL, NZ) * C
    END DO

    CALL COMPUTE_CTL(start, end, nlon, nz, q, t)
    CALL COMPUTE_CTL2(start, end, nlon, nz, q, t)

  END SUBROUTINE compute_column
  END MODULE compute_mod
"""

    fcode_intermediate_kernel = """
  MODULE compute_ctl_mod
  use compute2_mod, only: compute_another_column
  contains
  SUBROUTINE compute_ctl(start, end, nlon, nz, q1, t1)
    INTEGER, INTENT(IN) :: start, end  ! Iteration indices
    INTEGER, INTENT(IN) :: nlon, nz    ! Size of the horizontal and vertical
    REAL, INTENT(INOUT) :: t1(nlon,nz)
    REAL, INTENT(INOUT) :: q1(nlon,nz)
    CALL COMPUTE_ANOTHER_COLUMN(start, end, nlon, nz, q, t)
  END SUBROUTINE compute_ctl
  END MODULE compute_ctl_mod
"""

    fcode_intermediate2_kernel = """
  MODULE compute_ctl2_mod
  contains
  SUBROUTINE compute_ctl2(start, end, nlon, nz, q1, t1)
    INTEGER, INTENT(IN) :: start, end  ! Iteration indices
    INTEGER, INTENT(IN) :: nlon, nz    ! Size of the horizontal and vertical
    REAL, INTENT(INOUT) :: t1(nlon,nz)
    REAL, INTENT(INOUT) :: q1(nlon,nz)
  END SUBROUTINE compute_ctl2
  END MODULE compute_ctl2_mod
"""

    fcode_nested_kernel = """
  MODULE compute2_mod
  contains
  SUBROUTINE compute_another_column(start, end, nlon, nz, q1, t1)
    INTEGER, INTENT(IN) :: start, end  ! Iteration indices
    INTEGER, INTENT(IN) :: nlon, nz    ! Size of the horizontal and vertical
    REAL, INTENT(INOUT) :: t1(nlon,nz)
    REAL, INTENT(INOUT) :: q1(nlon,nz)
    INTEGER :: jl, jk
    REAL :: c

    c = 5.345
    DO jk = 2, nz
      DO jl = start, end
        t1(jl, jk) = c * jk
        q1(jl, jk) = q1(jl, jk-1) + t1(jl, jk) * c
      END DO
    END DO

    ! The scaling is purposefully upper-cased
    DO JL = START, END
      Q1(JL, NZ) = Q1(JL, NZ) * C
    END DO
  END SUBROUTINE compute_another_column
  END MODULE compute2_mod
"""

    # kernel = Subroutine.from_source(fcode_kernel, frontend=frontend)
    nested_kernel_mod = Module.from_source(
        fcode_nested_kernel, frontend=frontend, xmods=[tmp_path]
    )
    intermediate_kernel_mod = Module.from_source(
        fcode_intermediate_kernel, frontend=frontend, xmods=[tmp_path],
        definitions=nested_kernel_mod
    )
    intermediate_kernel2_mod = Module.from_source(
        fcode_intermediate2_kernel, frontend=frontend, xmods=[tmp_path]
    )
    kernel_mod = Module.from_source(
        fcode_kernel, frontend=frontend, xmods=[tmp_path],
        definitions=[intermediate_kernel_mod, intermediate_kernel2_mod]
    )
    driver = Subroutine.from_source(
        fcode_driver, frontend=frontend, xmods=[tmp_path],
        definitions=kernel_mod
    )
    kernel = kernel_mod.subroutines[0]
    intermediate_kernel = intermediate_kernel_mod.subroutines[0]
    nested_kernel = nested_kernel_mod.subroutines[0]

    # Ensure we have three loops in the kernel prior to transformation
    kernel_loops = FindNodes(ir.Loop).visit(kernel.body)
    assert len(kernel_loops) == 3

    scc_transform = (SCCDevectorTransformation(horizontal=horizontal),)
    scc_transform += (revector_trafo(horizontal=horizontal),)
    for transform in scc_transform:
        transform.apply(driver, role='driver', targets=('compute_column',))
        transform.apply(kernel, role='kernel', targets=('compute_Ctl',), ignore=('compute_Ctl2',))
        if ignore_nested_kernel:
            transform.apply(intermediate_kernel, role='kernel', ignore=('compute_Another_column',))
        else:
            transform.apply(intermediate_kernel, role='kernel', targets=('compute_Another_column',))
        if not ignore_nested_kernel:
            transform.apply(nested_kernel, role='kernel')

    # Ensure we have two nested loops in the kernel
    # (the hoisted horizontal and the native vertical)
    with pragmas_attached(kernel, node_type=ir.Loop):
        kernel_loops = FindNodes(ir.Loop).visit(kernel.body)
        calls = FindNodes(ir.CallStatement).visit(kernel.body)
        if revector_trafo == SCCSeqRevectorTransformation:
            assert len(kernel_loops) == 1
            assert kernel_loops[0].variable == 'jk'
            assert kernel_loops[0].bounds == '2:nz'
            assert kernel_loops[0].pragma
            assert is_loki_pragma(kernel_loops[0].pragma, starts_with='loop seq')
            for call in calls:
                assert 'jl' in call.arg_map
                assert call.routine.variable_map['jl'].type.intent.lower() == 'in'
        else:
            assert len(kernel_loops) == 2
            assert kernel_loops[1] in FindNodes(ir.Loop).visit(kernel_loops[0].body)
            assert kernel_loops[0].variable == 'jl'
            assert kernel_loops[0].bounds == 'start:end'
            assert kernel_loops[1].variable == 'jk'
            assert kernel_loops[1].bounds == '2:nz'

            # Check internal loop pragma annotations
            assert kernel_loops[0].pragma
            assert is_loki_pragma(kernel_loops[0].pragma, starts_with='loop vector')
            assert kernel_loops[1].pragma
            assert is_loki_pragma(kernel_loops[1].pragma, starts_with='loop seq')
            for call in calls:
                assert 'jl' not in call.arg_map

    # Ensure all expressions and array indices are unchanged
    assigns = FindNodes(ir.Assignment).visit(kernel.body)
    assert fgen(assigns[1]).lower() == 't(jl, jk) = c*jk'
    assert fgen(assigns[2]).lower() == 'q(jl, jk) = q(jl, jk - 1) + t(jl, jk)*c'
    assert fgen(assigns[3]).lower() == 'q(jl, nz) = q(jl, nz)*c'

    # Ensure that vector-section labels have been removed
    sections = FindNodes(ir.Section).visit(kernel.body)
    assert all(not s.label for s in sections)

    # Ensure driver remains unaffected and is marked
    with pragmas_attached(driver, node_type=ir.Loop):
        driver_loops = FindNodes(ir.Loop).visit(driver.body)
        if revector_trafo == SCCSeqRevectorTransformation:
            assert len(driver_loops) == 2
            assert driver_loops[1].variable == 'jl'
            assert driver_loops[1].bounds == 'start:end'
            assert driver_loops[1].pragma and len(driver_loops[1].pragma) == 1
            assert is_loki_pragma(driver_loops[1].pragma, starts_with='loop vector')
        else:
            assert len(driver_loops) == 1
        assert driver_loops[0].variable == 'b'
        assert driver_loops[0].bounds == '1:nb'
        assert driver_loops[0].pragma and len(driver_loops[0].pragma) == 1
        assert is_loki_pragma(driver_loops[0].pragma[0], starts_with='loop driver')
        assert 'vector_length(nlon)' in driver_loops[0].pragma[0].content

    kernel_calls = FindNodes(ir.CallStatement).visit(driver_loops[0])
    assert len(kernel_calls) == 1
    if revector_trafo == SCCSeqRevectorTransformation:
        assert 'jl' in kernel_calls[0].arg_map
        assert 'jl' in kernel_calls[0].routine.arguments
    else:
        assert 'jl' not in kernel_calls[0].arg_map

    assert kernel_calls[0].name == 'compute_column'
    if revector_trafo == SCCSeqRevectorTransformation:
        # make sure call to nested kernel gets horizontal.index as argument
        #  no matter whether it is a target or within ignore
        nested_kernel_call = FindNodes(ir.CallStatement).visit(kernel.body)[0]
        assert 'jl' in nested_kernel_call.arg_map

@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('revector_trafo', [SCCSeqRevectorTransformation, SCCVecRevectorTransformation])
def test_scc_revector_transformation_aliased_bounds(frontend, horizontal_bounds_aliases, revector_trafo, tmp_path):
    """
    Test removal of vector loops in kernel and re-insertion of a single
    hoisted horizontal loop in the kernel with aliased loop bounds.
    """

    fcode_bnds_type_mod = """
module bnds_type_mod
implicit none
    type bnds_type
        integer :: start
        integer :: end
    end type bnds_type
end module bnds_type_mod
    """.strip()

    fcode_driver = """
SUBROUTINE column_driver(nlon, nz, q, t, nb)
    USE bnds_type_mod, only : bnds_type
    USE compute_mod, only : compute_column
    INTEGER, INTENT(IN)   :: nlon, nz, nb  ! Size of the horizontal and vertical
    REAL, INTENT(INOUT)   :: t(nlon,nz,nb)
    REAL, INTENT(INOUT)   :: q(nlon,nz,nb)
    INTEGER :: b, start, end
    TYPE(bnds_type) :: bnds

    bnds%start = 1
    bnds%end = nlon
    do b=1, nb
      call compute_column(bnds, nlon, nz, q(:,:,b), t(:,:,b))
    end do
END SUBROUTINE column_driver
    """.strip()

    fcode_kernel = """
SUBROUTINE compute_column(bnds, nlon, nz, q, t)
    USE bnds_type_mod, only : bnds_type
    TYPE(bnds_type), INTENT(IN) :: bnds
    INTEGER, INTENT(IN) :: nlon, nz    ! Size of the horizontal and vertical
    REAL, INTENT(INOUT) :: t(nlon,nz)
    REAL, INTENT(INOUT) :: q(nlon,nz)
    INTEGER :: jl, jk
    REAL :: c

    c = 5.345
    DO jk = 2, nz
      DO jl = bnds%start, bnds%end
        t(jl, jk) = c * jk
        q(jl, jk) = q(jl, jk-1) + t(jl, jk) * c
      END DO
    END DO

    ! The scaling is purposefully upper-cased
    DO JL = BNDS%START, BNDS%END
      Q(JL, NZ) = Q(JL, NZ) * C
    END DO
END SUBROUTINE compute_column
    """.strip()
    fcode_kernel = """
MODULE compute_mod
contains
SUBROUTINE compute_column(bnds, nlon, nz, q, t)
    USE bnds_type_mod, only : bnds_type
    TYPE(bnds_type), INTENT(IN) :: bnds
    INTEGER, INTENT(IN) :: nlon, nz    ! Size of the horizontal and vertical
    REAL, INTENT(INOUT) :: t(nlon,nz)
    REAL, INTENT(INOUT) :: q(nlon,nz)
    INTEGER :: jl, jk
    REAL :: c

    c = 5.345
    DO jk = 2, nz
      DO jl = bnds%start, bnds%end
        t(jl, jk) = c * jk
        q(jl, jk) = q(jl, jk-1) + t(jl, jk) * c
      END DO
    END DO

    ! The scaling is purposefully upper-cased
    DO JL = BNDS%START, BNDS%END
      Q(JL, NZ) = Q(JL, NZ) * C
    END DO
END SUBROUTINE compute_column
END MODULE compute_mod
    """.strip()

    bnds_type_mod = Module.from_source(fcode_bnds_type_mod, frontend=frontend, xmods=[tmp_path])
    kernel_mod = Module.from_source(fcode_kernel, frontend=frontend, xmods=[tmp_path],
                                    definitions=bnds_type_mod.definitions)
    kernel = kernel_mod.subroutines[0]
    driver = Subroutine.from_source(fcode_driver, frontend=frontend, xmods=[tmp_path],
                                    definitions=(bnds_type_mod, kernel_mod))

    # Ensure we have three loops in the kernel prior to transformation
    kernel_loops = FindNodes(ir.Loop).visit(kernel.body)
    assert len(kernel_loops) == 3

    scc_transform = (SCCDevectorTransformation(horizontal=horizontal_bounds_aliases),)
    scc_transform += (revector_trafo(horizontal=horizontal_bounds_aliases),)
    for transform in scc_transform:
        transform.apply(driver, role='driver', targets=('compute_column',))
        transform.apply(kernel, role='kernel')

    # Ensure we have two nested loops in the kernel
    # (the hoisted horizontal and the native vertical)
    with pragmas_attached(kernel, node_type=ir.Loop):
        kernel_loops = FindNodes(ir.Loop).visit(kernel.body)
        if revector_trafo == SCCSeqRevectorTransformation:
            assert len(kernel_loops) == 1
            assert kernel_loops[0].variable == 'jk'
            assert kernel_loops[0].bounds == '2:nz'
            assert kernel_loops[0].pragma
            assert is_loki_pragma(kernel_loops[0].pragma, starts_with='loop seq')
        else:
            assert len(kernel_loops) == 2
            assert kernel_loops[1] in FindNodes(ir.Loop).visit(kernel_loops[0].body)
            assert kernel_loops[0].variable == 'jl'
            assert kernel_loops[0].bounds == 'bnds%start:bnds%end'
            assert kernel_loops[1].variable == 'jk'
            assert kernel_loops[1].bounds == '2:nz'

            # Check internal loop pragma annotations
            assert kernel_loops[0].pragma
            assert is_loki_pragma(kernel_loops[0].pragma, starts_with='loop vector')
            assert kernel_loops[1].pragma
            assert is_loki_pragma(kernel_loops[1].pragma, starts_with='loop seq')

    # Ensure all expressions and array indices are unchanged
    assigns = FindNodes(ir.Assignment).visit(kernel.body)
    assert fgen(assigns[1]).lower() == 't(jl, jk) = c*jk'
    assert fgen(assigns[2]).lower() == 'q(jl, jk) = q(jl, jk - 1) + t(jl, jk)*c'
    assert fgen(assigns[3]).lower() == 'q(jl, nz) = q(jl, nz)*c'

    # Ensure that vector-section labels have been removed
    sections = FindNodes(ir.Section).visit(kernel.body)
    assert all(not s.label for s in sections)

    # Ensure driver remains unaffected and is marked
    with pragmas_attached(driver, node_type=ir.Loop):
        driver_loops = FindNodes(ir.Loop).visit(driver.body)
        if revector_trafo == SCCSeqRevectorTransformation:
            assert len(driver_loops) == 2
            assert driver_loops[1].variable == 'jl'
            assert driver_loops[1].bounds == 'bnds%start:bnds%end'
            assert driver_loops[1].pragma and len(driver_loops[1].pragma) == 1
            assert is_loki_pragma(driver_loops[1].pragma, starts_with='loop vector')
        else:
            assert len(driver_loops) == 1
        assert driver_loops[0].variable == 'b'
        assert driver_loops[0].bounds == '1:nb'
        assert driver_loops[0].pragma and len(driver_loops[0].pragma) == 1
        assert is_loki_pragma(driver_loops[0].pragma[0], starts_with='loop driver')
        assert 'vector_length(nlon)' in driver_loops[0].pragma[0].content

    kernel_calls = FindNodes(ir.CallStatement).visit(driver_loops[0])
    assert len(kernel_calls) == 1
    if revector_trafo == SCCSeqRevectorTransformation:
        assert 'jl' in kernel_calls[0].arg_map
        assert 'jl' in kernel_calls[0].routine.arguments
    else:
        assert 'jl' not in kernel_calls[0].arg_map

    assert kernel_calls[0].name == 'compute_column'


@pytest.mark.parametrize('frontend', available_frontends())
def test_scc_devector_transformation(frontend, horizontal):
    """
    Test the correct identification of vector sections and removal of vector loops.
    """

    fcode_kernel = """
  SUBROUTINE compute_column(start, end, nlon, nz, q)
    INTEGER, INTENT(IN) :: start, end  ! Iteration indices
    INTEGER, INTENT(IN) :: nlon, nz    ! Size of the horizontal and vertical
    REAL, INTENT(INOUT) :: q(nlon,nz)
    INTEGER :: jl, jk, niter
    LOGICAL :: maybe
    REAL :: c

    if (maybe)  call logger()

    c = 5.345
    DO JL = START, END
      Q(JL, NZ) = Q(JL, NZ) + 3.0
    END DO

    DO niter = 1, 3

      DO JL = START, END
        Q(JL, NZ) = Q(JL, NZ) + 1.0
      END DO

      call update_q(start, end, nlon, nz, q, c)

    END DO

    DO JL = START, END
      Q(JL, NZ) = Q(JL, NZ) * C
    END DO

    IF (.not. maybe) THEN
      call update_q(start, end, nlon, nz, q, c)
    END IF

    DO JL = START, END
      Q(JL, NZ) = Q(JL, NZ) + C * 3.
    END DO

    IF (maybe)  call logger()

  END SUBROUTINE compute_column
"""
    kernel = Subroutine.from_source(fcode_kernel, frontend=frontend)

    # Check number of horizontal loops prior to transformation
    loops = [l for l in FindNodes(ir.Loop).visit(kernel.body) if l.variable == 'jl']
    assert len(loops) == 4

    # Test SCCDevector transform for kernel with scope-splitting outer loop
    scc_transform = SCCDevectorTransformation(horizontal=horizontal)
    scc_transform.apply(kernel, role='kernel')

    # Check removal of horizontal loops
    loops = [l for l in FindNodes(ir.Loop).visit(kernel.body) if l.variable == 'jl']
    assert not loops

    # Check number and content of vector sections
    sections = [
        s for s in FindNodes(ir.Section).visit(kernel.body)
        if s.label == 'vector_section'
    ]
    assert len(sections) == 4

    assigns = FindNodes(ir.Assignment).visit(sections[0])
    assert len(assigns) == 2
    assigns = FindNodes(ir.Assignment).visit(sections[1])
    assert len(assigns) == 1
    assigns = FindNodes(ir.Assignment).visit(sections[2])
    assert len(assigns) == 1
    assigns = FindNodes(ir.Assignment).visit(sections[3])
    assert len(assigns) == 1


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('revector_trafo', [SCCSeqRevectorTransformation, SCCVecRevectorTransformation])
def test_scc_vector_inlined_call(frontend, horizontal, revector_trafo):
    """
    Test that calls targeted for inlining inside a vector region are not treated as separators
    """

    fcode = """
    subroutine some_inlined_kernel(work)
    !$loki routine seq
       real, intent(inout) :: work

       work = work*2.
    end subroutine some_inlined_kernel

    subroutine some_kernel(start, end, nlon, work, cond)
       logical, intent(in) :: cond
       integer, intent(in) :: nlon, start, end
       real, dimension(nlon), intent(inout) :: work

       integer :: jl

       do jl=start,end
          if(cond)then
             call some_inlined_kernel(work(jl))
          endif
          work(jl) = work(jl) + 1.
       enddo

       call some_other_kernel()

    end subroutine some_kernel
    """

    source = Sourcefile.from_source(fcode, frontend=frontend)
    routine = source['some_kernel']
    inlined_routine = source['some_inlined_kernel']
    routine.enrich((inlined_routine,))

    scc_transform = (SCCDevectorTransformation(horizontal=horizontal),)
    scc_transform += (revector_trafo(horizontal=horizontal),)
    for transform in scc_transform:
        transform.apply(routine, role='kernel', targets=['some_kernel', 'some_inlined_kernel'])

    # Check only `!$loki loop vector` pragma has been inserted
    if revector_trafo == SCCVecRevectorTransformation:
        pragmas = FindNodes(ir.Pragma).visit(routine.ir)
        assert len(pragmas) == 2
        assert is_loki_pragma(pragmas[0], starts_with='routine vector')
        assert is_loki_pragma(pragmas[1], starts_with='loop vector')

        # Check that 'some_inlined_kernel' remains within vector-parallel region
        loops = FindNodes(ir.Loop).visit(routine.body)
        assert len(loops) == 1
        calls = FindNodes(ir.CallStatement).visit(loops[0].body)
        assert len(calls) == 1
        calls = FindNodes(ir.CallStatement).visit(routine.body)
        assert len(calls) == 2
    else:
        assert horizontal.index in routine.arguments
        assert routine.variable_map[horizontal.index].type.intent == 'in'
        pragmas = FindNodes(ir.Pragma).visit(routine.ir)
        assert len(pragmas) == 1
        assert is_loki_pragma(pragmas[0], starts_with='routine seq')
        calls = FindNodes(ir.CallStatement).visit(routine.body)
        assert len(calls) == 2


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('trim_vector_sections', [False, True])
@pytest.mark.parametrize('revector_trafo', [SCCSeqRevectorTransformation, SCCVecRevectorTransformation])
def test_scc_vector_section_trim_simple(frontend, horizontal, trim_vector_sections, revector_trafo):
    """
    Test the trimming of vector-sections to exclude scalar assignments.
    """

    fcode_kernel = """
    subroutine some_kernel(start, end, nlon)
       implicit none

       integer, intent(in) :: nlon, start, end
       logical :: flag0
       real, dimension(nlon) :: work
       integer :: jl

       flag0 = .true.

       do jl=start,end
          work(jl) = 1.
       enddo
       ! random comment
    end subroutine some_kernel
    """

    routine = Subroutine.from_source(fcode_kernel, frontend=frontend)

    scc_transform = (SCCDevectorTransformation(horizontal=horizontal, trim_vector_sections=trim_vector_sections),)
    scc_transform += (revector_trafo(horizontal=horizontal),)

    for transform in scc_transform:
        transform.apply(routine, role='kernel', targets=['some_kernel',])

    assign = FindNodes(ir.Assignment).visit(routine.body)[0]
    loops = FindNodes(ir.Loop).visit(routine.body)
    if revector_trafo == SCCSeqRevectorTransformation:
        assert len(loops) == 0
    else:
        loop = loops[0]
        comment = [
            c for c in FindNodes(ir.Comment).visit(routine.body)
            if c.text == '! random comment'
        ][0]

        # check we found the right assignment
        assert assign.lhs.name.lower() == 'flag0'

        # check we found the right comment
        assert comment.text == '! random comment'

        if trim_vector_sections:
            assert assign not in loop.body
            assert assign in routine.body.body

            assert comment not in loop.body
            assert comment in routine.body.body
        else:
            assert assign in loop.body
            assert assign not in routine.body.body

            assert comment in loop.body
            assert comment not in routine.body.body


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('trim_vector_sections', [False, True])
def test_scc_vector_section_trim_nested(frontend, horizontal, trim_vector_sections):
    """
    Test the trimming of vector-sections to exclude nested scalar assignments.
    """

    fcode_kernel = """
    subroutine some_kernel(start, end, nlon, flag0)
       implicit none

       integer, intent(in) :: nlon, start, end
       logical, intent(in) :: flag0
       logical :: flag1, flag2
       real, dimension(nlon) :: work

       integer :: jl

       if(flag0)then
         flag1 = .true.
         flag2 = .false.
       else
         flag1 = .false.
         flag2 = .true.
       endif

       do jl=start,end
          work(jl) = 1.
       enddo
    end subroutine some_kernel
    """

    routine = Subroutine.from_source(fcode_kernel, frontend=frontend)

    scc_transform = (SCCDevectorTransformation(horizontal=horizontal, trim_vector_sections=trim_vector_sections),)
    scc_transform += (SCCRevectorTransformation(horizontal=horizontal),)

    for transform in scc_transform:
        transform.apply(routine, role='kernel', targets=['some_kernel',])

    cond = FindNodes(ir.Conditional).visit(routine.body)[0]
    loop = FindNodes(ir.Loop).visit(routine.body)[0]

    if trim_vector_sections:
        assert cond not in loop.body
        assert cond in routine.body.body
    else:
        assert cond in loop.body
        assert cond not in routine.body.body


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('trim_vector_sections', [False, True])
def test_scc_vector_section_trim_complex(
        frontend, horizontal, vertical, blocking, trim_vector_sections
):
    """
    Test to highlight the limitations of vector-section trimming.
    """

    fcode_kernel = """
    subroutine some_kernel(start, end, nlon, flag0)
       implicit none

       integer, intent(in) :: nlon, start, end
       logical, intent(in) :: flag0
       logical :: flag1, flag2
       real, dimension(nlon) :: work, work1

       integer :: jl

       flag1 = .true.
       if(flag0)then
         flag2 = .false.
       else
         work1(start:end) = 1.
       endif

       do jl=start,end
          work(jl) = 1.
       enddo
    end subroutine some_kernel
    """

    routine = Subroutine.from_source(fcode_kernel, frontend=frontend)

    scc_pipeline = SCCVectorPipeline(
        horizontal=horizontal, vertical=vertical, block_dim=blocking,
        directive='openacc', trim_vector_sections=trim_vector_sections
    )
    scc_pipeline.apply(routine, role='kernel', targets=['some_kernel',])

    assign = FindNodes(ir.Assignment).visit(routine.body)[0]

    # check we found the right assignment
    assert assign.lhs.name.lower() == 'flag1'

    cond = FindNodes(ir.Conditional).visit(routine.body)[0]
    loop = FindNodes(ir.Loop).visit(routine.body)[0]

    assert cond in loop.body
    assert cond not in routine.body.body
    if trim_vector_sections:
        assert assign not in loop.body
        assert(len(FindNodes(ir.Assignment).visit(loop.body)) == 3)
    else:
        assert assign in loop.body
        assert(len(FindNodes(ir.Assignment).visit(loop.body)) == 4)

@pytest.mark.parametrize('frontend', available_frontends(
    skip={OMNI: 'OMNI automatically expands ELSEIF into a nested ELSE=>IF.'}
))
@pytest.mark.parametrize('trim_vector_sections', [False, True])
@pytest.mark.parametrize('vector_pipeline', [SCCVVectorPipeline, SCCSVectorPipeline])
def test_scc_devector_section_special_case(frontend, horizontal, vertical, blocking, trim_vector_sections,
        vector_pipeline):
    """
    Test to highlight the limitations of vector-section trimming.
    """

    fcode_kernel = """
    subroutine some_kernel(start, end, nlon, flag0, flag1, flag2)
       implicit none

       integer, intent(in) :: nlon, start, end
       logical, intent(in) :: flag0, flag1, flag2
       real, dimension(nlon) :: work

       integer :: jl

       if (flag0) then
         call some_other_kernel()
       elseif (flag1) then
         do jl=start,end
            work(jl) = 1.
         enddo
       elseif (flag2) then
         do jl=start,end
            work(jl) = 1.
            work(jl) = 2.
         enddo
       else
         do jl=start,end
            work(jl) = 41.
            work(jl) = 42.
         enddo
       endif

    end subroutine some_kernel
    """

    routine = Subroutine.from_source(fcode_kernel, frontend=frontend)

    # check whether pipeline can be applied and works as expected
    scc_pipeline = vector_pipeline(
        horizontal=horizontal, vertical=vertical, block_dim=blocking,
        directive='openacc', trim_vector_sections=trim_vector_sections
    )
    scc_pipeline.apply(routine, role='kernel', targets=['some_kernel',])

    with pragmas_attached(routine, node_type=ir.Loop):
        conditional = FindNodes(ir.Conditional).visit(routine.body)[0]
        assert isinstance(conditional.body[0], ir.CallStatement)
        assert len(conditional.body) == 1
        assert isinstance(conditional.else_body[0], ir.Conditional)
        assert len(conditional.else_body) == 1
        if vector_pipeline == SCCVVectorPipeline:
            assert isinstance(conditional.else_body[0].body[0], ir.Comment)
            assert isinstance(conditional.else_body[0].body[1], ir.Loop)
            assert conditional.else_body[0].body[1].pragma[0].content.lower() == 'loop vector'

            # Check that all else-bodies have been wrapped
            else_bodies = conditional.else_bodies
            assert len(else_bodies) == 3
            for body in else_bodies:
                assert isinstance(body[0], ir.Comment)
                assert isinstance(body[1], ir.Loop)
                assert body[1].pragma[0].content.lower() == 'loop vector'
        else:
            assert isinstance(conditional.else_body[0].body[0], ir.Assignment)
            # Check that all else-bodies have been wrapped
            else_bodies = conditional.else_bodies
            assert len(else_bodies) == 3
            for body in else_bodies:
                assert isinstance(body[0], ir.Assignment)
loki-ecmwf-0.3.6/loki/transformations/single_column/tests/test_scc.py0000664000175000017500000015671015167130205026323 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

# pylint: disable=too-many-lines

import pytest

from loki import Subroutine, Sourcefile, Dimension, fgen
from loki.batch import ProcedureItem, TransformationError
from loki.expression import Scalar, Array, IntLiteral
from loki.frontend import available_frontends, OMNI, HAVE_FP
from loki.ir import (
    FindNodes, Assignment, CallStatement, Conditional, Loop,
    Pragma, PragmaRegion, pragmas_attached, is_loki_pragma,
    pragma_regions_attached, FindInlineCalls
)
from loki.logging import WARNING
from loki.transformations import (
    DataOffloadTransformation, SanitiseTransformation,
    InlineTransformation, get_loop_bounds, PragmaModelTransformation
)
from loki.transformations.single_column import (
    SCCBaseTransformation, SCCDevectorTransformation,
    SCCDemoteTransformation, SCCRevectorTransformation,
    SCCAnnotateTransformation, SCCVectorPipeline,
    SCCVVectorPipeline, SCCSVectorPipeline, SCCSeqRevectorTransformation
)


@pytest.fixture(scope='module', name='horizontal')
def fixture_horizontal():
    return Dimension(
        name='horizontal', size=['dims%klon', 'nlon'], index='jl',
        aliases=('nproma',), lower=('START', 'dims%ist'),
        upper=('end', 'dims%iend')
    )

@pytest.fixture(scope='module', name='horizontal_bounds_aliases')
def fixture_horizontal_bounds_aliases():
    return Dimension(
        name='horizontal_bounds_aliases', size='nlon', index='jl',
        bounds=('start', 'end'), aliases=('nproma',),
        bounds_aliases=('bnds%start', 'bnds%end')
    )

@pytest.fixture(scope='module', name='blocking')
def fixture_blocking():
    return Dimension(name='blocking', size='nb', index='b')


@pytest.mark.parametrize('frontend', available_frontends())
def test_scc_base_resolve_vector_notation(frontend, horizontal):
    """
    Test resolving of vector notation in kernel.
    """

    fcode_kernel = """
  SUBROUTINE compute_column(start, end, nlon, nz, q, t)
    INTEGER, INTENT(IN) :: start, end  ! Iteration indices
    INTEGER, INTENT(IN) :: nlon, nz    ! Size of the horizontal and vertical
    REAL, INTENT(INOUT) :: t(nlon,nz)
    REAL, INTENT(INOUT) :: q(nlon,nz)
    INTEGER :: jk
    REAL :: c

    c = 5.345
    DO jk = 2, nz
      t(start:end, jk) = c * jk
      q(start:end, jk) = q(start:end, jk-1) + t(start:end, jk) * c
    END DO
  END SUBROUTINE compute_column
"""

    kernel = Subroutine.from_source(fcode_kernel, frontend=frontend)

    scc_transform = SCCBaseTransformation(horizontal=horizontal)
    scc_transform.apply(kernel, role='kernel')

    # Ensure horizontal loop variable has been declared
    assert 'jl' in kernel.variables

    # Ensure we have three loops in the kernel,
    # horizontal loops should be nested within vertical
    kernel_loops = FindNodes(Loop).visit(kernel.body)
    assert len(kernel_loops) == 3
    assert kernel_loops[1] in FindNodes(Loop).visit(kernel_loops[0].body)
    assert kernel_loops[2] in FindNodes(Loop).visit(kernel_loops[0].body)
    assert kernel_loops[1].variable == 'jl'
    assert kernel_loops[1].bounds == 'start:end'
    assert kernel_loops[2].variable == 'jl'
    assert kernel_loops[2].bounds == 'start:end'
    assert kernel_loops[0].variable == 'jk'
    assert kernel_loops[0].bounds == '2:nz'

    # Ensure all expressions and array indices are unchanged
    assigns = FindNodes(Assignment).visit(kernel.body)
    assert fgen(assigns[1]).lower() == 't(jl, jk) = c*jk'
    assert fgen(assigns[2]).lower() == 'q(jl, jk) = q(jl, jk - 1) + t(jl, jk)*c'


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('rel_index', ('jl', 'jcol', 'ji'))
@pytest.mark.parametrize('indices', (('jl', 'jcol', 'jlll'), ('jcol','jcol', 'jcol'),
    ('jl', 'jl', 'jl'), ('jcol', 'jlll', 'jlll')))
def test_scc_base_rename_index_aliases(frontend, rel_index, indices):
    """
    Test rename index aliases in kernel.
    """

    fcode_kernel = f"""
  SUBROUTINE kernel_rename_index_aliases(start, end, nlon, nz, q, t)
    INTEGER, INTENT(IN) :: start, end  ! Iteration indices
    INTEGER, INTENT(IN) :: nlon, nz    ! Size of the horizontal and vertical
    REAL, INTENT(INOUT) :: t(nlon,nz)
    REAL, INTENT(INOUT) :: q(nlon,nz)
    INTEGER :: jk {', jl' if 'jl' in indices else ''}
    {'INTEGER :: jcol' if 'jcol' in indices else ''}
    {'INTEGER :: jlll' if 'jlll' in indices else ''}
    REAL :: c

    c = 5.345
    DO jk = 2, nz
      DO {indices[0]} = start, end
        t({indices[0]}, jk) = c * jk
        q({indices[0]}, jk) = q({indices[0]}, jk-1) + t({indices[0]}, jk) * c
      END DO
    END DO
    DO jk = 2, nz
      DO {indices[1]} = start, end
        t({indices[1]}, jk) = c * jk
        q({indices[1]}, jk) = q({indices[1]}, jk-1) + t({indices[1]}, jk) * c
      END DO
    END DO
    DO jk = 2, nz
      DO {indices[2]} = start, end
        t({indices[2]}, jk) = c * jk
        q({indices[2]}, jk) = q({indices[2]}, jk-1) + t({indices[2]}, jk) * c
      END DO
    END DO
  END SUBROUTINE kernel_rename_index_aliases
"""

    kernel = Subroutine.from_source(fcode_kernel, frontend=frontend)

    h_indices = [rel_index] + [_index for _index in indices if _index != rel_index]
    horizontal = Dimension(
        name='horizontal', size=['dims%klon', 'nlon'], index=h_indices,
        aliases=('nproma',), lower=('START', 'dims%ist'),
        upper=('end', 'dims%iend')
    )
    scc_transform = SCCBaseTransformation(horizontal=horizontal)
    scc_transform.apply(kernel, role='kernel', rename_index_aliases=True)

    # ensure correct horizontal index is within the kernel variables
    assert rel_index in kernel.variables
    # ensure no horizontal index alias is in the kernel variables
    for h_index in h_indices[1:]:
        assert h_index not in kernel.variables

    # make sure all the relevant loops are there and have the correct variable
    loops = FindNodes(Loop).visit(kernel.body)
    h_loops = [loop for loop in loops if loop.bounds.lower == 'start']
    assert len(h_loops) == 3
    assert all(h_loop.variable == rel_index for h_loop in h_loops)


@pytest.mark.parametrize('frontend', available_frontends())
def test_scc_demote_transformation(frontend, horizontal):
    """
    Test that local array variables that do not buffer values
    between vector sections and whose size is known at compile-time
    are demoted.
    """

    fcode_kernel = """
  SUBROUTINE compute_column(start, end, nlon, nproma, nz, q)
    INTEGER, INTENT(IN) :: start, end  ! Iteration indices
    INTEGER, INTENT(IN) :: nlon, nz    ! Size of the horizontal and vertical
    INTEGER, INTENT(IN) :: nproma      ! Horizontal size alias
    REAL, INTENT(INOUT) :: q(nlon,nz)
    REAL :: t(nlon,nz)
    REAL :: a(nproma)
    REAL :: b(nlon,psize)
    REAL :: unused(nlon)
    REAL :: d(nlon,psize)
    INTEGER, PARAMETER :: psize = 3
    INTEGER :: jl, jk
    REAL :: c

    c = 5.345
    DO jk = 2, nz
      DO jl = start, end
        t(jl, jk) = c * jk
        q(jl, jk) = q(jl, jk-1) + t(jl, jk) * c
      END DO
    END DO

    ! The scaling is purposefully upper-cased
    DO JL = START, END
      a(jl) = Q(JL, 1)
      b(jl, 1) = Q(JL, 2)
      b(jl, 2) = Q(JL, 3)
      b(jl, 3) = a(jl) * (b(jl, 1) + b(jl, 2))

      d(jl, 1) = b(jl, 1)
      d(jl, 2) = b(jl, 2)
      d(jl, 3) = b(jl, 3)

      Q(JL, NZ) = Q(JL, NZ) * C + b(jl, 3)
    END DO
  END SUBROUTINE compute_column
"""
    kernel_source = Sourcefile.from_source(fcode_kernel, frontend=frontend)
    kernel_item = ProcedureItem(name='#compute_column', source=kernel_source, config={'preserve_arrays': ['d',]})
    kernel = kernel_source.subroutines[0]

    # Must run SCCDevector first because demotion relies on knowledge
    # of vector sections
    scc_transform = (SCCDevectorTransformation(horizontal=horizontal),)
    scc_transform += (SCCDemoteTransformation(horizontal=horizontal),)
    for transform in scc_transform:
        transform.apply(kernel, role='kernel', item=kernel_item)

    # Ensure correct array variables shapes
    assert isinstance(kernel.variable_map['a'], Scalar)
    assert isinstance(kernel.variable_map['b'], Array)
    assert isinstance(kernel.variable_map['c'], Scalar)
    assert isinstance(kernel.variable_map['t'], Array)
    assert isinstance(kernel.variable_map['q'], Array)
    assert isinstance(kernel.variable_map['unused'], Scalar)
    assert isinstance(kernel.variable_map['d'], Array)

    # Ensure that parameter-sized array b got demoted only
    assert kernel.variable_map['b'].shape == ((3,) if frontend is OMNI else ('psize',))
    assert kernel.variable_map['t'].shape == ('nlon', 'nz')
    assert kernel.variable_map['q'].shape == ('nlon', 'nz')

    # Ensure relevant expressions and array indices are unchanged
    assigns = FindNodes(Assignment).visit(kernel.body)
    assert fgen(assigns[1]).lower() == 't(jl, jk) = c*jk'
    assert fgen(assigns[2]).lower() == 'q(jl, jk) = q(jl, jk - 1) + t(jl, jk)*c'
    assert fgen(assigns[3]).lower() == 'a = q(jl, 1)'
    assert fgen(assigns[4]).lower() == 'b(1) = q(jl, 2)'
    assert fgen(assigns[5]).lower() == 'b(2) = q(jl, 3)'
    assert fgen(assigns[6]).lower() == 'b(3) = a*(b(1) + b(2))'
    assert fgen(assigns[7]).lower() == 'd(jl, 1) = b(1)'
    assert fgen(assigns[8]).lower() == 'd(jl, 2) = b(2)'
    assert fgen(assigns[9]).lower() == 'd(jl, 3) = b(3)'
    assert fgen(assigns[10]).lower() == 'q(jl, nz) = q(jl, nz)*c + b(3)'


@pytest.mark.xfail(not HAVE_FP, reason="Identification of array reduction intrinsics requires fparser.")
@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('acc_data', ['default', 'copyin', None])
def test_scc_annotate_openacc(frontend, horizontal, blocking, acc_data):
    """
    Test the correct addition of OpenACC pragmas to SCC format code (no hoisting).
    """

    fcode_driver = f"""
  SUBROUTINE column_driver(nlon, nproma, nlev, nz, q, nb)
    INTEGER, INTENT(IN)   :: nlon, nz, nb  ! Size of the horizontal and vertical
    INTEGER, INTENT(IN)   :: nproma, nlev  ! Aliases of horizontal and vertical sizes
    REAL, INTENT(INOUT)   :: q(nlon,nz,nb)
    REAL :: other_var(nlon), more_var(nlon)
    INTEGER :: b, start, end

    start = 1
    end = nlon
    {'!$acc data default(present)' if acc_data == 'default' else ''}
    {'!$acc data copyin(more_var) copyin(other_var)' if acc_data == 'copyin' else ''}
    !
    do b=1, nb
      call compute_column(start, end, nlon, nproma, nz, q(:,:,b), other_var, more_var)
    end do
    !
    {'!$acc end data' if acc_data else ''}
  END SUBROUTINE column_driver
"""

    fcode_kernel = """
  SUBROUTINE compute_column(start, end, nlon, nproma, nlev, nz, q, other_var, more_var)
    INTEGER, INTENT(IN) :: start, end   ! Iteration indices
    INTEGER, INTENT(IN) :: nlon, nz     ! Size of the horizontal and vertical
    INTEGER, INTENT(IN) :: nproma, nlev ! Aliases of horizontal and vertical sizes
    REAL, INTENT(INOUT) :: q(nlon,nz)
    REAL, INTENT(IN) :: other_var(nlon), more_var(nlon)
    REAL :: t(nlon,nz)
    REAL :: a(nlon)
    REAL :: d(nproma)
    REAL :: tmp(nproma)
    REAL :: e(nlev)
    REAL :: b(nlon,psize)
    REAL :: f(psize)
    INTEGER, PARAMETER :: psize = 3
    INTEGER :: jl, jk
    REAL :: c
    REAL :: tmp_sum

    c = 5.345
    DO jk = 2, nz
      DO jl = start, end
        t(jl, jk) = c * jk
        q(jl, jk) = q(jl, jk-1) + t(jl, jk) * c
      END DO
    END DO

    DO jl = start, end
      tmp(jl) = other_var(jl)
    END DO

    tmp_sum = sum(tmp(start:end))

    ! The scaling is purposefully upper-cased
    DO JL = START, END
      a(jl) = Q(JL, 1)
      b(jl, 1) = Q(JL, 2)
      b(jl, 2) = Q(JL, 3)
      b(jl, 3) = a(jl) * (b(jl, 1) + b(jl, 2))

      Q(JL, NZ) = Q(JL, NZ) * C + tmp
    END DO
  END SUBROUTINE compute_column
"""
    kernel = Subroutine.from_source(fcode_kernel, frontend=frontend)
    driver = Subroutine.from_source(fcode_driver, frontend=frontend)
    driver.enrich(kernel)  # Attach kernel source to driver call

    # Test OpenACC annotations on non-hoisted version
    scc_transform = (SCCDevectorTransformation(horizontal=horizontal),)
    scc_transform += (SCCDemoteTransformation(horizontal=horizontal),)
    scc_transform += (SCCRevectorTransformation(horizontal=horizontal),)
    scc_transform += (SCCAnnotateTransformation(block_dim=blocking),)
    scc_transform += (PragmaModelTransformation(directive='openacc'),)
    for transform in scc_transform:
        transform.apply(driver, role='driver', targets=['compute_column'])
        transform.apply(kernel, role='kernel')

    # Ensure routine is anntoated at vector level
    pragmas = FindNodes(Pragma).visit(kernel.ir)
    assert len(pragmas) == 6
    assert pragmas[0].keyword == 'acc'
    assert pragmas[0].content == 'routine vector'
    assert pragmas[1].keyword == 'acc'
    assert pragmas[1].content == 'data present(q, other_var, more_var)'
    assert pragmas[-1].keyword == 'acc'
    assert pragmas[-1].content == 'end data'

    # Ensure vector and seq loops are annotated, including privatized variable `b`
    with pragmas_attached(kernel, Loop):
        kernel_loops = FindNodes(Loop).visit(kernel.ir)
        assert len(kernel_loops) == 3
        assert kernel_loops[0].pragma[0].keyword == 'acc'
        assert kernel_loops[0].pragma[0].content == 'loop vector'
        assert kernel_loops[1].pragma[0].keyword == 'acc'
        assert kernel_loops[1].pragma[0].content == 'loop seq'
        assert kernel_loops[2].pragma[0].keyword == 'acc'
        assert kernel_loops[2].pragma[0].content == 'loop vector private(b)'

    # Ensure array reduction intrinsic is still in the correct place
    assert 'nproma' in kernel.variable_map['tmp'].dimensions
    assert 'tmp(start:end)' in list(FindInlineCalls().visit(kernel.body))[0].parameters

    # Ensure a single outer parallel loop in driver
    with pragmas_attached(driver, Loop):
        driver_loops = FindNodes(Loop).visit(driver.body)
        assert len(driver_loops) == 1
        assert driver_loops[0].pragma[0].keyword.lower() == 'acc'
        if acc_data:
            assert driver_loops[0].pragma[0].content == 'parallel loop gang vector_length(nlon)'
        else:
            assert driver_loops[0].pragma[0].content in (
                'parallel loop gang private(other_var, more_var) vector_length(nlon)',
                'parallel loop gang private(more_var, other_var) vector_length(nlon)'
            )


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('directive', [False, 'openacc', 'omp-gpu'])
def test_scc_annotate_directive(frontend, horizontal, blocking, directive):
    """
    Test the correct addition of OpenACC pragmas to SCC format code (no hoisting).
    """

    fcode_driver = """
  SUBROUTINE column_driver(nlon, nproma, nlev, nz, q, nb)
    INTEGER, INTENT(IN)   :: nlon, nz, nb  ! Size of the horizontal and vertical
    INTEGER, INTENT(IN)   :: nproma, nlev  ! Aliases of horizontal and vertical sizes
    REAL, INTENT(INOUT)   :: q(nlon,nz,nb)
    REAL :: other_var(nlon), more_var(nlon)
    INTEGER :: b, start, end

    start = 1
    end = nlon

    do b=1, nb
      call compute_column(start, end, nlon, nproma, nz, q(:,:,b), other_var, more_var)
    end do

  END SUBROUTINE column_driver
"""

    fcode_kernel = """
  SUBROUTINE compute_column(start, end, nlon, nproma, nlev, nz, q, other_var, more_var)
    INTEGER, INTENT(IN) :: start, end   ! Iteration indices
    INTEGER, INTENT(IN) :: nlon, nz     ! Size of the horizontal and vertical
    INTEGER, INTENT(IN) :: nproma, nlev ! Aliases of horizontal and vertical sizes
    REAL, INTENT(INOUT) :: q(nlon,nz)
    REAL, INTENT(IN) :: other_var(nlon), more_var(nlon)
    REAL :: t(nlon,nz)
    REAL :: a(nlon)
    REAL :: d(nproma)
    REAL :: e(nlev)
    REAL :: b(nlon,psize)
    INTEGER, PARAMETER :: psize = 3
    INTEGER :: jl, jk
    REAL :: c

    c = 5.345
    DO jk = 2, nz
      DO jl = start, end
        t(jl, jk) = c * jk
        q(jl, jk) = q(jl, jk-1) + t(jl, jk) * c
      END DO
    END DO

    ! The scaling is purposefully upper-cased
    DO JL = START, END
      a(jl) = Q(JL, 1)
      b(jl, 1) = Q(JL, 2)
      b(jl, 2) = Q(JL, 3)
      b(jl, 3) = a(jl) * (b(jl, 1) + b(jl, 2))

      Q(JL, NZ) = Q(JL, NZ) * C
    END DO
  END SUBROUTINE compute_column
"""
    kernel = Subroutine.from_source(fcode_kernel, frontend=frontend)
    driver = Subroutine.from_source(fcode_driver, frontend=frontend)
    driver.enrich(kernel)  # Attach kernel source to driver call

    # Test OpenACC annotations on non-hoisted version
    scc_transform = (SCCDevectorTransformation(horizontal=horizontal),)
    scc_transform += (SCCDemoteTransformation(horizontal=horizontal),)
    scc_transform += (SCCSeqRevectorTransformation(horizontal=horizontal),)
    scc_transform += (SCCAnnotateTransformation(block_dim=blocking),)
    scc_transform += (PragmaModelTransformation(directive=directive),)
    for transform in scc_transform:
        transform.apply(driver, role='driver', targets=['compute_column'])
        transform.apply(kernel, role='kernel')

    if directive is None:
        # Ensure routine is anntoated at vector level
        pragmas = FindNodes(Pragma).visit(kernel.ir)
        for pragma in pragmas:
            print(f"pragma: {pragma.keyword} - {pragma.content}")
        assert len(pragmas) == 4
        assert pragmas[0].keyword == 'loki'
        assert pragmas[0].content == 'routine seq'

        # Ensure vector and seq loops are annotated, including privatized variable `b`
        with pragmas_attached(kernel, Loop):
            kernel_loops = FindNodes(Loop).visit(kernel.ir)
            assert len(kernel_loops) == 1
            assert kernel_loops[0].pragma[0].keyword == 'loki'
            assert kernel_loops[0].pragma[0].content == 'loop seq'

        # Ensure a single outer parallel loop in driver
        with pragmas_attached(driver, Loop):
            driver_loops = FindNodes(Loop).visit(driver.body)
            assert len(driver_loops) == 2
            assert driver_loops[0].pragma[0].keyword.lower() == 'loki'
            assert driver_loops[0].pragma[0].content in (
                'loop gang private(other_var, more_var) vlength(nlon)',
                'loop gang private(more_var, other_var) vlength(nlon)'
            )
            assert driver_loops[1].pragma[0].keyword.lower() == 'loki'
            assert driver_loops[1].pragma[0].content == 'loop vector'
    if directive == 'openacc':
        # Ensure routine is anntoated at vector level
        pragmas = FindNodes(Pragma).visit(kernel.ir)
        assert len(pragmas) == 4
        assert pragmas[0].keyword == 'acc'
        assert pragmas[0].content == 'routine seq'
        assert pragmas[1].keyword == 'acc'
        assert pragmas[1].content == 'data present(q, other_var, more_var)'
        assert pragmas[-1].keyword == 'acc'
        assert pragmas[-1].content == 'end data'

        # Ensure vector and seq loops are annotated, including privatized variable `b`
        with pragmas_attached(kernel, Loop):
            kernel_loops = FindNodes(Loop).visit(kernel.ir)
            assert len(kernel_loops) == 1
            assert kernel_loops[0].pragma[0].keyword == 'acc'
            assert kernel_loops[0].pragma[0].content == 'loop seq'

        # Ensure a single outer parallel loop in driver
        with pragmas_attached(driver, Loop):
            driver_loops = FindNodes(Loop).visit(driver.body)
            assert len(driver_loops) == 2
            assert driver_loops[0].pragma[0].keyword.lower() == 'acc'
            assert driver_loops[0].pragma[0].content in (
                'parallel loop gang private(other_var, more_var) vector_length(nlon)',
                'parallel loop gang private(more_var, other_var) vector_length(nlon)'
            )
            assert driver_loops[1].pragma[0].keyword.lower() == 'acc'
            assert driver_loops[1].pragma[0].content == 'loop vector'
    if directive == 'omp-gpu':
        # Ensure routine is anntoated at vector level
        pragmas = FindNodes(Pragma).visit(kernel.ir)
        assert len(pragmas) == 4
        assert pragmas[0].keyword == 'omp'
        assert pragmas[0].content == 'declare target'
        assert pragmas[1].keyword == 'loki'
        assert pragmas[1].content == 'device-present vars(q, other_var, more_var)'
        assert pragmas[-1].keyword == 'loki'
        assert pragmas[-1].content == 'end device-present'

        # Ensure vector and seq loops are annotated, including privatized variable `b`
        with pragmas_attached(kernel, Loop):
            kernel_loops = FindNodes(Loop).visit(kernel.ir)
            assert len(kernel_loops) == 1
            assert kernel_loops[0].pragma[0].keyword == 'loki'
            assert kernel_loops[0].pragma[0].content == 'loop seq'

        # Ensure a single outer parallel loop in driver
        with pragmas_attached(driver, Loop):
            driver_loops = FindNodes(Loop).visit(driver.body)
            assert len(driver_loops) == 2
            assert driver_loops[0].pragma[0].keyword.lower() == 'omp'
            assert driver_loops[0].pragma[0].content == (
                'target teams distribute thread_limit(nlon)'
            )
            assert driver_loops[1].pragma[0].keyword.lower() == 'omp'
            assert driver_loops[1].pragma[0].content == 'parallel do'


@pytest.mark.parametrize('frontend', available_frontends())
def test_scc_nested(frontend, horizontal, blocking):
    """
    Test the correct handling of nested vector-level routines in SCC.
    """

    fcode_driver = """
  SUBROUTINE column_driver(nlon, nz, q, nb)
    INTEGER, INTENT(IN)   :: nlon, nz, nb  ! Size of the horizontal and vertical
    REAL, INTENT(INOUT)   :: q(nlon,nz,nb)
    INTEGER :: b, start, end

    start = 1
    end = nlon
    associate(x => q)
    do b=1, nb
      call compute_column(start, end, nlon, nz, x(:,:,b))
    end do
    end associate
  END SUBROUTINE column_driver
"""

    fcode_outer_kernel = """
  SUBROUTINE compute_column(start, end, nlon, nz, q)
    INTEGER, INTENT(IN) :: start, end  ! Iteration indices
    INTEGER, INTENT(IN) :: nlon, nz    ! Size of the horizontal and vertical
    REAL, INTENT(INOUT) :: q(nlon,nz)
    INTEGER :: jl, jk
    REAL :: c

    c = 5.345
    DO JL = START, END
      Q(JL, NZ) = Q(JL, NZ) + 1.0
    END DO

    call update_q(start, end, nlon, nz, q, c)

    DO JL = START, END
      Q(JL, NZ) = Q(JL, NZ) * C
    END DO
  END SUBROUTINE compute_column
"""

    fcode_inner_kernel = """
  SUBROUTINE update_q(start, end, nlon, nz, q, c)
    INTEGER, INTENT(IN) :: start, end  ! Iteration indices
    INTEGER, INTENT(IN) :: nlon, nz    ! Size of the horizontal and vertical
    REAL, INTENT(INOUT) :: q(nlon,nz)
    REAL, INTENT(IN)    :: c
    REAL :: t(nlon,nz)
    INTEGER :: jl, jk

    DO jk = 2, nz
      DO jl = start, end
        t(jl, jk) = c * jk
        q(jl, jk) = q(jl, jk-1) + t(jl, jk) * c
      END DO
    END DO
  END SUBROUTINE update_q
"""

    outer_kernel = Subroutine.from_source(fcode_outer_kernel, frontend=frontend)
    inner_kernel = Subroutine.from_source(fcode_inner_kernel, frontend=frontend)
    driver = Subroutine.from_source(fcode_driver, frontend=frontend)
    outer_kernel.enrich(inner_kernel)  # Attach kernel source to driver call
    driver.enrich(outer_kernel)  # Attach kernel source to driver call

    # Instantial SCCVector pipeline and apply
    scc_pipeline = SCCVectorPipeline(
        horizontal=horizontal, block_dim=blocking, directive='openacc'
    )
    scc_pipeline.apply(driver, role='driver', targets=['compute_column'])
    scc_pipeline.apply(outer_kernel, role='kernel', targets=['compute_q'])
    scc_pipeline.apply(inner_kernel, role='kernel')

    # Apply annotate twice to test bailing out mechanism
    scc_annotate = SCCAnnotateTransformation(block_dim=blocking)
    scc_annotate.apply(driver, role='driver', targets=['compute_column'])
    scc_annotate.apply(outer_kernel, role='kernel', targets=['compute_q'])
    scc_annotate.apply(inner_kernel, role='kernel')

    # Ensure a single outer parallel loop in driver
    with pragmas_attached(driver, Loop):
        driver_loops = FindNodes(Loop).visit(driver.body)
        assert len(driver_loops) == 1
        assert driver_loops[0].variable == 'b'
        assert driver_loops[0].bounds == '1:nb'
        assert driver_loops[0].pragma[0].keyword == 'acc'
        assert driver_loops[0].pragma[0].content == 'parallel loop gang vector_length(nlon)'

        # Ensure we have a kernel call in the driver loop
        kernel_calls = FindNodes(CallStatement).visit(driver_loops[0])
        assert len(kernel_calls) == 1
        assert kernel_calls[0].name == 'compute_column'

    # Ensure that the intermediate kernel contains two wrapped loops and an unwrapped call statement
    with pragmas_attached(outer_kernel, Loop):
        outer_kernel_loops = FindNodes(Loop).visit(outer_kernel.body)
        assert len(outer_kernel_loops) == 2
        assert outer_kernel_loops[0].variable == 'jl'
        assert outer_kernel_loops[0].bounds == 'start:end'
        assert outer_kernel_loops[0].pragma[0].keyword == 'acc'
        assert outer_kernel_loops[0].pragma[0].content == 'loop vector'
        assert outer_kernel_loops[1].variable == 'jl'
        assert outer_kernel_loops[1].bounds == 'start:end'
        assert outer_kernel_loops[1].pragma[0].keyword == 'acc'
        assert outer_kernel_loops[1].pragma[0].content == 'loop vector'

        # Ensure we still have a call, but not in the loops
        assert len(FindNodes(CallStatement).visit(outer_kernel_loops[0])) == 0
        assert len(FindNodes(CallStatement).visit(outer_kernel_loops[1])) == 0
        assert len(FindNodes(CallStatement).visit(outer_kernel.body)) == 1

        # Ensure the routine has been marked properly
        outer_kernel_pragmas = FindNodes(Pragma).visit(outer_kernel.ir)
        assert len(outer_kernel_pragmas) == 3
        assert outer_kernel_pragmas[0].keyword == 'acc'
        assert outer_kernel_pragmas[0].content == 'routine vector'
        assert outer_kernel_pragmas[1].keyword == 'acc'
        assert outer_kernel_pragmas[1].content == 'data present(q)'
        assert outer_kernel_pragmas[2].keyword == 'acc'
        assert outer_kernel_pragmas[2].content == 'end data'

    # Ensure that the leaf kernel contains two nested loops
    with pragmas_attached(inner_kernel, Loop):
        inner_kernel_loops = FindNodes(Loop).visit(inner_kernel.body)
        assert len(inner_kernel_loops) == 2
        assert inner_kernel_loops[1] in FindNodes(Loop).visit(inner_kernel_loops[0].body)
        assert inner_kernel_loops[0].variable == 'jl'
        assert inner_kernel_loops[0].bounds == 'start:end'
        assert inner_kernel_loops[0].pragma[0].keyword == 'acc'
        assert inner_kernel_loops[0].pragma[0].content == 'loop vector'
        assert inner_kernel_loops[1].variable == 'jk'
        assert inner_kernel_loops[1].bounds == '2:nz'
        assert inner_kernel_loops[1].pragma[0].keyword == 'acc'
        assert inner_kernel_loops[1].pragma[0].content == 'loop seq'

        # Ensure the routine has been marked properly
        inner_kernel_pragmas = FindNodes(Pragma).visit(inner_kernel.ir)
        assert len(inner_kernel_pragmas) == 3
        assert inner_kernel_pragmas[0].keyword == 'acc'
        assert inner_kernel_pragmas[0].content == 'routine vector'
        assert outer_kernel_pragmas[1].keyword == 'acc'
        assert outer_kernel_pragmas[1].content == 'data present(q)'
        assert outer_kernel_pragmas[2].keyword == 'acc'
        assert outer_kernel_pragmas[2].content == 'end data'


@pytest.mark.parametrize('frontend', available_frontends())
def test_scc_outer_loop(frontend, horizontal, blocking):
    """
    Test the correct handling of an outer loop that breaks scoping.
    """

    fcode_kernel = """
  SUBROUTINE compute_column(start, end, nlon, nz, q)
    INTEGER, INTENT(IN) :: start, end  ! Iteration indices
    INTEGER, INTENT(IN) :: nlon, nz    ! Size of the horizontal and vertical
    REAL, INTENT(INOUT) :: q(nlon,nz)
    INTEGER :: jl, jk, niter
    LOGICAL :: maybe
    REAL :: c

    if (maybe)  call logger()

    c = 5.345
    DO JL = START, END
      Q(JL, NZ) = Q(JL, NZ) + 3.0
    END DO

    DO niter = 1, 3

      DO JL = START, END
        Q(JL, NZ) = Q(JL, NZ) + 1.0
      END DO

      call update_q(start, end, nlon, nz, q, c)

    END DO

    DO JL = START, END
      Q(JL, NZ) = Q(JL, NZ) * C
    END DO

    IF (.not. maybe) THEN
      call update_q(start, end, nlon, nz, q, c)
    END IF

    DO JL = START, END
      Q(JL, NZ) = Q(JL, NZ) + C * 3.
    END DO

    IF (maybe)  call logger()

  END SUBROUTINE compute_column
"""
    kernel = Subroutine.from_source(fcode_kernel, frontend=frontend)

    # Test SCC transform for kernel with scope-splitting outer loop
    scc_pipeline = SCCVectorPipeline(
        horizontal=horizontal, block_dim=blocking, directive='openacc'
    )
    scc_pipeline.apply(kernel, role='kernel')

    # Ensure that we capture vector loops outside the outer vertical loop, as well as the one vector loop inside it.
    with pragmas_attached(kernel, Loop):
        kernel_loops = FindNodes(Loop).visit(kernel.body)
        assert len(kernel_loops) == 5
        assert kernel_loops[2] in kernel_loops[1].body

        assert kernel_loops[0].variable == 'jl'
        assert kernel_loops[0].bounds == 'start:end'
        assert kernel_loops[0].pragma[0].keyword == 'acc'
        assert kernel_loops[0].pragma[0].content == 'loop vector'
        assert kernel_loops[1].variable == 'niter'
        assert kernel_loops[1].bounds == '1:3'
        assert kernel_loops[1].pragma[0].keyword == 'acc'
        assert kernel_loops[1].pragma[0].content == 'loop seq'
        assert kernel_loops[2].variable == 'jl'
        assert kernel_loops[2].bounds == 'start:end'
        assert kernel_loops[2].pragma[0].keyword == 'acc'
        assert kernel_loops[2].pragma[0].content == 'loop vector'
        assert kernel_loops[3].variable == 'jl'
        assert kernel_loops[3].bounds == 'start:end'
        assert kernel_loops[3].pragma[0].keyword == 'acc'
        assert kernel_loops[3].pragma[0].content == 'loop vector'
        assert kernel_loops[4].variable == 'jl'
        assert kernel_loops[4].bounds == 'start:end'
        assert kernel_loops[4].pragma[0].keyword == 'acc'
        assert kernel_loops[4].pragma[0].content == 'loop vector'

        # Ensure we still have a call, but only in the outer counter loop
        assert len(FindNodes(CallStatement).visit(kernel_loops[0])) == 0
        assert len(FindNodes(CallStatement).visit(kernel_loops[1])) == 1
        assert len(FindNodes(CallStatement).visit(kernel_loops[2])) == 0
        assert len(FindNodes(CallStatement).visit(kernel_loops[3])) == 0
        assert len(FindNodes(CallStatement).visit(kernel_loops[4])) == 0
        assert len(FindNodes(CallStatement).visit(kernel.body)) == 4


@pytest.mark.parametrize('frontend', available_frontends())
def test_scc_variable_demotion(frontend, horizontal):
    """
    Test the correct demotion of an outer loop that breaks scoping.
    """

    fcode_kernel = """
  SUBROUTINE compute_column(start, end, nlon, nz)
    INTEGER, INTENT(IN) :: start, end  ! Iteration indices
    INTEGER, INTENT(IN) :: nlon, nz    ! Size of the horizontal and vertical
    REAL :: a(nlon), b(nlon), c(nlon)
    INTEGER :: jl, jk, niter

    DO JL = START, END
      A(JL) = A(JL) + 3.0
      B(JL) = B(JL) + 1.0
    END DO

    DO niter = 1, 3

      DO JL = START, END
        B(JL) = B(JL) + 1.0
      END DO

    END DO

    call update_q(start, end, nlon, nz)

    DO JL = START, END
      A(JL) = A(JL) + 3.0
      C(JL) = C(JL) + 1.0
    END DO

  END SUBROUTINE compute_column
"""
    kernel = Subroutine.from_source(fcode_kernel, frontend=frontend)

    # Test SCC transform for kernel with scope-splitting outer loop
    scc_transform = (SCCDevectorTransformation(horizontal=horizontal),)
    scc_transform += (SCCDemoteTransformation(horizontal=horizontal),)
    for transform in scc_transform:
        transform.apply(kernel, role='kernel')

    # Ensure that only a has not been demoted, as it buffers information across the subroutine call.
    assert isinstance(kernel.variable_map['a'], Array)
    assert isinstance(kernel.variable_map['b'], Scalar)
    assert isinstance(kernel.variable_map['c'], Scalar)


@pytest.mark.parametrize('frontend', available_frontends())
def test_scc_multicond(frontend, horizontal, blocking):
    """
    Test if horizontal loops in multiconditionals with CallStatements are
    correctly transformed.
    """

    fcode = """
    subroutine test(icase, start, end, work)
    implicit none

      integer, intent(in) :: icase, start, end
      real, dimension(start:end), intent(inout) :: work
      integer :: jl

      select case(icase)
      case(1)
        work(start:end) = 1.
      case(2)
        do jl = start,end
           work(jl) = work(jl) + 2.
        enddo
      case(3)
        do jl = start,end
           work(jl) = work(jl) + 3.
        enddo
        call some_kernel(start, end, work)
      case default
        work(start:end) = 0.
      end select

    end subroutine test
    """

    kernel = Subroutine.from_source(fcode, frontend=frontend)

    scc_pipeline = SCCVectorPipeline(
        horizontal=horizontal, block_dim=blocking, directive='openacc'
    )
    scc_pipeline.apply(kernel, role='kernel')

    # Ensure we have three vector loops in the kernel
    kernel_loops = FindNodes(Loop).visit(kernel.body)
    assert len(kernel_loops) == 4
    assert kernel_loops[0].variable == 'jl'
    assert kernel_loops[1].variable == 'jl'
    assert kernel_loops[2].variable == 'jl'
    assert kernel_loops[3].variable == 'jl'

    # Check acc pragmas of newly created vector loops
    pragmas = FindNodes(Pragma).visit(kernel.ir)
    assert len(pragmas) == 7
    assert pragmas[2].keyword == 'acc'
    assert pragmas[2].content == 'loop vector'
    assert pragmas[3].keyword == 'acc'
    assert pragmas[3].content == 'loop vector'
    assert pragmas[4].keyword == 'acc'
    assert pragmas[4].content == 'loop vector'
    assert pragmas[5].keyword == 'acc'
    assert pragmas[5].content == 'loop vector'


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('dims_type', ['pointer', 'allocatable', 'static'])
@pytest.mark.parametrize('data_offload', [True, False])
def test_scc_multiple_acc_pragmas(frontend, horizontal, blocking, dims_type, tmp_path, caplog,
                                  data_offload):
    """
    Test that both '!$acc data' and '!$acc parallel loop gang' pragmas are created at the
    driver layer.
    """

    fcode = f"""
    module test_mod

    type dims_type
      integer :: klon
      integer :: ist
      integer :: iend
      integer :: kbl
      integer :: wrk(100)
    end type

    contains

    subroutine test(work, nb, dims)
    implicit none

      integer, intent(in) :: nb
      type(dims_type), intent(in) :: dims
      real, dimension(nlon, nb), intent(inout) :: work
      {'type(dims_type), pointer :: local_dims' if dims_type == 'pointer' else ''}
      {'type(dims_type), allocatable :: local_dims' if dims_type == 'allocatable' else ''}
      {'type(dims_type) :: local_dims' if dims_type == 'static' else ''}
      integer :: b

      !$acc data present(dims)
      !$loki data
      !$omp parallel do private(b) shared(work, nproma)
        do b=1, nb
           local_dims%ist = 1
           local_dims%iend = dims%klon
           local_dims%kbl = b
           local_dims%wrk = 0
           call some_kernel(local_dims, local_dims%klon, work(:,b))
        enddo
      !$omp end parallel do
      !$loki end data
      !$acc end data

    end subroutine test

    subroutine some_kernel(dims, nlon, work)
    implicit none

      type(dims_type), intent(in) :: dims
      integer, intent(in) :: nlon
      real, dimension(nlon), target, intent(inout) :: work
      real, pointer :: tmp(:) => null()
      integer :: jl

      do jl=dims%ist,dims%iend
         work(jl) = work(jl) + 1.
      enddo

      tmp => work

      do jl=dims%ist,dims%iend
         tmp(jl) = tmp(jl) + 1.
      enddo

    end subroutine some_kernel
    end module test_mod
    """

    source = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    routine = source['test']
    routine.enrich(source.all_subroutines)

    if data_offload:
        data_offload = DataOffloadTransformation(remove_openmp=True)
        data_offload.transform_subroutine(routine, role='driver', targets=['some_kernel',])
    pragma_model = PragmaModelTransformation(directive='openacc')
    pragma_model.transform_subroutine(routine, role='driver', targets=['some_kernel',])

    scc_pipeline = SCCVectorPipeline(
        horizontal=horizontal, block_dim=blocking, directive='openacc',
        privatise_derived_types=True
    )

    if dims_type in ['pointer', 'allocatable']:
        with caplog.at_level(WARNING):
            scc_pipeline.apply(routine, role='driver', targets=['some_kernel',])
        if frontend == OMNI:
            assert len(caplog.records) == 3
            message = caplog.records[2].message
        else:
            assert len(caplog.records) == 2
            message = caplog.records[1].message
        assert "[Loki-SCC::Annotate] dynamically allocated structs are being privatised: ['local_dims']" in message
    else:
        scc_pipeline.apply(routine, role='driver', targets=['some_kernel',])

    pragmas = FindNodes(Pragma).visit(routine.ir)
    assert len(pragmas) == 6

    assert pragmas[0].keyword.lower() == 'acc'
    assert pragmas[0].content == 'data present(dims)'
    assert pragmas[5].content == 'end data'
    assert pragmas[5].keyword.lower() == 'acc'

    if data_offload:
        assert all(p.keyword.lower() == 'acc' for p in pragmas[1:5])
        assert pragmas[2].content == 'parallel loop gang private(local_dims) vector_length(dims%klon)'
        assert pragmas[3].content == 'end parallel loop'
        assert pragmas[4].content == 'end data'

        assert 'data copy(work)' in pragmas[1].content
    else:
        assert pragmas[1].keyword == 'loki'
        assert pragmas[2].keyword.lower() == 'omp'
        assert pragmas[3].keyword.lower() == 'omp'
        assert pragmas[4].keyword == 'loki'
        assert pragmas[2].content == 'parallel do private(b) shared(work, nproma)'
        assert pragmas[3].content == 'end parallel do'

    # check that pointer association was correctly identified as a separator node
    routine = source['some_kernel']
    scc_pipeline.apply(routine, role='kernel')

    loops = FindNodes(Loop).visit(routine.body)
    assert len(loops) == 2


@pytest.mark.parametrize('frontend', available_frontends())
def test_scc_driver_loop_async(frontend, horizontal, blocking):
    """
    Test that an annotated async driver loop will be mapped to an async acc pragma
    """

    fcode = """
    subroutine test(work, nlon, nb)
    implicit none

      integer, intent(in) :: nb, nlon
      real, dimension(nlon, nb), intent(inout) :: work
      integer :: b, queue

      queue = 0

      !$loki data async(queue)
      !$loki driver-loop async(queue)
        do b=1, nb
           call some_kernel(nlon, work(:,b))
        enddo
      !$loki end data

    end subroutine test

    subroutine some_kernel(nlon, work)
    implicit none

      integer, intent(in) :: nlon
      real, dimension(nlon), intent(inout) :: work
      integer :: jl

      do jl=1,nlon
         work(jl) = work(jl) + 1.
      enddo

    end subroutine some_kernel
    """

    source = Sourcefile.from_source(fcode, frontend=frontend)
    routine = source['test']
    routine.enrich(source.all_subroutines)

    data_offload = DataOffloadTransformation(remove_openmp=True)
    data_offload.transform_subroutine(routine, role='driver', targets=['some_kernel',])
    pragma_model = PragmaModelTransformation(directive='openacc')
    pragma_model.transform_subroutine(routine, role='driver', targets=['some_kernel',])

    scc_pipeline = SCCVectorPipeline(
        horizontal=horizontal, block_dim=blocking, directive='openacc'
    )
    scc_pipeline.apply(routine, role='driver', targets=['some_kernel',])

    # Check that both acc pragmas are created
    pragmas = FindNodes(Pragma).visit(routine.ir)
    assert len(pragmas) == 4
    assert pragmas[0].keyword == 'acc'
    assert pragmas[1].keyword == 'acc'
    assert pragmas[2].keyword == 'acc'
    assert pragmas[3].keyword == 'acc'

    assert 'data' in pragmas[0].content
    assert 'copy' in pragmas[0].content
    assert '(work)' in pragmas[0].content
    assert 'async(queue)' in pragmas[0].content

    assert pragmas[1].content == 'parallel loop gang vector_length(nlon) async(queue)'
    assert pragmas[2].content == 'end parallel loop'
    assert pragmas[3].content == 'end data'


@pytest.mark.parametrize('frontend', available_frontends())
def test_scc_annotate_routine_seq_pragma(frontend, blocking):
    """
    Test that `!$loki routine seq` pragmas are replaced correctly by
    `!$acc routine seq` pragmas.
    """

    fcode = """
    subroutine some_kernel(work, nang)
       implicit none

       integer, intent(in) :: nang
       real, dimension(nang), intent(inout) :: work
       integer :: k
!$loki routine seq

       do k=1,nang
          work(k) = 1.
       enddo

    end subroutine some_kernel
    """

    routine = Subroutine.from_source(fcode, frontend=frontend)

    pragmas = FindNodes(Pragma).visit(routine.ir)
    assert len(pragmas) == 1
    assert pragmas[0].keyword == 'loki'
    assert pragmas[0].content == 'routine seq'

    transformation = SCCAnnotateTransformation(block_dim=blocking)
    transformation.transform_subroutine(routine, role='kernel', targets=['some_kernel',])
    pragma_model = PragmaModelTransformation(directive='openacc')
    pragma_model.transform_subroutine(routine, role='driver', targets=['some_kernel',])

    # Ensure the routine pragma is in the first pragma in the spec
    pragmas = FindNodes(Pragma).visit(routine.spec)
    assert len(pragmas) == 1
    assert pragmas[0].keyword == 'acc'
    assert pragmas[0].content == 'routine seq'


@pytest.mark.parametrize('frontend', available_frontends())
def test_scc_annotate_empty_data_clause(frontend, blocking):
    """
    Test that we do not generate empty `!$acc data` clauses.
    """

    fcode = """
    subroutine some_kernel(n)
       implicit none
       ! Scalars should not show up in `!$acc data` clause
       integer, intent(inout) :: n
!$loki routine seq
       integer :: k

       k = n
       do k=1, 3
          n = k + 1
       enddo
    end subroutine some_kernel
    """
    routine = Subroutine.from_source(fcode, frontend=frontend)

    pragmas = FindNodes(Pragma).visit(routine.ir)
    assert len(pragmas) == 1
    assert pragmas[0].keyword == 'loki'
    assert pragmas[0].content == 'routine seq'

    transformation = SCCAnnotateTransformation(block_dim=blocking)
    transformation.transform_subroutine(routine, role='kernel', targets=['some_kernel',])
    pragma_model = PragmaModelTransformation(directive='openacc')
    pragma_model.transform_subroutine(routine, role='driver', targets=['some_kernel',])

    # Ensure the routine pragma is in the first pragma in the spec
    pragmas = FindNodes(Pragma).visit(routine.ir)
    assert len(pragmas) == 1
    assert pragmas[0].keyword == 'acc'
    assert pragmas[0].content == 'routine seq'


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('pipeline', [SCCVVectorPipeline, SCCSVectorPipeline])
def test_scc_vector_reduction(frontend, pipeline, horizontal, blocking):
    """
    Test for the insertion of OpenACC vector reduction directives.
    """

    fcode = """
    subroutine some_kernel(start, end, nlon, mij, lcond)
       integer, intent(in) :: nlon, start, end
       integer, dimension(nlon), intent(in) :: mij
       logical, intent(in) :: lcond

       integer :: jl, maxij, sumij, sum0

       maxij = -1
       !$loki vector-reduction( mAx:maXij )
       do jl=start,end
          maxij = max(maxij, mij(jl))
       enddo
       !$loki end vector-reduction( mAx:maXij )

       do jl=start,end
          mij(jl) = jl
       enddo

       if (lcond) then
          sumij = 0
          !$loki vector-reduction( +: sUmij, sUm0 )
          do jl=start,end
             sumij = sumij + mij(jl)
          enddo
          !$loki end vector-reduction( +: sUmij, sUm0 )
       endif

       do jl=start,end
          mij(jl) = 0
       enddo

    end subroutine some_kernel
    """

    scc_pipeline = pipeline(
        horizontal=horizontal, block_dim=blocking, directive='openacc'
    )

    source = Sourcefile.from_source(fcode, frontend=frontend)
    routine = source['some_kernel']

    with pragma_regions_attached(routine):
        regions = FindNodes(PragmaRegion).visit(routine.body)
        for region in regions:
            assert is_loki_pragma(region.pragma, starts_with = 'vector-reduction')


    if pipeline == SCCSVectorPipeline:
        with pytest.raises(TransformationError):
            scc_pipeline.apply(routine, role='kernel', targets=['some_kernel',])
    else:
        scc_pipeline.apply(routine, role='kernel', targets=['some_kernel',])

    pragmas = FindNodes(Pragma).visit(routine.body)
    if pipeline == SCCVVectorPipeline:
        assert len(pragmas) == 6
        assert all(p.keyword == 'acc' for p in pragmas)

        # Check OpenACC directives have been inserted
        with pragmas_attached(routine, Loop):
            loops = FindNodes(Loop).visit(routine.body)
            assert len(loops) == 4
            assert loops[0].pragma[0].content == 'loop vector reduction( mAx:maXij )'
            assert loops[2].pragma[0].content == 'loop vector reduction( +: sUmij, sUm0 )'

            conds = FindNodes(Conditional).visit(routine.body)
            assert len(conds) == 1
            loops = FindNodes(Loop).visit(conds[0].body)
            assert len(loops) == 1
            assert loops[0].pragma[0].content == 'loop vector reduction( +: sUmij, sUm0 )'


@pytest.mark.parametrize('frontend', available_frontends())
def test_scc_demotion_parameter(frontend, horizontal, tmp_path):
    """
    Test that temporary arrays with compile-time constants are marked for demotion.
    """

    fcode_mod = """
    module YOWPARAM
       integer, parameter :: nang_param = 36
    end module YOWPARAM
    """

    fcode_kernel = """
    subroutine some_kernel(start, end, nlon, nang)
       use yowparam, only: nang_param
       implicit none

       integer, intent(in) :: nlon, start, end, nang
       real, dimension(nlon, nang_param, 2) :: work

       integer :: jl, k

       do jl=start,end
          do k=1,nang
             work(jl,k,1) = 1.
             work(jl,k,2) = 1.
          enddo
       enddo

    end subroutine some_kernel
    """

    source = Sourcefile.from_source(fcode_mod, frontend=frontend, xmods=[tmp_path])
    routine = Subroutine.from_source(fcode_kernel, definitions=source.definitions,
                                     frontend=frontend, xmods=[tmp_path])

    scc_transform = (SCCDevectorTransformation(horizontal=horizontal),)
    scc_transform += (SCCDemoteTransformation(horizontal=horizontal),)
    for transform in scc_transform:
        transform.apply(routine, role='kernel', targets=['some_kernel',])

    assert len(routine.symbol_map['work'].shape) == 2
    if frontend == OMNI:
        assert routine.symbol_map['work'].shape == (IntLiteral(36), IntLiteral(2))
    else:
        assert routine.symbol_map['work'].shape == ('nang_param', IntLiteral(2))


@pytest.mark.parametrize('frontend', available_frontends())
def test_scc_base_horizontal_bounds_checks(frontend, horizontal, horizontal_bounds_aliases, tmp_path):
    """
    Test the SCCBaseTransformation checks for horizontal loop bounds.
    """

    fcode = """
subroutine kernel(start, end, work)
    real, intent(inout) :: work
    integer, intent(in) :: start, end

end subroutine kernel
    """.strip()

    fcode_no_start = """
subroutine kernel(end, work)
    real, intent(inout) :: work
    integer, intent(in) :: end

end subroutine kernel
    """.strip()

    fcode_no_end = """
subroutine kernel(start, work)
    real, intent(inout) :: work
    integer, intent(in) :: start

end subroutine kernel
    """.strip()

    fcode_alias = """
module bnds_type_mod
    implicit none
    type bnds_type
        integer :: start
        integer :: end
    end type bnds_type
end module bnds_type_mod

subroutine kernel(bnds, work)
    use bnds_type_mod, only : bnds_type
    type(bnds_type), intent(in) :: bnds
    real, intent(inout) :: work

end subroutine kernel
    """.strip()

    routine = Subroutine.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    no_start = Subroutine.from_source(fcode_no_start, frontend=frontend, xmods=[tmp_path])
    no_end = Subroutine.from_source(fcode_no_end, frontend=frontend, xmods=[tmp_path])
    alias = Sourcefile.from_source(fcode_alias, frontend=frontend, xmods=[tmp_path]).subroutines[0]

    transform = SCCBaseTransformation(horizontal=horizontal)
    with pytest.raises(TransformationError):
        transform.apply(no_start, role='kernel')
    with pytest.raises(TransformationError):
        transform.apply(no_end, role='kernel')

    transform = SCCBaseTransformation(horizontal=horizontal_bounds_aliases)
    transform.apply(alias, role='kernel')

    bounds = get_loop_bounds(routine, dimension=horizontal_bounds_aliases)
    assert bounds[0] == 'start'
    assert bounds[1] == 'end'

    bounds = get_loop_bounds(alias, dimension=horizontal_bounds_aliases)
    assert bounds[0] == 'bnds%start'
    assert bounds[1] == 'bnds%end'


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('inline_internals', [False, True])
@pytest.mark.parametrize('resolve_sequence_association', [False, True])
def test_scc_inline_and_sequence_association(
        frontend, horizontal, inline_internals, resolve_sequence_association
):
    """
    Test the combinations of routine inlining and sequence association
    """

    fcode_kernel = """
    subroutine some_kernel(nlon, start, end)
       implicit none

       integer, intent(in) :: nlon, start, end
       real, dimension(nlon) :: work

       call contained_kernel(work(1))

     contains

       subroutine contained_kernel(work)
          implicit none

          real, dimension(nlon) :: work
          integer :: jl

          do jl = start, end
             work(jl) = 1.
          enddo

       end subroutine contained_kernel
    end subroutine some_kernel
    """

    routine = Subroutine.from_source(fcode_kernel, frontend=frontend)

    # Remove sequence association via SanitiseTransform
    sanitise_transform = SanitiseTransformation(
        resolve_sequence_association=resolve_sequence_association
    )
    sanitise_transform.apply(routine, role='kernel')

    # Create member inlining transformation to go along SCC
    inline_transform = InlineTransformation(inline_internals=inline_internals)

    scc_transform = SCCBaseTransformation(horizontal=horizontal)

    #Not really doing anything for contained routines
    if (not inline_internals and not resolve_sequence_association):
        inline_transform.apply(routine, role='kernel')
        scc_transform.apply(routine, role='kernel')

        assert len(routine.members) == 1
        assert not FindNodes(Loop).visit(routine.body)

    #Should fail because it can't resolve sequence association
    elif (inline_internals and not resolve_sequence_association):
        with pytest.raises(TransformationError) as e_info:
            inline_transform.apply(routine, role='kernel')
            scc_transform.apply(routine, role='kernel')
        assert(
            (
                '[Loki::TransformInline] Cannot resolve procedure call to contained_kernel'
            ) in e_info.exconly()
        )

    #Check that the call is properly modified
    elif (not inline_internals and resolve_sequence_association):
        inline_transform.apply(routine, role='kernel')
        scc_transform.apply(routine, role='kernel')

        assert len(routine.members) == 1
        call = FindNodes(CallStatement).visit(routine.body)[0]
        assert fgen(call).lower() == 'call contained_kernel(work(1:nlon))'

    #Check that the contained subroutine has been inlined
    else:
        inline_transform.apply(routine, role='kernel')
        scc_transform.apply(routine, role='kernel')

        assert len(routine.members) == 0

        loop = FindNodes(Loop).visit(routine.body)[0]
        assert loop.variable == 'jl'
        assert loop.bounds == 'start:end'

        assign = FindNodes(Assignment).visit(loop.body)[0]
        assert fgen(assign).lower() == 'work(jl) = 1.'


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('pipeline,mode', [
    (SCCVVectorPipeline, 'vector'),
    (SCCSVectorPipeline, 'seq')
])
def test_scc_annotate_driver_loop(frontend, tmp_path, pipeline, mode):
    """
    Test for issue #246 to ensure the correct loop in the loop nest
    is identified and annotated as driver loop
    """
    fcode = """
module mod
    implicit none
    integer, parameter :: ngpblk = 2
    integer, parameter :: klon = 10
    integer, parameter :: klev = 5
    contains
    subroutine wrapper(arr)
        integer, parameter :: niter = 10
        integer, intent(in) :: arr(klon, klev, ngpblk)
        integer :: jblk, jiter

        ! Begin relevant part.
        DO jiter=1,niter
            DO jblk=1,ngpblk
                CALL kernel(klon, klev, 1, 1, arr(:, :, jblk))
            END DO
        END DO
        ! End relevant part.

    end subroutine wrapper

    subroutine kernel(klon, klev, jlev_lower, jlon_lower, arr)
        integer, intent(in) :: klon
        integer, intent(in) :: klev
        integer, intent(in) :: jlev_lower, jlon_lower
        integer, intent(inout) :: arr(klon, klev)
        integer :: jlon, jlev

        DO jlev=jlev_lower,klev
            DO jlon=jlon_lower,klon
                arr(jlon, jlev) = jlon + jlev
            END DO
        END DO

    end subroutine kernel

end module mod
    """.strip()

    source = Sourcefile.from_source(fcode, frontend=frontend, xmods=(tmp_path,))

    horizontal = Dimension('horizontal', index='JLON', bounds=['jlon_lower', 'klon'], size='KLON')
    block_dim = Dimension('block_dim', index='JBLK', size='NGPBLK')

    # Test OpenACC annotations on non-hoisted version
    for transform in pipeline(horizontal=horizontal, block_dim=block_dim, directive='openacc').transformations:
        transform.apply(source['wrapper'], role='driver', targets=['kernel'])
        transform.apply(source['kernel'], role='kernel')

    # Ensure routine is annotated at vector level
    pragmas = FindNodes(Pragma).visit(source['kernel'].ir)
    if mode == 'vector':
        assert len(pragmas) == 5
    else:
        assert len(pragmas) == 4
    assert pragmas[0].keyword == 'acc'
    assert pragmas[0].content == f'routine {mode}'
    assert pragmas[1].keyword == 'acc'
    assert pragmas[1].content == 'data present(arr)'
    assert pragmas[-1].keyword == 'acc'
    assert pragmas[-1].content == 'end data'

    # Ensure vector and seq loops are annotated, including privatized variable `b`
    with pragmas_attached(source['kernel'], Loop):
        kernel_loops = FindNodes(Loop).visit(source['kernel'].ir)
        assert kernel_loops[0].pragma[-1].keyword == 'acc'
        assert kernel_loops[0].pragma[-1].content == f'loop {mode}'
        if mode == 'vector':
            assert len(kernel_loops) == 2
            assert kernel_loops[1].pragma[-1].keyword == 'acc'
            assert kernel_loops[1].pragma[-1].content == 'loop seq'
        else:
            assert len(kernel_loops) == 1

    # Ensure a single outer parallel loop in driver
    with pragmas_attached(source['wrapper'], Loop):
        driver_loops = FindNodes(Loop).visit(source['wrapper'].body)
        if mode == 'vector':
            assert len(driver_loops) == 2
        else:
            assert len(driver_loops) == 3
        assert not driver_loops[0].pragma
        assert driver_loops[1].pragma[-1].keyword == 'acc'
        assert driver_loops[1].pragma[-1].content == 'parallel loop gang'
loki-ecmwf-0.3.6/loki/transformations/single_column/tests/test_scc_hoist.py0000664000175000017500000010116215167130205027520 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki import Sourcefile, Dimension, fgen, SGraph
from loki.batch import (
    Scheduler, SchedulerConfig, ProcedureItem
)
from loki.frontend import available_frontends
from loki.ir import (
    FindNodes, Assignment, CallStatement, Loop, Pragma,
    pragmas_attached, Import
)

from loki.transformations import (
    InlineTransformation, PragmaModelTransformation
)
from loki.transformations.single_column import (
    SCCBaseTransformation, SCCDevectorTransformation,
    SCCDemoteTransformation, SCCRevectorTransformation,
    SCCAnnotateTransformation, SCCHoistPipeline,
    SCCVHoistPipeline, SCCSHoistPipeline
)


@pytest.fixture(scope='module', name='horizontal')
def fixture_horizontal():
    return Dimension(
        name='horizontal', size='nlon', index='jl',
        bounds=('start', 'end'), aliases=('nproma',)
    )

@pytest.fixture(scope='module', name='vertical')
def fixture_vertical():
    return Dimension(name='vertical', size='nz', index='jk', aliases=('nlev',))


@pytest.fixture(scope='module', name='blocking')
def fixture_blocking():
    return Dimension(name='blocking', size='nb', index='b', aliases=('block_var%nb',))


@pytest.fixture(scope='module', name='blocking_alt')
def fixture_blocking_alt():
    return Dimension(name='blocking', size='dims%nb', index='b', aliases=('block_var%nb',))


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('hoist_pipeline', [SCCVHoistPipeline, SCCSHoistPipeline])
def test_scc_hoist_multiple_kernels(frontend, horizontal, blocking_alt, hoist_pipeline, tmp_path):
    """
    Test hoisting of column temporaries to "driver" level.
    """

    fcode_dims_mod = """
  MODULE dims_mod
    TYPE dims_type
      integer :: nb
    END TYPE
  END MODULE dims_mod
"""

    fcode_driver = """
  SUBROUTINE column_driver(nlon, nz, q, dims)
    USE dims_mod, ONLY : dims_type
    INTEGER, INTENT(IN)   :: nlon, nz  ! Size of the horizontal and vertical
    TYPE(DIMS_TYPE), INTENT(IN) :: dims
    REAL, INTENT(INOUT)   :: q(nlon,nz,dims%nb)
    INTEGER :: b, start, end

    start = 1
    end = nlon
    do b=1, dims%nb
      call compute_column(start, end, nlon, nz, q(:,:,b))

      ! A second call, to check multiple calls are honored
      call compute_column(start, end, nlon, nz, q(:,:,b))
    end do
  END SUBROUTINE column_driver
"""

    fcode_kernel = """
  SUBROUTINE compute_column(start, end, nlon, nz, q)
    INTEGER, INTENT(IN) :: start, end  ! Iteration indices
    INTEGER, INTENT(IN) :: nlon, nz    ! Size of the horizontal and vertical
    REAL, INTENT(INOUT) :: q(nlon,nz)
    REAL :: t(nlon,nz)
    INTEGER :: jl, jk
    REAL :: c

    c = 5.345
    DO jk = 2, nz
      DO jl = start, end
        t(jl, jk) = c * jk
        q(jl, jk) = q(jl, jk-1) + t(jl, jk) * c
      END DO
    END DO

    ! The scaling is purposefully upper-cased
    DO JL = START, END
      Q(JL, NZ) = Q(JL, NZ) * C
    END DO
  END SUBROUTINE compute_column
"""
    dims_mod_source = Sourcefile.from_source(fcode_dims_mod, frontend=frontend, xmods=[tmp_path])
    kernel_source = Sourcefile.from_source(fcode_kernel, frontend=frontend)
    dims_mod = dims_mod_source['dims_mod']
    driver_source = Sourcefile.from_source(fcode_driver, frontend=frontend,
                                           definitions=[dims_mod], xmods=[tmp_path])
    driver = driver_source['column_driver']
    kernel = kernel_source['compute_column']
    driver.enrich(kernel)  # Attach kernel source to driver call

    driver_item = ProcedureItem(name='#column_driver', source=driver_source)
    kernel_item = ProcedureItem(name='#compute_column', source=kernel_source)

    scc_hoist = hoist_pipeline(
        horizontal=horizontal, block_dim=blocking_alt, directive='openacc'
    )

    graph_dic = {driver_item: [kernel_item]}
    graph = SGraph.from_dict(graph_dic)
    # Apply pipeline in reverse order to ensure analysis runs before hoisting
    scc_hoist.apply(kernel, role='kernel', item=kernel_item)
    scc_hoist.apply(
        driver, role='driver', item=driver_item,
        sub_sgraph=graph,
        targets=['compute_column']
    )

    # Ensure we two loops left in kernel
    kernel_loops = FindNodes(Loop).visit(kernel.body)
    if hoist_pipeline == SCCSHoistPipeline:
        assert len(kernel_loops) == 1
        assert kernel_loops[0].variable == 'jk'
        assert kernel_loops[0].bounds == '2:nz'
    else:
        assert len(kernel_loops) == 2
        assert kernel_loops[0].variable == 'jl'
        assert kernel_loops[0].bounds == 'start:end'
        assert kernel_loops[1].variable == 'jk'
        assert kernel_loops[1].bounds == '2:nz'

    # Ensure all expressions and array indices are unchanged
    assigns = FindNodes(Assignment).visit(kernel.body)
    assert fgen(assigns[1]).lower() == 't(jl, jk) = c*jk'
    assert fgen(assigns[2]).lower() == 'q(jl, jk) = q(jl, jk - 1) + t(jl, jk)*c'
    assert fgen(assigns[3]).lower() == 'q(jl, nz) = q(jl, nz)*c'

    # Ensure we have only one driver block loop
    driver_loops = FindNodes(Loop).visit(driver.body)
    if hoist_pipeline == SCCSHoistPipeline:
        assert len(driver_loops) == 3
        assert driver_loops[1].variable == 'jl'
        assert driver_loops[1].bounds == 'start:end'
        calls = FindNodes(CallStatement).visit(driver_loops[1])
        assert len(calls) == 1
        assert calls[0].name == 'compute_column'
        assert driver_loops[2].variable == 'jl'
        assert driver_loops[2].bounds == 'start:end'
        calls = FindNodes(CallStatement).visit(driver_loops[2])
        assert len(calls) == 1
        assert calls[0].name == 'compute_column'
    else:
        assert len(driver_loops) == 1
    assert driver_loops[0].variable == 'b'
    assert driver_loops[0].bounds == '1:dims%nb'

    # Ensure we have two kernel calls in the driver loop
    kernel_calls = FindNodes(CallStatement).visit(driver_loops[0])
    assert len(kernel_calls) == 2
    assert kernel_calls[0].name == 'compute_column'
    assert kernel_calls[1].name == 'compute_column'
    assert 'compute_column_t(:,:,b)' in kernel_calls[0].arguments
    assert 'compute_column_t(:,:,b)' in kernel_calls[1].arguments

    # Ensure that column local `t(nlon,nz)` has been hoisted
    assert 't' in kernel.argnames
    assert kernel.variable_map['t'].type.intent.lower() == 'inout'
    assert kernel.variable_map['t'].type.shape == ('nlon', 'nz')
    assert driver.variable_map['compute_column_t'].dimensions == ('nlon', 'nz', 'dims%nb')


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('trim_vector_sections', [True, False])
def test_scc_hoist_multiple_kernels_loops(tmp_path, frontend, trim_vector_sections, horizontal, blocking):
    """
    Test hoisting of column temporaries to "driver" level.
    """

    fcode_driver = """
  SUBROUTINE driver(nlon, nz, q, nb, lflag)
    use kernel_mod, only: kernel
    implicit none
    INTEGER, INTENT(IN)   :: nlon, nz, nb  ! Size of the horizontal and vertical
    REAL, INTENT(INOUT)   :: q(nlon,nz,nb)
    LOGICAL, INTENT(IN)   :: lflag
    REAL                  :: c, tmp(nlon,nz,nb)
    INTEGER :: b, jk, jl, start, end

    tmp = 0.0

    !$loki driver-loop
    do b=1, nb
      end = nlon - nb
      do jk = 2, nz
        do jl = start, end
          q(jl, jk, b) = 2.0 * jk * jl
        end do
      end do
    end do

    do b=2, nb
      end = nlon - nb
      call kernel(start, end, nlon, nz, q(:,:,b))

      DO jk = 2, nz
        DO jl = start, end
          tmp(jl, jk, b) = 2.0 * jk * jl
          q(jl, jk, b) = q(jl, jk-1, b) * c + tmp(jl, jk, b)
        END DO
      END DO

      ! A second call, to check multiple calls are honored
      call kernel(start, end, nlon, nz, q(:,:,b))

      DO jk = 2, nz
        DO jl = start, end
          q(jl, jk, b) = (-1.0) * q(jl, jk, b)
        END DO
      END DO
    end do

    !$loki driver-loop
    do b=3, nb
      if (lflag) then
        end = nlon - nb
        !$loki separator
        do jk = 2, nz
          do jl = start, end
            q(jl, jk, b) = 2.0 * jk * jl
          end do
        end do
      endif
    end do
  END SUBROUTINE driver
""".strip()

    fcode_kernel = """
MODULE kernel_mod
implicit none
CONTAINS
  SUBROUTINE kernel(start, end, nlon, nz, q)
    implicit none
    INTEGER, INTENT(IN) :: start, end  ! Iteration indices
    INTEGER, INTENT(IN) :: nlon, nz    ! Size of the horizontal and vertical
    REAL, INTENT(INOUT) :: q(nlon,nz)
    REAL :: t(nlon,nz)
    INTEGER :: jl, jk
    REAL :: c

    c = 5.345
    DO jk = 2, nz
      DO jl = start, end
        t(jl, jk) = c * k
        q(jl, jk) = q(jl, jk-1) + t(jl, jk) * c
      END DO
    END DO

    ! The scaling is purposefully upper-cased
    DO JL = START, END
      Q(JL, NZ) = Q(JL, NZ) * C
    END DO
  END SUBROUTINE kernel
END MODULE kernel_mod
""".strip()

    (tmp_path / 'driver.F90').write_text(fcode_driver)
    (tmp_path / 'kernel.F90').write_text(fcode_kernel)

    config = {
        'default': {
            'mode': 'idem',
            'role': 'kernel',
            'expand': True,
            'strict': True
        },
        'routines': {
            'driver': {'role': 'driver'}
        }
    }
    scheduler = Scheduler(
        paths=[tmp_path], config=SchedulerConfig.from_dict(config), frontend=frontend, xmods=[tmp_path]
    )

    driver = scheduler["#driver"].ir
    kernel = scheduler["kernel_mod#kernel"].ir

    transformation = (SCCBaseTransformation(horizontal=horizontal),)
    transformation += (SCCDevectorTransformation(horizontal=horizontal, trim_vector_sections=trim_vector_sections),)
    transformation += (SCCDemoteTransformation(horizontal=horizontal),)
    transformation += (SCCRevectorTransformation(horizontal=horizontal),)
    transformation += (SCCAnnotateTransformation(block_dim=blocking),)
    transformation += (PragmaModelTransformation(directive='openacc'),)
    for transform in transformation:
        scheduler.process(transformation=transform)

    kernel_loops = FindNodes(Loop).visit(kernel.body)
    assert len(kernel_loops) == 2
    assert kernel_loops[0].variable == 'jl'
    assert kernel_loops[0].bounds == 'start:end'
    assert kernel_loops[1].variable == 'jk'
    assert kernel_loops[1].bounds == '2:nz'

    driver_loops = FindNodes(Loop).visit(driver.body)
    driver_loop_pragmas = [pragma for pragma in FindNodes(Pragma).visit(driver.body) if pragma.keyword.lower() == 'acc']
    assert len(driver_loops) == 11
    assert len(driver_loop_pragmas) == 14
    assert "parallel loop gang vector_length" in driver_loop_pragmas[0].content.lower()
    assert "loop vector" in driver_loop_pragmas[1].content.lower()
    assert "loop seq" in driver_loop_pragmas[2].content.lower()
    assert "end parallel loop" in driver_loop_pragmas[3].content.lower()
    assert "parallel loop gang vector_length" in driver_loop_pragmas[4].content.lower()
    assert "loop vector" in driver_loop_pragmas[5].content.lower()
    assert "loop seq" in driver_loop_pragmas[6].content.lower()
    assert "loop vector" in driver_loop_pragmas[7].content.lower()
    assert "loop seq" in driver_loop_pragmas[8].content.lower()
    assert "end parallel loop" in driver_loop_pragmas[9].content.lower()
    assert "parallel loop gang vector_length" in driver_loop_pragmas[10].content.lower()
    assert "loop vector" in driver_loop_pragmas[11].content.lower()
    assert "loop seq" in driver_loop_pragmas[12].content.lower()
    assert "end parallel loop" in driver_loop_pragmas[13].content.lower()

    assert driver_loops[1] in FindNodes(Loop).visit(driver_loops[0].body)
    assert driver_loops[2] in FindNodes(Loop).visit(driver_loops[0].body)
    assert driver_loops[0].variable == 'b'
    assert driver_loops[0].bounds == '1:nb'
    assert driver_loops[1].variable == 'jl'
    assert driver_loops[1].bounds == 'start:end'
    assert driver_loops[2].variable == 'jk'
    assert driver_loops[2].bounds == '2:nz'

    # check location of loop-bound assignment
    assign = FindNodes(Assignment).visit(driver_loops[0])[0]
    assert assign.lhs == 'end'
    assert assign.rhs == 'nlon-nb'
    assigns = FindNodes(Assignment).visit(driver_loops[1])
    if trim_vector_sections:
        assert not assign in assigns
    else:
        assert assign in assigns

    assert driver_loops[4] in FindNodes(Loop).visit(driver_loops[3].body)
    assert driver_loops[5] in FindNodes(Loop).visit(driver_loops[3].body)
    assert driver_loops[6] in FindNodes(Loop).visit(driver_loops[3].body)
    assert driver_loops[7] in FindNodes(Loop).visit(driver_loops[3].body)
    kernel_calls = FindNodes(CallStatement).visit(driver_loops[3])
    assert len(kernel_calls) == 2
    assert kernel_calls[0].name == 'kernel'
    assert kernel_calls[1].name == 'kernel'

    assert driver_loops[3].variable == 'b'
    assert driver_loops[3].bounds == '2:nb'
    assert driver_loops[4].variable == 'jl'
    assert driver_loops[4].bounds == 'start:end'
    assert driver_loops[5].variable == 'jk'
    assert driver_loops[5].bounds == '2:nz'
    assert driver_loops[6].variable == 'jl'
    assert driver_loops[6].bounds == 'start:end'
    assert driver_loops[7].variable == 'jk'
    assert driver_loops[7].bounds == '2:nz'

    assert driver_loops[9] in FindNodes(Loop).visit(driver_loops[8].body)
    assert driver_loops[10] in FindNodes(Loop).visit(driver_loops[8].body)
    assert driver_loops[8].variable == 'b'
    assert driver_loops[8].bounds == '3:nb'
    assert driver_loops[9].variable == 'jl'
    assert driver_loops[9].bounds == 'start:end'
    assert driver_loops[10].variable == 'jk'
    assert driver_loops[10].bounds == '2:nz'

    # check location of loop-bound assignment
    assign = FindNodes(Assignment).visit(driver_loops[8])[0]
    assert assign.lhs == 'end'
    assert assign.rhs == 'nlon-nb'
    assigns = FindNodes(Assignment).visit(driver_loops[9])
    assert not assign in assigns


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('insert_pragma', [False, True])
def test_scc_hoist_openacc(insert_pragma, frontend, horizontal, blocking, tmp_path):
    """
    Test the correct addition of OpenACC pragmas to SCC format code
    when hoisting array temporaries to driver.
    """

    fcode_mod = """
MODULE BLOCK_DIM_MOD
    type block_type
      INTEGER :: nb
    end type block_type
END MODULE BLOCK_DIM_MOD
    """.strip()

    fcode_driver = f"""
SUBROUTINE column_driver(nlon, nz, q)
    USE BLOCK_DIM_MOD, ONLY : block_type
    INTEGER, INTENT(IN)   :: nlon, nz  ! Size of the horizontal and vertical
    REAL, INTENT(INOUT)   :: q(nlon,nz,block_var%nb)
    INTEGER :: b, start, end
    type(block_type) :: block_var
    INTEGER :: nb

    nb = block_var%nb

    !$loki helper

    {'!$loki stack-insert' if insert_pragma else ''}

    start = 1
    end = nlon
    do b=1, nb
      call compute_column(start, end, nlon, nz, q(:,:,b))
    end do
END SUBROUTINE column_driver
    """.strip()

    fcode_kernel = """
SUBROUTINE compute_column(start, end, nlon, nz, q)
    INTEGER, INTENT(IN) :: start, end  ! Iteration indices
    INTEGER, INTENT(IN) :: nlon, nz    ! Size of the horizontal and vertical
    REAL, TARGET, INTENT(INOUT) :: q(nlon,nz)
    REAL :: t(nlon,nz)
    REAL :: a(nlon)
    REAL :: b(nlon,psize)
    REAL, POINTER :: b_ptr(:,:)
    INTEGER, PARAMETER :: psize = 3
    INTEGER :: jl, jk
    REAL :: c

    b_ptr => q

    c = 5.345
    DO jk = 2, nz
      DO jl = start, end
        t(jl, jk) = c * jk
        q(jl, jk) = q(jl, jk-1) + t(jl, jk) * c
      END DO
    END DO

    ! The scaling is purposefully upper-cased
    DO JL = START, END
      a(jl) = Q(JL, 1)
      b(jl, 1) = Q(JL, 2)
      b(jl, 2) = Q(JL, 3)
      b(jl, 3) = a(jl) * (b(jl, 1) + b(jl, 2))

      Q(JL, NZ) = Q(JL, NZ) * C
    END DO
END SUBROUTINE compute_column
    """.strip()

    fcode_module = """
module my_scaling_value_mod
    implicit none
    REAL :: c = 5.345
end module my_scaling_value_mod
    """.strip()

    # Mimic the scheduler internal mechanis to apply the transformation cascade
    mod_source = Sourcefile.from_source(fcode_mod, frontend=frontend, xmods=[tmp_path])
    kernel_source = Sourcefile.from_source(fcode_kernel, frontend=frontend, xmods=[tmp_path])
    driver_source = Sourcefile.from_source(
        fcode_driver, frontend=frontend, definitions=mod_source.modules, xmods=[tmp_path]
    )
    module_source = Sourcefile.from_source(fcode_module, frontend=frontend, xmods=[tmp_path])
    driver = driver_source['column_driver']
    kernel = kernel_source['compute_column']
    module = module_source['my_scaling_value_mod']
    kernel.enrich(module)
    driver.enrich(kernel)  # Attach kernel source to driver call

    driver_item = ProcedureItem(name='#column_driver', source=driver_source)
    kernel_item = ProcedureItem(name='#compute_column', source=kernel_source)

    scc_hoist = SCCHoistPipeline(
        horizontal=horizontal, block_dim=blocking,
        directive='openacc'
    )

    graph_dic = {driver_item: [kernel_item]}
    graph = SGraph.from_dict(graph_dic)
    # Apply in reverse order to ensure hoisting analysis gets run on kernel first
    scc_hoist.apply(kernel, role='kernel', item=kernel_item)
    scc_hoist.apply(
        driver, role='driver', item=driver_item, targets=['compute_column'],
        sub_sgraph=graph
    )

    with pragmas_attached(kernel, Loop):
        # Ensure kernel routine is anntoated at vector level
        kernel_pragmas = FindNodes(Pragma).visit(kernel.ir)
        assert len(kernel_pragmas) == 3
        assert kernel_pragmas[0].keyword == 'acc'
        assert kernel_pragmas[0].content == 'routine vector'
        assert kernel_pragmas[1].keyword == 'acc'
        assert kernel_pragmas[1].content == 'data present(q, t)'
        assert kernel_pragmas[2].keyword == 'acc'
        assert kernel_pragmas[2].content == 'end data'

        # Ensure `seq` and `vector` loops in kernel
        kernel_loops = FindNodes(Loop).visit(kernel.body)
        assert len(kernel_loops) == 2
        assert kernel_loops[0].pragma[0].keyword == 'acc'
        assert kernel_loops[0].pragma[0].content == 'loop vector private(b)'
        assert kernel_loops[1].pragma[0].keyword == 'acc'
        assert kernel_loops[1].pragma[0].content == 'loop seq'

    # Ensure two levels of blocked parallel loops in driver
    with pragmas_attached(driver, Loop):
        driver_loops = FindNodes(Loop).visit(driver.body)
        assert len(driver_loops) == 1
        assert driver_loops[0].pragma[0].keyword == 'acc'
        assert driver_loops[0].pragma[0].content == 'parallel loop gang vector_length(nlon)'

        # Ensure device allocation and teardown via `!$acc enter/exit data`
        driver_pragmas = FindNodes(Pragma).visit(driver.body)
        assert len(driver_pragmas) == 3
        enter_data_pragma = 0 if not insert_pragma else 1
        assert driver_pragmas[enter_data_pragma].keyword == 'acc'
        assert driver_pragmas[enter_data_pragma].content == 'enter data create(compute_column_t)'
        assert driver_pragmas[2].keyword == 'acc'
        assert driver_pragmas[2].content == 'exit data delete(compute_column_t) finalize'

@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('as_kwarguments', [False, True])
def test_scc_hoist_nested_openacc(frontend, horizontal, vertical, blocking,
        as_kwarguments):
    """
    Test the correct addition of OpenACC pragmas to SCC format code
    when hoisting array temporaries to driver.
    """

    fcode_driver = """
  SUBROUTINE column_driver(nlon, nz, q, nb)
    INTEGER, INTENT(IN)   :: nlon, nz, nb  ! Size of the horizontal and vertical
    REAL, INTENT(INOUT)   :: q(nlon,nz,nb)
    INTEGER :: b, start, end

    start = 1
    end = nlon
    do b=1, nb
      call compute_column(start, end, nlon, nz, q(:,:,b))
    end do
  END SUBROUTINE column_driver
"""

    fcode_outer_kernel = """
  SUBROUTINE compute_column(start, end, nlon, nz, q)
    INTEGER, INTENT(IN) :: start, end  ! Iteration indices
    INTEGER, INTENT(IN) :: nlon, nz    ! Size of the horizontal and vertical
    REAL, INTENT(INOUT) :: q(nlon,nz)
    INTEGER :: jl, jk
    REAL :: c

    c = 5.345
    DO JL = START, END
      Q(JL, NZ) = Q(JL, NZ) + 1.0
    END DO

    call update_q(start, end, nlon, nz, q, c)

    DO JL = START, END
      Q(JL, NZ) = Q(JL, NZ) * C
    END DO
  END SUBROUTINE compute_column
"""

    fcode_inner_kernel = """
  SUBROUTINE update_q(start, end, nlon, nz, q, c)
    use, intrinsic :: iso_fortran_env, only : real64
    INTEGER, INTENT(IN) :: start, end  ! Iteration indices
    INTEGER, INTENT(IN) :: nlon, nz    ! Size of the horizontal and vertical
    REAL, INTENT(INOUT) :: q(nlon,nz)
    REAL, INTENT(IN)    :: c
    REAL(kind=real64)   :: t(nlon,nz)
    INTEGER :: jl, jk

    DO jk = 2, nz
      DO jl = start, end
        t(jl, jk) = c * jk
        q(jl, jk) = q(jl, jk-1) + t(jl, jk) * c
      END DO
    END DO
  END SUBROUTINE update_q
"""

    # Mimic the scheduler internal mechanis to apply the transformation cascade
    outer_kernel_source = Sourcefile.from_source(fcode_outer_kernel, frontend=frontend)
    inner_kernel_source = Sourcefile.from_source(fcode_inner_kernel, frontend=frontend)
    driver_source = Sourcefile.from_source(fcode_driver, frontend=frontend)
    driver = driver_source['column_driver']
    outer_kernel = outer_kernel_source['compute_column']
    inner_kernel = inner_kernel_source['update_q']
    outer_kernel.enrich(inner_kernel)  # Attach kernel source to driver call
    driver.enrich(outer_kernel)  # Attach kernel source to driver call

    driver_item = ProcedureItem(name='#column_driver', source=driver)
    outer_kernel_item = ProcedureItem(name='#compute_column', source=outer_kernel)
    inner_kernel_item = ProcedureItem(name='#update_q', source=inner_kernel)

    scc_hoist = SCCHoistPipeline(
        horizontal=horizontal, block_dim=blocking,
        dim_vars=vertical.sizes, as_kwarguments=as_kwarguments, directive='openacc'
    )

    graph_dic = {driver_item: [outer_kernel_item], outer_kernel_item: [inner_kernel_item]}
    graph = SGraph.from_dict(graph_dic)
    # Apply in reverse order to ensure hoisting analysis gets run on kernel first
    scc_hoist.apply(inner_kernel, role='kernel', item=inner_kernel_item)
    scc_hoist.apply(
        outer_kernel, role='kernel', item=outer_kernel_item,
        targets=['compute_q'], sub_sgraph=graph.get_sub_sgraph(outer_kernel_item)
    )
    scc_hoist.apply(
        driver, role='driver', item=driver_item,
        targets=['compute_column'], sub_sgraph=graph
    )

    # Ensure calls have correct arguments
    # driver
    calls = FindNodes(CallStatement).visit(driver.body)
    assert len(calls) == 1
    if not as_kwarguments:
        assert calls[0].arguments == ('start', 'end', 'nlon', 'nz', 'q(:, :, b)',
                'update_q_t(:, :, b)')
        assert calls[0].kwarguments == ()
    else:
        assert calls[0].arguments == ('start', 'end', 'nlon', 'nz', 'q(:, :, b)')
        assert calls[0].kwarguments == (('update_q_t', 'update_q_t(:, :, b)'),)
    # outer kernel
    calls = FindNodes(CallStatement).visit(outer_kernel.body)
    assert len(calls) == 1
    if not as_kwarguments:
        assert calls[0].arguments == ('start', 'end', 'nlon', 'nz', 'q', 'c', 'update_q_t')
        assert calls[0].kwarguments == ()
    else:
        assert calls[0].arguments == ('start', 'end', 'nlon', 'nz', 'q', 'c')
        assert calls[0].kwarguments == (('t', 'update_q_t'),)

    # Ensure a single outer parallel loop in driver
    with pragmas_attached(driver, Loop):
        driver_loops = FindNodes(Loop).visit(driver.body)
        assert len(driver_loops) == 1
        assert driver_loops[0].variable == 'b'
        assert driver_loops[0].bounds == '1:nb'
        assert driver_loops[0].pragma[0].keyword == 'acc'
        assert driver_loops[0].pragma[0].content == 'parallel loop gang vector_length(nlon)'

        # Ensure we have a kernel call in the driver loop
        kernel_calls = FindNodes(CallStatement).visit(driver_loops[0])
        assert len(kernel_calls) == 1
        assert kernel_calls[0].name == 'compute_column'

    # Ensure that the intermediate kernel contains two wrapped loops and an unwrapped call statement
    with pragmas_attached(outer_kernel, Loop):
        outer_kernel_loops = FindNodes(Loop).visit(outer_kernel.body)
        assert len(outer_kernel_loops) == 2
        assert outer_kernel_loops[0].variable == 'jl'
        assert outer_kernel_loops[0].bounds == 'start:end'
        assert outer_kernel_loops[0].pragma[0].keyword == 'acc'
        assert outer_kernel_loops[0].pragma[0].content == 'loop vector'
        assert outer_kernel_loops[1].variable == 'jl'
        assert outer_kernel_loops[1].bounds == 'start:end'
        assert outer_kernel_loops[1].pragma[0].keyword == 'acc'
        assert outer_kernel_loops[1].pragma[0].content == 'loop vector'

        # Ensure we still have a call, but not in the loops
        assert len(FindNodes(CallStatement).visit(outer_kernel_loops[0])) == 0
        assert len(FindNodes(CallStatement).visit(outer_kernel_loops[1])) == 0
        assert len(FindNodes(CallStatement).visit(outer_kernel.body)) == 1

        # Ensure the routine has been marked properly
        outer_kernel_pragmas = FindNodes(Pragma).visit(outer_kernel.ir)
        assert len(outer_kernel_pragmas) == 3
        assert outer_kernel_pragmas[0].keyword == 'acc'
        assert outer_kernel_pragmas[0].content == 'routine vector'
        assert outer_kernel_pragmas[1].keyword == 'acc'
        assert outer_kernel_pragmas[1].content == 'data present(q, update_q_t)'
        assert outer_kernel_pragmas[2].keyword == 'acc'
        assert outer_kernel_pragmas[2].content == 'end data'

    # Ensure that the leaf kernel contains two nested loops
    with pragmas_attached(inner_kernel, Loop):
        inner_kernel_loops = FindNodes(Loop).visit(inner_kernel.body)
        assert len(inner_kernel_loops) == 2
        assert inner_kernel_loops[1] in FindNodes(Loop).visit(inner_kernel_loops[0].body)
        assert inner_kernel_loops[0].variable == 'jl'
        assert inner_kernel_loops[0].bounds == 'start:end'
        assert inner_kernel_loops[0].pragma[0].keyword == 'acc'
        assert inner_kernel_loops[0].pragma[0].content == 'loop vector'
        assert inner_kernel_loops[1].variable == 'jk'
        assert inner_kernel_loops[1].bounds == '2:nz'
        assert inner_kernel_loops[1].pragma[0].keyword == 'acc'
        assert inner_kernel_loops[1].pragma[0].content == 'loop seq'

        # Ensure the routine has been marked properly
        inner_kernel_pragmas = FindNodes(Pragma).visit(inner_kernel.ir)
        assert len(inner_kernel_pragmas) == 3
        assert inner_kernel_pragmas[0].keyword == 'acc'
        assert inner_kernel_pragmas[0].content == 'routine vector'
        assert outer_kernel_pragmas[1].keyword == 'acc'
        assert outer_kernel_pragmas[1].content == 'data present(q, update_q_t)'
        assert outer_kernel_pragmas[2].keyword == 'acc'
        assert outer_kernel_pragmas[2].content == 'end data'

    # check that kind import was added to driver
    imports = FindNodes(Import).visit(driver.spec)
    assert imports
    assert imports[0].module.lower() == 'iso_fortran_env'
    assert 'real64' in imports[0].symbols
    assert driver.variable_map['update_q_t'].type.kind.name.lower() == 'real64'


@pytest.mark.parametrize('frontend', available_frontends())
def test_scc_hoist_nested_inline_openacc(frontend, horizontal, vertical, blocking):
    """
    Test the correct addition of OpenACC pragmas to SCC format code
    when hoisting array temporaries to driver.
    """

    fcode_driver = """
  SUBROUTINE column_driver(nlon, nz, q, nb)
    INTEGER, INTENT(IN)   :: nlon, nz, nb  ! Size of the horizontal and vertical
    REAL, INTENT(INOUT)   :: q(nlon,nz,nb)
    INTEGER :: b, start, end

    start = 1
    end = nlon
    do b=1, nb
      call compute_column(start, end, nlon, nz, q(:,:,b))
    end do
  END SUBROUTINE column_driver
"""

    fcode_outer_kernel = """
  SUBROUTINE compute_column(start, end, nlon, nz, q)
    INTEGER, INTENT(IN) :: start, end  ! Iteration indices
    INTEGER, INTENT(IN) :: nlon, nz    ! Size of the horizontal and vertical
    REAL, INTENT(INOUT) :: q(nlon,nz)
    INTEGER :: jl, jk
    REAL :: c

    c = 5.345
    DO JL = START, END
      Q(JL, NZ) = Q(JL, NZ) + 1.0
    END DO

    !$loki inline
    call update_q(start, end, nlon, nz, q, c)

    DO JL = START, END
      Q(JL, NZ) = Q(JL, NZ) * C
    END DO
  END SUBROUTINE compute_column
"""

    fcode_inner_kernel = """
  SUBROUTINE update_q(start, end, nlon, nz, q, c)
    INTEGER, INTENT(IN) :: start, end  ! Iteration indices
    INTEGER, INTENT(IN) :: nlon, nz    ! Size of the horizontal and vertical
    REAL, INTENT(INOUT) :: q(nlon,nz)
    REAL, INTENT(IN)    :: c
    REAL :: t(nlon,nz)
    INTEGER :: jl, jk

    DO jk = 2, nz
      DO jl = start, end
        t(jl, jk) = c * jk
        q(jl, jk) = q(jl, jk-1) + t(jl, jk) * c
      END DO
    END DO
  END SUBROUTINE update_q
"""

    # Mimic the scheduler internal mechanis to apply the transformation cascade
    outer_kernel_source = Sourcefile.from_source(fcode_outer_kernel, frontend=frontend)
    inner_kernel_source = Sourcefile.from_source(fcode_inner_kernel, frontend=frontend)
    driver_source = Sourcefile.from_source(fcode_driver, frontend=frontend)
    driver = driver_source['column_driver']
    outer_kernel = outer_kernel_source['compute_column']
    inner_kernel = inner_kernel_source['update_q']
    outer_kernel.enrich(inner_kernel)  # Attach kernel source to driver call
    driver.enrich(outer_kernel)  # Attach kernel source to driver call

    driver_item = ProcedureItem(name='#column_driver', source=driver)
    outer_kernel_item = ProcedureItem(name='#compute_column', source=outer_kernel)
    inner_kernel_item = ProcedureItem(name='#update_q', source=inner_kernel)

    scc_hoist = SCCHoistPipeline(
        horizontal=horizontal, block_dim=blocking,
        dim_vars=vertical.sizes, directive='openacc'
    )

    InlineTransformation(allowed_aliases=horizontal.index).apply(outer_kernel)

    graph_dic = {driver_item: [outer_kernel_item], outer_kernel_item: [inner_kernel_item]}
    graph = SGraph.from_dict(graph_dic)

    # Apply in reverse order to ensure hoisting analysis gets run on kernel first
    scc_hoist.apply(inner_kernel, role='kernel', item=inner_kernel_item)
    scc_hoist.apply(
        outer_kernel, role='kernel', item=outer_kernel_item,
        targets=['compute_q']
    )
    scc_hoist.apply(
        driver, role='driver', item=driver_item,
        targets=['compute_column'], sub_sgraph=graph
    )

    # Ensure calls have correct arguments
    # driver
    calls = FindNodes(CallStatement).visit(driver.body)
    assert len(calls) == 1
    assert calls[0].arguments == ('start', 'end', 'nlon', 'nz', 'q(:, :, b)',
            'compute_column_t(:, :, b)')

    # Ensure a single outer parallel loop in driver
    with pragmas_attached(driver, Loop):
        driver_loops = FindNodes(Loop).visit(driver.body)
        assert len(driver_loops) == 1
        assert driver_loops[0].variable == 'b'
        assert driver_loops[0].bounds == '1:nb'
        assert driver_loops[0].pragma[0].keyword == 'acc'
        assert driver_loops[0].pragma[0].content == 'parallel loop gang vector_length(nlon)'

        # Ensure we have a kernel call in the driver loop
        kernel_calls = FindNodes(CallStatement).visit(driver_loops[0])
        assert len(kernel_calls) == 1
        assert kernel_calls[0].name == 'compute_column'

    # Ensure that the intermediate kernel contains two wrapped loops and an unwrapped call statement
    with pragmas_attached(outer_kernel, Loop):
        outer_kernel_loops = FindNodes(Loop).visit(outer_kernel.body)
        assert len(outer_kernel_loops) == 2
        assert outer_kernel_loops[0].variable == 'jl'
        assert outer_kernel_loops[0].bounds == 'start:end'
        assert outer_kernel_loops[0].pragma[0].keyword == 'acc'
        assert outer_kernel_loops[0].pragma[0].content == 'loop vector'

        # check correctly nested vertical loop from inlined routine
        assert outer_kernel_loops[1] in FindNodes(Loop).visit(outer_kernel_loops[0].body)

        # Ensure the call was inlined
        assert not FindNodes(CallStatement).visit(outer_kernel.body)

        # Ensure the routine has been marked properly
        outer_kernel_pragmas = FindNodes(Pragma).visit(outer_kernel.ir)
        assert len(outer_kernel_pragmas) == 3
        assert outer_kernel_pragmas[0].keyword == 'acc'
        assert outer_kernel_pragmas[0].content == 'routine vector'
        assert outer_kernel_pragmas[1].keyword == 'acc'
        assert outer_kernel_pragmas[1].content == 'data present(q, t)'
        assert outer_kernel_pragmas[2].keyword == 'acc'
        assert outer_kernel_pragmas[2].content == 'end data'
loki-ecmwf-0.3.6/loki/transformations/single_column/tests/test_scc_vertical.py0000664000175000017500000002310615167130205030204 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki import Subroutine, Dimension
from loki.frontend import available_frontends
from loki.ir import FindNodes, Loop, FindVariables
from loki.transformations.single_column import SCCFuseVerticalLoops


@pytest.fixture(scope='module', name='horizontal')
def fixture_horizontal():
    return Dimension(
        name='horizontal', size='nlon', index='jl',
        bounds=('start', 'end'), aliases=('nproma',)
    )

@pytest.fixture(scope='module', name='horizontal_bounds_aliases')
def fixture_horizontal_bounds_aliases():
    return Dimension(
        name='horizontal_bounds_aliases', size='nlon', index='jl',
        bounds=('start', 'end'), aliases=('nproma',),
        bounds_aliases=('bnds%start', 'bnds%end')
    )

@pytest.fixture(scope='module', name='vertical')
def fixture_vertical():
    return Dimension(name='vertical', size='nz', index='jk', aliases=('nlev',))

@pytest.fixture(scope='module', name='blocking')
def fixture_blocking():
    return Dimension(name='blocking', size='nb', index='b')


@pytest.mark.parametrize('frontend', available_frontends())
def test_simple_scc_fuse_verticals_transformation(frontend, horizontal, vertical):
    """
    Test simple example of vertical loop fusion and demotion of temporaries.
    """

    fcode_kernel = """
  SUBROUTINE compute_column(start, end, nlon, nz, q, t)
    INTEGER, INTENT(IN) :: start, end  ! Iteration indices
    INTEGER, INTENT(IN) :: nlon, nz    ! Size of the horizontal and vertical
    REAL, INTENT(INOUT) :: t(nlon,nz)
    REAL, INTENT(INOUT) :: q(nlon,nz)
    REAL :: temp_t(nlon, nz)
    REAL :: temp_q(nlon, nz)
    INTEGER :: jl, JK
    REAL :: c

    c = 5.345
    !$loki loop-fusion group(1)
    DO JK = 1, nz
      DO jl = start, end
        temp_t(jl, jk) = c
        temp_q(jl, JK) = c
      END DO
    END DO

    !$loki loop-fusion group(1)
    DO jk = 2, nz
      DO jl = start, end
        t(jl, jk) = temp_t(jl, jk) * jk
        q(jl, jk) = q(jl, jk-1) + t(jl, jk) * temp_q(jl, jk)
      END DO
    END DO

    ! The scaling is purposefully upper-cased
    DO JL = START, END
      Q(JL, NZ) = Q(JL, NZ) * C
    END DO
  END SUBROUTINE compute_column
"""
    kernel = Subroutine.from_source(fcode_kernel, frontend=frontend)

    # Ensure we have three loops in the kernel prior to transformation
    kernel_loops = FindNodes(Loop).visit(kernel.body)
    assert len(kernel_loops) == 5

    # no-op as 'compute_column' is not within apply_to
    SCCFuseVerticalLoops(vertical=vertical, apply_to=('another_kernel',)).apply(kernel, role='kernel')
    # Ensure we have three loops in the kernel prior to transformation
    kernel_loops = FindNodes(Loop).visit(kernel.body)
    assert len(kernel_loops) == 5

    # actual loop fusion and demotion ... (as apply_to is not provided and therefore all routines are dispatched)
    SCCFuseVerticalLoops(vertical=vertical).apply(kernel, role='kernel')

    # Ensure the two vertical loops are fused
    kernel_loops = FindNodes(Loop).visit(kernel.body)
    assert len(kernel_loops) == 4
    assert kernel_loops[0].variable.name.lower() == 'jk'
    assert kernel_loops[-1].variable.name.lower() == 'jl'
    assert len([loop for loop in kernel_loops if loop.variable.name.lower() == 'jk']) == 1
    kernel_var_map = kernel.variable_map
    assert kernel_var_map['temp_t'].shape == (horizontal.size,)
    assert kernel_var_map['temp_q'].shape == (horizontal.size,)
    kernel_vars = [var for var in FindVariables().visit(kernel.body) if var.name.lower() in ['temp_t', 'temp_q']]
    for var in kernel_vars:
        assert var.shape == (horizontal.size,)
        assert var.dimensions == (horizontal.index,)


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('ignore', (False, True))
def test_scc_fuse_verticals_transformation(frontend, horizontal, vertical, ignore):
    """
    Test somewhat more sophisticated example of vertical loop fusion
    and demotion of temporaries.
    """

    fcode_kernel = f"""
  SUBROUTINE compute_column(start, end, nlon, nz, q, t)
    INTEGER, INTENT(IN) :: start, end  ! Iteration indices
    INTEGER, INTENT(IN) :: nlon, nz    ! Size of the horizontal and vertical
    REAL, INTENT(INOUT) :: t(nlon,nz)
    REAL, INTENT(INOUT) :: q(nlon,nz)
    REAL :: temp_t(nlon, nz)
    REAL :: temp_t2(nlon, nz)
    REAL :: temp_q(nlon, nz)
    REAL :: temp_q2(nlon, nz)
    REAL :: temp_cld(nlon, nz, 5)
    INTEGER :: jl, jk, jm
    REAL :: c

    {'!$loki k-caching ignore(temp_q2)' if ignore else ''}

    c = 5.345
    !$loki loop-fusion group(1-init)
    DO jk = 1, nz
      DO jl = start, end
        temp_t(jl, jk) = c
        temp_q(jl, jk) = c
        temp_t2(jl, jk) = 2*c
      END DO
    END DO

    !$loki loop-fusion group(1)
    !$loki loop-interchange
    DO jm=1,5
      DO jk = 1, nz
        DO jl = start, end
          temp_cld(jl, jk, jm) = 3.1415
        END DO
      END DO
    END DO

    DO jl = start, end
      q(jl, jk) = 0.
    END DO

    !$loki loop-fusion group(1) insert
    DO jk = 2, nz
      DO jl = start, end
        t(jl, jk) = temp_t(jl, jk) * temp_t2(jl, jk-1) * temp_cld(jl, jk, 1)
        q(jl, jk) = q(jl, jk-1) + t(jl, jk) * temp_q(jl, jk)
      END DO
    END DO

    CALL nested_kernel(start, end, nlon, nz, q)

    !$loki loop-fusion group(2)
    DO jk = 2, nz
      DO jl = start, end
        temp_q2(jl, jk) = 3.1415
      END DO
    END DO

    !$loki loop-fusion group(2)
    DO jk = 2, nz
      DO jl = start, end
        t(jl, jk) = t(jl, jk) + 3.1415
        q(jl, jk) = q(jl, jk-1) + t(jl, jk) * temp_q(jl, jk) + temp_q2(jl, jk)
      END DO
    END DO

    ! The scaling is purposefully upper-cased
    DO JL = START, END
      Q(JL, NZ) = Q(JL, NZ) * C
    END DO
  END SUBROUTINE compute_column
"""


    kernel = Subroutine.from_source(fcode_kernel, frontend=frontend)

    # Ensure we have three loops in the kernel prior to transformation
    kernel_loops = FindNodes(Loop).visit(kernel.body)
    assert len(kernel_loops) == 13
    SCCFuseVerticalLoops(vertical=vertical).apply(kernel, role='kernel')

    # Ensure the two vertical loops are fused
    kernel_loops = FindNodes(Loop).visit(kernel.body)
    assert len(kernel_loops) == 12
    vertical_loops = [loop for loop in kernel_loops if loop.variable.name.lower() == vertical.index]
    assert len(vertical_loops) == 3

    shape1D = (horizontal.size,)
    shape2D = (horizontal.size, vertical.size)
    dimension1D = (horizontal.index,)
    dimension2D = (horizontal.index,vertical.index)
    dimension2DI1 = (horizontal.index, f'{vertical.index}-1')

    vertical_loop_0_vars = FindVariables().visit(vertical_loops[0].body)
    vertical_loop_0_var_names = [var.name.lower() for var in vertical_loop_0_vars]
    vertical_loop_0_var_dict = dict(zip(vertical_loop_0_var_names, vertical_loop_0_vars))
    assert 'temp_t2' in vertical_loop_0_var_names
    assert 'temp_t' not in vertical_loop_0_var_names
    assert 'temp_q' not in vertical_loop_0_var_names
    assert 'temp_q2' not in vertical_loop_0_var_names
    assert 'temp_cld' not in vertical_loop_0_var_names
    assert vertical_loop_0_var_dict['temp_t2'].shape == shape2D
    assert vertical_loop_0_var_dict['temp_t2'].dimensions == dimension2D

    vertical_loop_1_vars = FindVariables().visit(vertical_loops[1].body)
    vertical_loop_1_var_names = [var.name.lower() for var in vertical_loop_1_vars]
    vertical_loop_1_var_dict = dict(zip(vertical_loop_1_var_names, vertical_loop_1_vars))
    assert 'temp_t2' in vertical_loop_1_var_names
    assert 'temp_t' in vertical_loop_1_var_names
    assert 'temp_q' in vertical_loop_1_var_names
    assert 'temp_q2' not in vertical_loop_1_vars
    assert 'temp_cld' in vertical_loop_1_var_names
    assert vertical_loop_1_var_dict['temp_t2'].shape == shape2D
    assert vertical_loop_1_var_dict['temp_t2'].dimensions == dimension2DI1
    assert vertical_loop_1_var_dict['temp_t'].shape == shape1D
    assert vertical_loop_1_var_dict['temp_t'].dimensions == dimension1D
    assert vertical_loop_1_var_dict['temp_q'].shape == shape2D
    assert vertical_loop_1_var_dict['temp_q'].dimensions == dimension2D
    assert vertical_loop_1_var_dict['temp_cld'].shape == shape1D + (5,)
    assert vertical_loop_1_var_dict['temp_cld'].dimensions in (dimension1D + (1,), dimension1D + ('jm',))

    vertical_loop_2_vars = FindVariables().visit(vertical_loops[2].body)
    vertical_loop_2_var_names = [var.name.lower() for var in vertical_loop_2_vars]
    vertical_loop_2_var_dict = dict(zip(vertical_loop_2_var_names, vertical_loop_2_vars))
    assert 'temp_t2' not in vertical_loop_2_var_names
    assert 'temp_t' not in vertical_loop_2_var_names
    assert 'temp_q' in vertical_loop_2_var_names
    assert 'temp_q2' in vertical_loop_2_var_names
    assert 'temp_cld' not in vertical_loop_2_var_names
    assert vertical_loop_2_var_dict['temp_q'].shape == shape2D
    assert vertical_loop_2_var_dict['temp_q'].dimensions == dimension2D
    assert vertical_loop_2_var_dict['temp_q2'].shape == shape2D if ignore else shape1D
    assert vertical_loop_2_var_dict['temp_q2'].dimensions == dimension2D if ignore else dimension1D

    kernel_var_map = kernel.variable_map
    assert kernel_var_map['temp_t'].shape == shape1D
    assert kernel_var_map['temp_t2'].shape == shape2D
    assert kernel_var_map['temp_q'].shape == shape2D
    assert kernel_var_map['temp_q2'].shape == shape2D if ignore else shape1D
loki-ecmwf-0.3.6/loki/transformations/single_column/scc_cuf.py0000664000175000017500000012256515167130205024760 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Single-Column-Coalesced CUDA Fortran (SCC-CUF) transformation.
"""

from loki.logging import info
from loki.batch import Transformation
from loki.expression import symbols as sym
from loki.ir import (
    nodes as ir, FindNodes, Transformer, FindVariables,
    SubstituteExpressions
)
from loki.tools import CaseInsensitiveDict, as_tuple, flatten
from loki.types import BasicType, DerivedType, SymbolAttributes

from loki.transformations.array_indexing import resolve_vector_notation
from loki.transformations.temporaries.hoist_variables import HoistVariablesTransformation
from loki.transformations.sanitise import do_resolve_associates
from loki.transformations.single_column.base import SCCBaseTransformation
from loki.transformations.single_column.devector import RemoveLoopTransformer
from loki.transformations.utilities import single_variable_declaration
from loki.ir.pragma_utils import get_pragma_parameters

__all__ = [
    'HoistTemporaryArraysDeviceAllocatableTransformation',
    'HoistTemporaryArraysPragmaOffloadTransformation',
    'SccLowLevelLaunchConfiguration',
    'SccLowLevelDataOffload',
]


class HoistTemporaryArraysDeviceAllocatableTransformation(HoistVariablesTransformation):
    """
    Synthesis part for variable/array hoisting for CUDA Fortran (CUF) (transformation).
    """

    def driver_variable_declaration(self, routine, variables):
        """
        CUDA Fortran (CUF) Variable/Array device declaration including
        allocation and de-allocation.

        Parameters
        ----------
        routine: :any:`Subroutine`
            The subroutine to add the variable declaration
        var: :any:`Variable`
            The variable to be declared
        """
        for var in variables:
            vtype = var.type.clone(device=True, allocatable=True)
            routine.variables += tuple([var.clone(scope=routine, dimensions=as_tuple(
                [sym.RangeIndex((None, None))] * (len(var.dimensions))), type=vtype)])

            allocations = FindNodes(ir.Allocation).visit(routine.body)
            if allocations:
                insert_index = routine.body.body.index(allocations[-1])
                routine.body.insert(insert_index + 1, ir.Allocation((var.clone(),)))
            else:
                routine.body.prepend(ir.Allocation((var.clone(),)))
            de_allocations = FindNodes(ir.Deallocation).visit(routine.body)
            if de_allocations:
                insert_index = routine.body.body.index(de_allocations[-1])
                routine.body.insert(insert_index + 1, ir.Deallocation((var.clone(dimensions=None),)))
            else:
                routine.body.append(ir.Deallocation((var.clone(dimensions=None),)))


class HoistTemporaryArraysPragmaOffloadTransformation(HoistVariablesTransformation):
    """
    Synthesis part for variable/array hoisting, offload via pragmas e.g., OpenACC.
    """

    def driver_variable_declaration(self, routine, variables):
        """
        Standard Variable/Array declaration including
        device offload via pragmas.

        Parameters
        ----------
        routine: :any:`Subroutine`
            The subroutine to add the variable declaration
        var: :any:`Variable`
            The variable to be declared
        """
        routine.variables += tuple(var.clone(scope=routine) for var in variables)

        vnames = ', '.join(v.name for v in variables)

        pragma = ir.Pragma(keyword='loki', content=f'unstructured-data create({vnames})')
        pragma_post = ir.Pragma(keyword='loki', content=f'exit unstructured-data delete({vnames})')
        # Add comments around standalone pragmas to avoid false attachment
        routine.body.prepend((ir.Comment(''), pragma, ir.Comment('')))
        routine.body.append((ir.Comment(''), pragma_post, ir.Comment('')))


def remove_non_loki_pragmas(routine):
    """
    Remove all pragmas.

    Parameters
    ----------
    routine: :any:`Subroutine`
        The subroutine in which to remove all pragmas
    """
    pragma_map = {p: None for p in FindNodes(ir.Pragma).visit(routine.body) if p.keyword.lower()!="loki"}
    routine.body = Transformer(pragma_map).visit(routine.body)

def device_subroutine_prefix(routine, depth):
    """
    Add prefix/specifier `ATTRIBUTES(GLOBAL)` for kernel subroutines and
    `ATTRIBUTES(DEVICE)` for device subroutines.

    Parameters
    ----------
    routine: :any:`Subroutine`
        The subroutine (kernel/device subroutine) to add a prefix/specifier
    depth: int
        The subroutines depth
    """
    if depth == 1:
        routine.prefix += ("ATTRIBUTES(GLOBAL)",)
    elif depth > 1:
        routine.prefix += ("ATTRIBUTES(DEVICE)",)

class SccLowLevelLaunchConfiguration(Transformation):
    """
    Part of the pipeline for generating Single Column Coalesced
    Low Level GPU (CUDA Fortran, CUDA C, HIP, ...) for block-indexed gridpoint/single-column
    routines (responsible for the launch configuration including the chevron notation).
    """

    def __init__(self, horizontal, vertical, block_dim, transformation_type='parametrise', mode="CUF"):
        """
        Part of the pipeline for generating Single Column Coalesced
        Low Level GPU (CUDA Fortran, CUDA C, HIP, ...) for block-indexed gridpoint/single-column
        routines responsible for the launch configuration including the chevron notation.

        .. note::
            In dependence of the transformation type ``transformation_type``, further
            transformations are necessary:

            * ``transformation_type = 'parametrise'`` requires a subsequent
              :any:`ParametriseTransformation` transformation with the necessary information
              to parametrise (at least) the ``vertical`` `size`
            * ``transformation_type = 'hoist'`` requires subsequent :any:`HoistVariablesAnalysis`
              and :class:`HoistVariablesTransformation` transformations (e.g.
              :any:`HoistTemporaryArraysAnalysis` for analysis and
              :any:`HoistTemporaryArraysTransformationDeviceAllocatable` or
              :any:`HoistTemporaryArraysPragmaOffloadTransformation` for synthesis)

        Parameters
        ----------
        horizontal : :any:`Dimension`
            :any:`Dimension` object describing the variable conventions used in code
            to define the horizontal data dimension and iteration space.
        vertical : :any:`Dimension`
            :any:`Dimension` object describing the variable conventions used in code
            to define the vertical dimension, as needed to decide array privatization.
        block_dim : :any:`Dimension`
            :any:`Dimension` object to define the blocking dimension
            to use for hoisted column arrays if hoisting is enabled.
        transformation_type : str
            Kind of transformation/Handling of temporaries/local arrays

            - `parametrise`: parametrising the array dimensions to make the vertical dimension
              a compile-time constant
            - `hoist`: host side hoisting of (relevant) arrays
        mode: str
            Mode/language to target

            - `CUF` - CUDA Fortran
            - `CUDA` - CUDA C
            - `HIP` - HIP
        """
        self.horizontal = horizontal
        self.vertical = vertical
        self.block_dim = block_dim
        self.mode = mode.lower()
        assert self.mode in ['cuf', 'cuda', 'hip']

        self.transformation_type = transformation_type
        # `parametrise` : parametrising the array dimensions
        # `hoist`: host side hoisting
        info(f"[SccLowLevelLaunchConfiguration] Applying transformation type: '{self.transformation_type}'")
        assert self.transformation_type in ['parametrise', 'hoist']
        self.transformation_description = {'parametrise': 'parametrised array dimensions of local arrays',
                                           'hoist': 'host side hoisted local arrays'}

    def transform_subroutine(self, routine, **kwargs):

        item = kwargs.get('item', None)
        role = kwargs.get('role')
        depths = kwargs.get('depths', None)
        targets = kwargs.get('targets', None)
        depth = 0
        if depths is None:
            if role == 'driver':
                depth = 0
            elif role == 'kernel':
                depth = 1
        else:
            depth = depths[item]

        remove_non_loki_pragmas(routine)
        single_variable_declaration(routine=routine)
        device_subroutine_prefix(routine, depth)

        if self.mode == 'cuf':
            routine.spec.prepend(ir.Import(module="cudafor"))

        if role == 'driver':
            self.process_driver(routine, targets=targets)
        if role == 'kernel':
            self.process_kernel(routine, depth=depth, targets=targets)

        for call in FindNodes(ir.CallStatement).visit(routine.body):
            if call.name in as_tuple(targets):
                # call.sort_kwarguments()
                call.convert_kwargs_to_args()

    def process_kernel(self, routine, depth=1, targets=None):
        """
        Kernel/Device subroutine specific changes/transformations.
        Parameters
        ----------
        routine: :any:`Subroutine`
            The subroutine (kernel/device subroutine) to process
        depth: int
            The subroutines depth
        """

        self.kernel_cuf(
            routine, self.horizontal, self.vertical, self.block_dim, depth=depth,
            targets=targets
        )

    def process_driver(self, routine, targets=None):
        """
        Driver subroutine specific changes/transformations.
        Parameters
        ----------
        routine: :any:`Subroutine`
            The subroutine (driver) to process
        """

        upper, step, _, blockdim_var, griddim_var, blockdim_assignment, griddim_assignment =\
                self.driver_launch_configuration(routine=routine, block_dim=self.block_dim, targets=targets)

        if self.mode in ['cuda', 'hip']:
            call_map = {}
            for call in FindNodes(ir.CallStatement).visit(routine.body):
                if str(call.name).lower() in as_tuple(targets):
                    new_args = ()
                    if upper.name not in call.routine.arguments:
                        new_args += (upper.clone(type=upper.type.clone(intent='in'), scope=call.routine),)
                    if step.name not in call.routine.arguments:
                        new_args += (step.clone(type=step.type.clone(intent='in'), scope=call.routine),)
                    new_kwargs = tuple((_.name, _) for _ in new_args)
                    if new_args:
                        call.routine.arguments = list(call.routine.arguments) + list(new_args)
                        call_map[call] = call.clone(kwarguments=as_tuple(list(call.kwarguments) + list(new_kwargs)))
                    call.routine.variables += (blockdim_var, griddim_var)
                    call.routine.body = (blockdim_assignment, griddim_assignment) + as_tuple(call.routine.body)
            routine.body = Transformer(call_map).visit(routine.body)
        elif self.mode == 'cuf':
            routine.body.prepend(ir.Comment(f"!@cuf print *, 'executing SCC-CUF type: {self.transformation_type} - "
                                            f"{self.transformation_description[self.transformation_type]}'"))
            routine.body.prepend(ir.Comment(""))

    def kernel_cuf(self, routine, horizontal, vertical, block_dim,
               depth, targets=None):

        if SCCBaseTransformation.is_elemental(routine):
            # TODO: correct "definition" of elemental/pure routines and corresponding removing
            #  of subroutine prefix(es)/specifier(s)
            routine.prefix = as_tuple([prefix for prefix in routine.prefix if prefix not in ["ELEMENTAL"]]) #,"PURE"]])
            return

        single_variable_declaration(routine, variables=(horizontal.index, block_dim.index))

        #  this does not make any difference ...
        self.kernel_demote_private_locals(routine, horizontal, vertical)

        # find vertical and block loops and replace with implicit "loops"
        loop_map = {}
        for loop in FindNodes(ir.Loop).visit(routine.body):
            if loop.variable == self.block_dim.index or loop.variable.name in self.block_dim.sizes:
                loop_map[loop] = loop.body
            if loop.variable == self.horizontal.index or loop.variable.name in self.horizontal.sizes:
                loop_map[loop] = loop.body
        routine.body = Transformer(loop_map).visit(routine.body)

        if depth == 1:

            ## bit hacky ...
            assignments = FindNodes(ir.Assignment).visit(routine.body)
            assignments2remove = as_tuple(block_dim.index) + horizontal.bounds
            assignment_map = {assign: None for assign in assignments if assign.lhs.name in assignments2remove}
            routine.body = Transformer(assignment_map).visit(routine.body)
            ##end: bit hacky

            # CUDA thread mapping
            if self.mode == 'cuf':
                var_thread_idx = sym.Variable(name="THREADIDX")
                var_x = sym.Variable(name="X", parent=var_thread_idx)
            else:
                ctype = SymbolAttributes(DerivedType(name="threadIdx"))
                var_thread_idx = sym.Variable(name="threadIdx", case_sensitive=True)
                var_x = sym.Variable(name="x", parent=var_thread_idx, case_sensitive=True, type=ctype)
            horizontal_assignment = ir.Assignment(lhs=routine.variable_map[horizontal.index], rhs=var_x)

            if self.mode == 'cuf':
                var_thread_idx = sym.Variable(name="BLOCKIDX")
                var_x = sym.Variable(name="Z", parent=var_thread_idx)
            else:
                ctype = SymbolAttributes(DerivedType(name="blockIdx"))
                var_thread_idx = sym.Variable(name="blockIdx", case_sensitive=True)
                var_x = sym.Variable(name="x", parent=var_thread_idx, case_sensitive=True, type=ctype)
            block_dim_assignment = ir.Assignment(lhs=routine.variable_map[block_dim.index], rhs=var_x)

            condition = sym.LogicalAnd((sym.Comparison(routine.variable_map[block_dim.index], '<=',
                                                       routine.variable_map[block_dim.size]),
                                        sym.Comparison(routine.variable_map[horizontal.index], '<=',
                                                       routine.variable_map[horizontal.size])))

            routine.body = ir.Section((horizontal_assignment, block_dim_assignment, ir.Comment(''),
                            ir.Conditional(condition=condition, body=as_tuple(routine.body), else_body=())))
        for call in FindNodes(ir.CallStatement).visit(routine.body):
            if call.routine.name.lower() in targets and not SCCBaseTransformation.is_elemental(call.routine):
                horizontal_index = routine.variable_map[horizontal.index]
                block_dim_index = routine.variable_map[block_dim.index]
                additional_args = ()
                additional_kwargs = ()
                if horizontal_index.name not in call.routine.arguments:
                    if horizontal_index.name in call.routine.variables:
                        call.routine.symbol_attrs.update({horizontal_index.name:\
                                call.routine.variable_map[horizontal_index.name].type.clone(intent='in')})
                    additional_args += (horizontal_index.clone(type=horizontal_index.type.clone(intent='in'),
                                                               scope=call.routine),)
                if horizontal_index.name not in call.arg_map:
                    additional_kwargs += ((horizontal_index.name, horizontal_index.clone(scope=routine)),)

                if block_dim_index.name not in call.routine.arguments:
                    additional_args += (block_dim_index.clone(type=block_dim_index.type.clone(intent='in',
                        scope=call.routine)),)
                    additional_kwargs += ((block_dim_index.name, block_dim_index.clone(scope=routine)),)
                if additional_kwargs:
                    call._update(kwarguments=call.kwarguments+additional_kwargs)
                if additional_args:
                    call.routine.arguments += additional_args

    @staticmethod
    def kernel_demote_private_locals(routine, horizontal, vertical):
        """
        Demotes all local variables.
        Array variables whose dimensions include only the vector dimension
        or known (short) constant dimensions (eg. local vector or matrix arrays)
        can be privatized without requiring shared GPU memory. Array variables
        with unknown (at compile time) dimensions (eg. the vertical dimension)
        cannot be privatized at the vector loop level and should therefore not
        be demoted here.
        Parameters
        ----------
        routine: :any:`Subroutine`
            The subroutine to demote the private locals
        horizontal: :any:`Dimension`
            The dimension object specifying the horizontal vector dimension
        vertical: :any:`Dimension`
            The dimension object specifying the vertical loop dimension
        """

        # Establish the new dimensions and shapes first, before cloning the variables
        # The reason for this is that shapes of all variable instances are linked
        # via caching, meaning we can easily void the shape of an unprocessed variable.
        variables = list(routine.variables)
        variables += list(FindVariables(unique=False).visit(routine.body))

        # Filter out purely local array variables
        argument_map = CaseInsensitiveDict({a.name: a for a in routine.arguments})
        variables = [v for v in variables if not v.name in argument_map]
        variables = [v for v in variables if isinstance(v, sym.Array)]

        # Find all arrays with shapes that do not include the vertical
        # dimension and can thus be privatized.
        variables = [v for v in variables if v.shape is not None]
        variables = [v for v in variables if not any(vertical.size in d for d in v.shape)]

        # Filter out variables that we will pass down the call tree
        calls = FindNodes(ir.CallStatement).visit(routine.body)
        call_args = flatten(call.arguments for call in calls)
        call_args += flatten(list(dict(call.kwarguments).values()) for call in calls)
        variables = [v for v in variables if v.name not in call_args]

        shape_map = CaseInsensitiveDict({v.name: v.shape for v in variables})
        vmap = {}
        for v in variables:
            old_shape = shape_map[v.name]
            # TODO: "s for s in old_shape if s not in expressions" sufficient?
            new_shape = as_tuple(s for s in old_shape if s not in horizontal.size_expressions)

            if old_shape and old_shape[0] in horizontal.size_expressions:
                new_type = v.type.clone(shape=new_shape or None)
                new_dims = v.dimensions[1:] or None
                vmap[v] = v.clone(dimensions=new_dims, type=new_type)

        routine.body = SubstituteExpressions(vmap).visit(routine.body)
        routine.spec = SubstituteExpressions(vmap).visit(routine.spec)

    def driver_launch_configuration(self, routine, block_dim, targets=None):
        """
        Launch configuration for kernel calls within the driver with the
        CUDA Fortran (CUF) specific chevron syntax `<<>>`.
        Parameters
        ----------
        routine: :any:`Subroutine`
            The subroutine to specify the launch configurations for kernel calls.
        block_dim: :any:`Dimension`
            The dimension object specifying the block loop dimension
        targets : tuple of str
            Tuple of subroutine call names that are processed in this traversal
        """

        d_type = SymbolAttributes(DerivedType("dim3"))
        blockdim_var = sym.Variable(name="BLOCKDIM", type=d_type)
        griddim_var = sym.Variable(name="GRIDDIM", type=d_type)
        if self.mode == 'cuf':
            routine.spec.append(ir.VariableDeclaration(symbols=(griddim_var, blockdim_var)))

        # istat: status of CUDA runtime function (e.g. for cudaDeviceSynchronize(), cudaMalloc(), cudaFree(), ...)
        i_type = SymbolAttributes(BasicType.INTEGER)
        routine.spec.append(ir.VariableDeclaration(symbols=(sym.Variable(name="istat", type=i_type),)))

        blockdim_assignment = None
        griddim_assignment = None
        mapper = {}

        for call in FindNodes(ir.CallStatement).visit(routine.body):
            if call.name not in as_tuple(targets):
                continue

            if call.pragma:
                parameters = get_pragma_parameters(call.pragma, starts_with='removed_loop')
            else:
                parameters = ()
            assignment_lhs = routine.variable_map["istat"]
            assignment_rhs = sym.InlineCall(
                function=sym.ProcedureSymbol(name="cudaDeviceSynchronize", scope=routine),
                parameters=())

            upper = routine.variable_map[parameters['upper']]
            try:
                step = routine.variable_map[parameters['step']]
            except Exception as e:
                print(f"Exception: {e}")
                step = sym.IntLiteral(1)


            if self.mode == 'cuf':
                func_dim3 = sym.ProcedureSymbol(name="DIM3", scope=routine)
                func_ceiling = sym.ProcedureSymbol(name="CEILING", scope=routine)

                # BLOCKDIM
                lhs = routine.variable_map["blockdim"]
                rhs = sym.InlineCall(function=func_dim3, parameters=(step, sym.IntLiteral(1), sym.IntLiteral(1)))
                blockdim_assignment = ir.Assignment(lhs=lhs, rhs=rhs)

                # GRIDDIM
                lhs = routine.variable_map["griddim"]
                rhs = sym.InlineCall(function=func_dim3, parameters=(sym.IntLiteral(1), sym.IntLiteral(1),
                                                                    sym.InlineCall(function=func_ceiling,
                                                                                    parameters=as_tuple(
                                                                                        sym.Cast(name="REAL",
                                                                                                expression=upper) /
                                                                                        sym.Cast(name="REAL",
                                                                                                expression=step)))))
                griddim_assignment = ir.Assignment(lhs=lhs, rhs=rhs)
                mapper[call] = (blockdim_assignment, griddim_assignment, ir.Comment(""),
                        call.clone(chevron=(routine.variable_map["GRIDDIM"], routine.variable_map["BLOCKDIM"]),),
                        ir.Assignment(lhs=assignment_lhs, rhs=assignment_rhs))
            else:
                func_dim3 = sym.ProcedureSymbol(name="dim3", scope=routine)
                func_ceiling = sym.ProcedureSymbol(name="ceil", scope=routine)

                # BLOCKDIM
                lhs = blockdim_var
                rhs = sym.InlineCall(function=func_dim3, parameters=(step, sym.IntLiteral(1), sym.IntLiteral(1)))
                blockdim_assignment = ir.Assignment(lhs=lhs, rhs=rhs)
                # GRIDDIM
                lhs = griddim_var
                rhs = sym.InlineCall(function=func_dim3, parameters=(sym.InlineCall(function=func_ceiling,
                    parameters=as_tuple(
                        sym.Cast(name="REAL", expression=upper) /
                        sym.Cast(name="REAL", expression=step))),
                    sym.IntLiteral(1), sym.IntLiteral(1)))
                griddim_assignment = ir.Assignment(lhs=lhs, rhs=rhs)

        routine.body = Transformer(mapper=mapper).visit(routine.body)
        return upper, step, routine.variable_map[block_dim.size], blockdim_var, griddim_var,\
                blockdim_assignment, griddim_assignment


class SccLowLevelDataOffload(Transformation):
    """
    Part of the pipeline for generating Single Column Coalesced
    Low Level GPU (CUDA Fortran, CUDA C, HIP, ...) for block-indexed gridpoint/single-column
    routines (responsible for the data offload).
    """

    def __init__(self, horizontal, vertical, block_dim, transformation_type='parametrise',
                 derived_types=None, mode="CUF"):
        """
        Part of the pipeline for generating Single Column Coalesced
        Low Level GPU (CUDA Fortran, CUDA C, HIP, ...) for block-indexed gridpoint/single-column
        routines responsible for the data offload..

        .. note::
            In dependence of the transformation type ``transformation_type``, further
            transformations are necessary:

            * ``transformation_type = 'parametrise'`` requires a subsequent
              :any:`ParametriseTransformation` transformation with the necessary information
              to parametrise (at least) the ``vertical`` `size`
            * ``transformation_type = 'hoist'`` requires subsequent :any:`HoistVariablesAnalysis`
              and :class:`HoistVariablesTransformation` transformations (e.g.
              :any:`HoistTemporaryArraysAnalysis` for analysis and
              :any:`HoistTemporaryArraysTransformationDeviceAllocatable` or
              :any:`HoistTemporaryArraysPragmaOffloadTransformation` for synthesis)

        Parameters
        ----------
        horizontal : :any:`Dimension`
            :any:`Dimension` object describing the variable conventions used in code
            to define the horizontal data dimension and iteration space.
        vertical : :any:`Dimension`
            :any:`Dimension` object describing the variable conventions used in code
            to define the vertical dimension, as needed to decide array privatization.
        block_dim : :any:`Dimension`
            :any:`Dimension` object to define the blocking dimension
            to use for hoisted column arrays if hoisting is enabled.
        derived_types: tuple
            Derived types that are relevant
        transformation_type : str
            Kind of transformation/Handling of temporaries/local arrays

            - `parametrise`: parametrising the array dimensions to make the vertical dimension
              a compile-time constant
            - `hoist`: host side hoisting of (relevant) arrays
        mode: str
            Mode/language to target

            - `CUF` - CUDA Fortran
            - `CUDA` - CUDA C
            - `HIP` - HIP
        """
        self.horizontal = horizontal
        self.vertical = vertical
        self.block_dim = block_dim
        self.mode = mode.lower()
        assert self.mode in ['cuf', 'cuda', 'hip']

        self.transformation_type = transformation_type
        # `parametrise` : parametrising the array dimensions
        # `hoist`: host side hoisting
        assert self.transformation_type in ['parametrise', 'hoist']
        self.transformation_description = {'parametrise': 'parametrised array dimensions of local arrays',
                                           'hoist': 'host side hoisted local arrays'}

        if derived_types is None:
            self.derived_types = ()
        else:
            self.derived_types = [_.upper() for _ in derived_types]
        self.derived_type_variables = ()

    def transform_subroutine(self, routine, **kwargs):

        role = kwargs.get('role')
        targets = kwargs.get('targets', None)

        remove_non_loki_pragmas(routine)
        single_variable_declaration(routine=routine, group_by_shape=True)

        if self.mode == 'cuf':
            routine.spec.prepend(ir.Import(module="cudafor"))

        if role == 'driver':
            self.process_driver(routine, targets=targets)
        if role == 'kernel':
            self.process_kernel(routine) # , depth=depth, targets=targets)

        for call in FindNodes(ir.CallStatement).visit(routine.body):
            if str(call.name).lower() in as_tuple(targets):
                call.convert_kwargs_to_args()

    def process_driver(self, routine, targets=None):
        """
        Driver subroutine specific changes/transformations.
        Parameters
        ----------
        routine: :any:`Subroutine`
            The subroutine (driver) to process
        """

        self.derived_type_variables = self.device_derived_types(
            routine=routine, derived_types=self.derived_types, targets=targets
        )
        # create variables needed for the device execution, especially generate device versions of arrays
        self.driver_device_variables(routine=routine, targets=targets)

    def process_kernel(self, routine): # , depth=1, targets=None):
        """
        Kernel/Device subroutine specific changes/transformations.
        Parameters
        ----------
        routine: :any:`Subroutine`
            The subroutine (kernel/device subroutine) to process
        """

        do_resolve_associates(routine)
        resolve_vector_notation(routine)
        routine.body = RemoveLoopTransformer(dimension=self.horizontal).visit(routine.body)

        self.kernel_cuf(
            routine, self.horizontal, self.block_dim, self.transformation_type,
            derived_type_variables=self.derived_type_variables
        )

    def kernel_cuf(self, routine, horizontal, block_dim, transformation_type,
               derived_type_variables):

        relevant_local_arrays = []
        var_map = {}
        for var in routine.variables:
            if var in routine.arguments:
                if isinstance(var, sym.Scalar) and var not in derived_type_variables\
                        and var.type.intent.lower() == 'in':
                    var_map[var] = var.clone(type=var.type.clone(value=True))
            else:
                if isinstance(var, sym.Array):
                    dimensions = list(var.dimensions)
                    shape = list(var.shape)
                    if horizontal.size in list(FindVariables().visit(var.dimensions)):
                        if transformation_type == 'hoist':
                            dimensions += [routine.variable_map[block_dim.size]]
                            shape = list(var.shape) + [routine.variable_map[block_dim.size]]
                            vtype = var.type.clone(shape=as_tuple(shape))
                            relevant_local_arrays.append(var.name)
                        else:
                            dimensions.remove(horizontal.size)
                            shape.remove(horizontal.size)
                            relevant_local_arrays.append(var.name)
                            vtype = var.type.clone(device=True, shape=shape)
                        var_map[var] = var.clone(dimensions=as_tuple(dimensions), type=vtype)

        routine.spec = SubstituteExpressions(var_map).visit(routine.spec)

        var_map = {}
        arguments_name = [var.name for var in routine.arguments]
        for var in FindVariables().visit(routine.body):
            if var.name not in arguments_name:
                if transformation_type == 'hoist':
                    if var.name in relevant_local_arrays:
                        var_map[var] = var.clone(dimensions=var.dimensions + (routine.variable_map[block_dim.index],))
                else:
                    if var.name in relevant_local_arrays:
                        dimensions = list(var.dimensions)
                        var_map[var] = var.clone(dimensions=as_tuple(dimensions[1:]))

        routine.body = SubstituteExpressions(var_map).visit(routine.body)

    def device_derived_types(self, routine, derived_types, targets=None):
        """
        Create device versions of variables of specific derived types including
        host-device-synchronisation.
        Parameters
        ----------
        routine: :any:`Subroutine`
            The subroutine to create device versions of the specified derived type variables.
        derived_types: tuple
            Tuple of derived types within the routine
        targets : tuple of str
            Tuple of subroutine call names that are processed in this traversal
        """
        _variables = list(FindVariables().visit(routine.ir))
        variables = []
        for var in _variables:
            for derived_type in derived_types:
                if derived_type in str(var.type):
                    variables.append(var)

        var_map = {}
        for var in variables:
            new_var = var.clone(name=f"{var.name}_d", type=var.type.clone(intent=None, imported=None,
                                                                          allocatable=None, device=True,
                                                                          module=None))
            var_map[var] = new_var
            routine.spec.append(ir.VariableDeclaration((new_var,)))
            routine.body.prepend(ir.Assignment(lhs=new_var, rhs=var))

        for call in FindNodes(ir.CallStatement).visit(routine.body):
            if call.name not in as_tuple(targets):
                continue
            arguments = tuple(var_map.get(arg, arg) for arg in call.arguments)
            call._update(arguments=arguments)
        return variables

    def driver_device_variables(self, routine, targets=None):
        """
        Driver device variable versions including
        * variable declaration
        * allocation
        * host-device synchronisation
        * de-allocation
        Parameters
        ----------
        routine: :any:`Subroutine`
            The subroutine (driver) to handle the device variables
        targets : tuple of str
            Tuple of subroutine call names that are processed in this traversal
        """

        # # istat: status of CUDA runtime function (e.g. for cudaDeviceSynchronize(), cudaMalloc(), cudaFree(), ...)
        # i_type = SymbolAttributes(types.BasicType.INTEGER)
        # routine.spec.append(ir.VariableDeclaration(symbols=(sym.Variable(name="istat", type=i_type),)))

        relevant_arrays = []
        calls = tuple(
            call for call in FindNodes(ir.CallStatement).visit(routine.body)
            if call.name in as_tuple(targets)
        )
        for call in calls:
            relevant_arrays.extend([arg for arg in call.arguments if isinstance(arg, sym.Array)])

        relevant_arrays = list(dict.fromkeys(relevant_arrays))

        if self.mode in ['cuda', 'hip']:
            # Collect the three types of device data accesses from calls
            inargs = ()
            inoutargs = ()
            outargs = ()

            # insert_index = routine.body.body.index(calls[-1])
            # insert_index = None
            for call in calls:
                if call.routine is BasicType.DEFERRED:
                    # warning(f'[Loki] Data offload: Routine {routine.name} has not been enriched with ' +
                    #     f'in {str(call.name).lower()}')
                    continue
                for param, arg in call.arg_iter():
                    if isinstance(param, sym.Array) and param.type.intent.lower() == 'in':
                        inargs += (str(arg.name).lower(),)
                    if isinstance(param, sym.Array) and param.type.intent.lower() == 'inout':
                        inoutargs += (str(arg.name).lower(),)
                    if isinstance(param, sym.Array) and param.type.intent.lower() == 'out':
                        outargs += (str(arg.name).lower(),)

            # Sanitize data access categories to avoid double-counting variables
            inoutargs += tuple(v for v in inargs if v in outargs)
            inargs = tuple(v for v in inargs if v not in inoutargs)
            outargs = tuple(v for v in outargs if v not in inoutargs)

            # Filter for duplicates
            inargs = tuple(dict.fromkeys(inargs))
            outargs = tuple(dict.fromkeys(outargs))
            inoutargs = tuple(dict.fromkeys(inoutargs))

            copy_pragmas = []
            copy_end_pragmas = []
            if outargs:
                copy_pragmas += [ir.Pragma(keyword='loki', content=f'structured-data out({", ".join(outargs)})')]
                copy_end_pragmas += [ir.Pragma(keyword='loki', content='end structured-data')]
            if inoutargs:
                copy_pragmas += [ir.Pragma(keyword='loki', content=f'structured-data inout({", ".join(inoutargs)})')]
                copy_end_pragmas += [ir.Pragma(keyword='loki', content='end structured-data')]
            if inargs:
                copy_pragmas += [ir.Pragma(keyword='loki', content=f'structured-data in({", ".join(inargs)})')]
                copy_end_pragmas += [ir.Pragma(keyword='loki', content='end structured-data')]

            if copy_pragmas:
                pragma_map = {}
                for pragma in FindNodes(ir.Pragma).visit(routine.body):
                    if pragma.content == 'data' and 'loki' == pragma.keyword:
                        pragma_map[pragma] = as_tuple(copy_pragmas)
                if pragma_map:
                    routine.body = Transformer(pragma_map).visit(routine.body)
            if copy_end_pragmas:
                pragma_map = {}
                for pragma in FindNodes(ir.Pragma).visit(routine.body):
                    if pragma.content == 'end data' and 'loki' == pragma.keyword:
                        pragma_map[pragma] = as_tuple(copy_end_pragmas)
                if pragma_map:
                    routine.body = Transformer(pragma_map).visit(routine.body)
        else:
            # Declaration
            routine.spec.append(ir.Comment(''))
            routine.spec.append(ir.Comment('! Device arrays'))
            for array in relevant_arrays:
                vtype = array.type.clone(device=True, allocatable=True, intent=None, shape=None)
                vdimensions = [sym.RangeIndex((None, None))] * len(array.shape)
                var = array.clone(name=f"{array.name}_d", type=vtype, dimensions=as_tuple(vdimensions))
                routine.spec.append(ir.VariableDeclaration(symbols=as_tuple(var)))

            # Allocation
            for array in reversed(relevant_arrays):
                vtype = array.type.clone(device=True, allocatable=True, intent=None, shape=None)
                routine.body.prepend(ir.Allocation((array.clone(name=f"{array.name}_d", type=vtype,
                                                            dimensions=routine.variable_map[array.name].dimensions),)))
            routine.body.prepend(ir.Comment('! Device array allocation'))
            routine.body.prepend(ir.Comment(''))

            allocations = FindNodes(ir.Allocation).visit(routine.body)
            if allocations:
                insert_index = routine.body.body.index(allocations[-1]) + 1
            else:
                insert_index = None
            # or: insert_index = routine.body.body.index(calls[0])
            # Copy host to device
            for array in reversed(relevant_arrays):
                vtype = array.type.clone(device=True, allocatable=True, intent=None, shape=None)
                lhs = array.clone(name=f"{array.name}_d", type=vtype, dimensions=())
                rhs = array.clone(dimensions=())
                if insert_index is not None:
                    routine.body.insert(insert_index, ir.Assignment(lhs=lhs, rhs=rhs))
                else:
                    routine.body.prepend(ir.Assignment(lhs=lhs, rhs=rhs))
            routine.body.insert(insert_index, ir.Comment('! Copy host to device'))
            routine.body.insert(insert_index, ir.Comment(''))

            # TODO: this just assumes that host-device-synchronisation is only needed at the beginning and end
            # Copy device to host
            insert_index = None
            for call in FindNodes(ir.CallStatement).visit(routine.body):
                if "THREAD_END" in str(call.name):  # TODO: fix/check: very specific to CLOUDSC
                    insert_index = routine.body.body.index(call) + 1

            if insert_index is None:
                routine.body.append(ir.Comment(''))
                routine.body.append(ir.Comment('! Copy device to host'))
            for v in reversed(relevant_arrays):
                if v.type.intent != "in":
                    lhs = v.clone(dimensions=())
                    vtype = v.type.clone(device=True, allocatable=True, intent=None, shape=None)
                    rhs = v.clone(name=f"{v.name}_d", type=vtype, dimensions=())
                    if insert_index is None:
                        routine.body.append(ir.Assignment(lhs=lhs, rhs=rhs))
                    else:
                        routine.body.insert(insert_index, ir.Assignment(lhs=lhs, rhs=rhs))
            if insert_index is not None:
                routine.body.insert(insert_index, ir.Comment('! Copy device to host'))

            # De-allocation
            routine.body.append(ir.Comment(''))
            routine.body.append(ir.Comment('! De-allocation'))
            for array in relevant_arrays:
                routine.body.append(ir.Deallocation((array.clone(name=f"{array.name}_d", dimensions=()),)))

            call_map = {}
            for call in calls:
                arguments = []
                for arg in call.arguments:
                    if arg in relevant_arrays:
                        vtype = arg.type.clone(device=True, allocatable=True, intent=None)
                        arguments.append(arg.clone(name=f"{arg.name}_d", type=vtype, dimensions=()))
                    else:
                        arguments.append(arg)
                call_map[call] = call.clone(arguments=as_tuple(arguments))
            routine.body = Transformer(call_map).visit(routine.body)
loki-ecmwf-0.3.6/loki/transformations/single_column/hoist.py0000664000175000017500000001214715167130205024473 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.expression import symbols as sym
from loki.ir import nodes as ir

from loki.transformations.temporaries.hoist_variables import HoistVariablesTransformation
from loki.transformations.temporaries.stack_allocator import BaseStackTransformation
from loki.transformations.utilities import get_integer_variable, substitute_variables_for_definitions


__all__ = ['SCCHoistTemporaryArraysTransformation']


class SCCHoistTemporaryArraysTransformation(HoistVariablesTransformation):
    """
    **Specialisation** for the *Synthesis* part of the hoist variables
    transformation that uses automatic arrays in the driver layer to
    allocate hoisted temporaries.

    This flavour of the hoisting synthesis will add a blocking dimension
    to the allocation and add OpenACC directives to the driver routine
    to trigger device side-allocation of the hoisted temporaries.

    Parameters
    ----------
    block_dim : :any:`Dimension`
        :any:`Dimension` object to define the blocking dimension
        to use for hoisted array arguments on the driver side.
    """

    def __init__(self, block_dim=None, **kwargs):
        self.block_dim = block_dim
        super().__init__(**kwargs)

    def driver_variable_declaration(self, routine, variables):
        """
        Adds driver-side declarations of full block-size arrays to
        pass to kernels. It also adds the OpenACC pragmas for
        driver-side allocation/deallocation.

        Parameters
        ----------
        routine : :any:`Subroutine`
            The subroutine to add the variable declaration to.
        variables : tuple of :any:`Variable`
            The array to be declared, allocated and de-allocated.
        """
        if not self.block_dim:
            raise RuntimeError(
                '[Loki] SingleColumnCoalescedTransform: No blocking dimension found '
                'for array argument hoisting.'
            )

        block_var = get_integer_variable(routine, self.block_dim.size)
        block_var = substitute_variables_for_definitions(routine, variables=block_var)[0]
        routine.variables += tuple(
            v.clone(
                dimensions=v.dimensions + (block_var,),
                type=v.type.clone(shape=v.shape + (block_var,))
            ) for v in variables
        )

        # Add explicit device-side allocations/deallocations for hoisted temporaries
        vnames = ', '.join(v.name for v in variables)
        if vnames:
            pragma = ir.Pragma(keyword='loki', content=f'unstructured-data create({vnames})')
            # Rather than simply decrementing the dynamic reference counter,
            # finalize sets it to zero. This shouldn't be needed, and likely points to an
            # OpenACC runtime bug.
            pragma_post = ir.Pragma(keyword='loki', content=f'exit unstructured-data delete({vnames}) finalize')

            # Add comments around standalone pragmas to avoid false attachment
            if not BaseStackTransformation._insert_stack_at_loki_pragma(routine, pragma):
                routine.body.prepend((ir.Comment(''), pragma, ir.Comment('')))
            routine.body.append((ir.Comment(''), pragma_post, ir.Comment('')))

    def driver_call_argument_remapping(self, routine, call, variables):
        """
        Adds hoisted sub-arrays to the kernel call from a driver routine.

        This assumes that the hoisted temporaries have been allocated with
        a blocking dimension and are device-resident. The remapping will then
        add the block-index as the last index to each passed array argument.

        Parameters
        ----------
        routine : :any:`Subroutine`
            The subroutine to add the variable declaration to.
        call : :any:`CallStatement`
            Call object to which hoisted arrays will be added.
        variables : tuple of :any:`Variable`
            The array to be declared, allocated and de-allocated.
        """
        if not self.block_dim:
            raise RuntimeError(
                '[Loki] SingleColumnCoalescedTransform: No blocking dimension found '
                'for array argument hoisting.'
            )
        idx_var = get_integer_variable(routine, self.block_dim.index)
        if self.as_kwarguments:
            new_kwargs = tuple(
                (a.name, v.clone(dimensions=tuple(sym.RangeIndex((None, None))
                for _ in v.dimensions) + (idx_var,))) for (a, v) in variables
            )
            kwarguments = call.kwarguments if call.kwarguments is not None else ()
            return call.clone(kwarguments=kwarguments + new_kwargs)
        new_args = tuple(
            v.clone(dimensions=tuple(sym.RangeIndex((None, None)) for _ in v.dimensions) + (idx_var,))
            for v in variables
        )
        return call.clone(arguments=call.arguments + new_args)
loki-ecmwf-0.3.6/loki/transformations/single_column/scc_low_level.py0000664000175000017500000004233515167130205026167 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from functools import partial

from loki.batch import Pipeline, Transformation
from loki.transformations.temporaries.hoist_variables import HoistTemporaryArraysAnalysis
from loki.transformations.single_column.base import SCCBaseTransformation
from loki.transformations.single_column.demote import SCCDemoteTransformation
from loki.transformations.single_column.devector import SCCDevectorTransformation
from loki.transformations.single_column.revector import SCCRevectorTransformation
from loki.transformations.single_column.scc_cuf import (
    HoistTemporaryArraysDeviceAllocatableTransformation,
    HoistTemporaryArraysPragmaOffloadTransformation,
    SccLowLevelDataOffload, SccLowLevelLaunchConfiguration
)
from loki.transformations.block_index_transformations import (
        InjectBlockIndexTransformation,
        LowerBlockIndexTransformation, LowerBlockLoopTransformation
)
from loki.transformations.transform_derived_types import DerivedTypeArgumentsTransformation
from loki.transformations.data_offload import (
    GlobalVariableAnalysis, GlobalVarHoistTransformation
)
from loki.transformations.parametrise import ParametriseTransformation
from loki.transformations.inline import (
    inline_constant_parameters, inline_elemental_functions
)

__all__ = [
        'SCCLowLevelCufHoist', 'SCCLowLevelCufParametrise', 'SCCLowLevelHoist',
        'SCCLowLevelParametrise', 'SCCLowLevelCuf'
]

def inline_elemental_kernel(routine, **kwargs):
    role = kwargs['role']

    if role == 'kernel':

        inline_constant_parameters(routine, external_only=True)
        inline_elemental_functions(routine)


class InlineTransformation(Transformation):

    def transform_subroutine(self, routine, **kwargs):
        role = kwargs['role']

        if role == 'kernel':

            inline_constant_parameters(routine, external_only=True)
            inline_elemental_functions(routine)


SCCLowLevelCuf = partial(
    Pipeline, classes=(
        SCCBaseTransformation,
        SCCDevectorTransformation,
        SCCDemoteTransformation,
        SCCRevectorTransformation,
        LowerBlockIndexTransformation,
        InjectBlockIndexTransformation,
        LowerBlockLoopTransformation,
        SccLowLevelLaunchConfiguration,
        SccLowLevelDataOffload,
    )
)
"""
The basic Single Column Coalesced low-level GPU via CUDA-Fortran (SCC-CUF).

This tranformation will convert kernels with innermost vectorisation
along a common horizontal dimension to a GPU-friendly loop-layout via
loop inversion and local array variable demotion. The resulting kernel
remains "vector-parallel", but with the ``horizontal`` loop as the
outermost iteration dimension (as far as data dependencies
allow). This allows local temporary arrays to be demoted to scalars,
where possible.

Kernels are specified via ``'GLOBAL'`` and the number of threads that
execute the kernel for a given call is specified via the chevron syntax.

This :any:`Pipeline` applies the following :any:`Transformation`
classes in sequence:
1. :any:`SCCBaseTransformation` - Ensure utility variables and resolve
   problematic code constructs.
2. :any:`SCCDevectorTransformation` - Remove horizontal vector loops.
3. :any:`SCCDemoteTransformation` - Demote local temporary array
   variables where appropriate.
4. :any:`SCCRevectorTransformation` - Re-insert the vecotr loops outermost,
   according to identified vector sections.
5. :any:`LowerBlockIndexTransformation` - Lower the block index (for
   array argument definitions).
6. :any:`InjectBlockIndexTransformation` - Complete the previous step
   and inject the block index for the relevant arrays.
7. :any:`LowerBlockLoopTransformation` - Lower the block loop
   from driver to kernel(s).
8. :any:`SCCLowLevelLaunchConfiguration` - Create launch configuration
   and related things.
9. :any:`SCCLowLevelDataOffload` - Create/handle data offload
   and related things.

Parameters
----------
horizontal : :any:`Dimension`
    :any:`Dimension` object describing the variable conventions used in code
    to define the horizontal data dimension and iteration space.
block_dim : :any:`Dimension`
    Optional ``Dimension`` object to define the blocking dimension
    to use for hoisted column arrays if hoisting is enabled.
directive : string or None
    Directives flavour to use for parallelism annotations; either
    ``'openacc'`` or ``None``.
trim_vector_sections : bool
    Flag to trigger trimming of extracted vector sections to remove
    nodes that are not assignments involving vector parallel arrays.
demote_local_arrays : bool
    Flag to trigger local array demotion to scalar variables where possible
derived_types: tuple
    List of relevant derived types
transformation_type : str
    Kind of transformation/Handling of temporaries/local arrays

    - `parametrise`: parametrising the array dimensions to make the vertical dimension
      a compile-time constant
    - `hoist`: host side hoisting of (relevant) arrays
mode: str
    Mode/language to target

    - `CUF` - CUDA Fortran
    - `CUDA` - CUDA C
    - `HIP` - HIP
"""

SCCLowLevelCufParametrise = partial(
    Pipeline, classes=(
        SCCBaseTransformation,
        SCCDevectorTransformation,
        SCCDemoteTransformation,
        SCCRevectorTransformation,
        LowerBlockIndexTransformation,
        InjectBlockIndexTransformation,
        LowerBlockLoopTransformation,
        SccLowLevelLaunchConfiguration,
        SccLowLevelDataOffload,
        ParametriseTransformation
    )
)
"""
The Single Column Coalesced low-level GPU via CUDA-Fortran (SCC-CUF)
handling temporaries via parametrisation.

For details of the kernel and driver-side transformations, please
refer to :any:`SCCLowLevelCuf`.

In addition, this pipeline will invoke
:any:`ParametriseTransformation` to parametrise relevant array
dimensions to allow having temporary arrays.

Parameters
----------
horizontal : :any:`Dimension`
    :any:`Dimension` object describing the variable conventions used in code
    to define the horizontal data dimension and iteration space.
block_dim : :any:`Dimension`
    Optional ``Dimension`` object to define the blocking dimension
    to use for hoisted column arrays if hoisting is enabled.
directive : string or None
    Directives flavour to use for parallelism annotations; either
    ``'openacc'`` or ``None``.
trim_vector_sections : bool
    Flag to trigger trimming of extracted vector sections to remove
    nodes that are not assignments involving vector parallel arrays.
demote_local_arrays : bool
    Flag to trigger local array demotion to scalar variables where possible
derived_types: tuple
    List of relevant derived types
transformation_type : str
    Kind of transformation/Handling of temporaries/local arrays

    - `parametrise`: parametrising the array dimensions to make the vertical dimension
      a compile-time constant
    - `hoist`: host side hoisting of (relevant) arrays
mode: str
    Mode/language to target

    - `CUF` - CUDA Fortran
    - `CUDA` - CUDA C
    - `HIP` - HIP
dic2p: dict
    Dictionary of variable names and corresponding values to be parametrised.
"""

SCCLowLevelCufHoist = partial(
    Pipeline, classes=(
        SCCBaseTransformation,
        SCCDevectorTransformation,
        SCCDemoteTransformation,
        SCCRevectorTransformation,
        LowerBlockIndexTransformation,
        InjectBlockIndexTransformation,
        LowerBlockLoopTransformation,
        SccLowLevelLaunchConfiguration,
        SccLowLevelDataOffload,
        HoistTemporaryArraysAnalysis,
        HoistTemporaryArraysDeviceAllocatableTransformation
    )
)
"""
The Single Column Coalesced low-level GPU via CUDA-Fortran (SCC-CUF)
handling temporaries via hoisting.

For details of the kernel and driver-side transformations, please
refer to :any:`SCCLowLevelCuf`.

In addition, this pipeline will invoke
:any:`HoistTemporaryArraysAnalysis` and
:any:`HoistTemporaryArraysDeviceAllocatableTransformation`
to hoist temporary arrays.

Parameters
----------
horizontal : :any:`Dimension`
    :any:`Dimension` object describing the variable conventions used in code
    to define the horizontal data dimension and iteration space.
block_dim : :any:`Dimension`
    Optional ``Dimension`` object to define the blocking dimension
    to use for hoisted column arrays if hoisting is enabled.
directive : string or None
    Directives flavour to use for parallelism annotations; either
    ``'openacc'`` or ``None``.
trim_vector_sections : bool
    Flag to trigger trimming of extracted vector sections to remove
    nodes that are not assignments involving vector parallel arrays.
demote_local_arrays : bool
    Flag to trigger local array demotion to scalar variables where possible
derived_types: tuple
    List of relevant derived types
transformation_type : str
    Kind of transformation/Handling of temporaries/local arrays

    - `parametrise`: parametrising the array dimensions to make the vertical dimension
      a compile-time constant
    - `hoist`: host side hoisting of (relevant) arrays
mode: str
    Mode/language to target

    - `CUF` - CUDA Fortran
    - `CUDA` - CUDA C
    - `HIP` - HIP
"""

SCCLowLevelParametrise = partial(
    Pipeline, classes=(
        InlineTransformation,
        GlobalVariableAnalysis,
        GlobalVarHoistTransformation,
        DerivedTypeArgumentsTransformation,
        SCCBaseTransformation,
        SCCDevectorTransformation,
        SCCDemoteTransformation,
        SCCRevectorTransformation,
        LowerBlockIndexTransformation,
        InjectBlockIndexTransformation,
        LowerBlockLoopTransformation,
        SccLowLevelLaunchConfiguration,
        SccLowLevelDataOffload,
        ParametriseTransformation
    )
)
"""
The Single Column Coalesced low-level GPU via low-level C-style
kernel language (CUDA, HIP, ...) handling temporaries via parametrisation.

This tranformation will convert kernels with innermost vectorisation
along a common horizontal dimension to a GPU-friendly loop-layout via
loop inversion and local array variable demotion. The resulting kernel
remains "vector-parallel", but with the ``horizontal`` loop as the
outermost iteration dimension (as far as data dependencies
allow). This allows local temporary arrays to be demoted to scalars,
where possible.

Kernels are specified via e.g., ``'__global__'`` and the number of threads that
execute the kernel for a given call is specified via the chevron syntax.

This :any:`Pipeline` applies the following :any:`Transformation`
classes in sequence:
1. :any:`InlineTransformation` - Inline constants and elemental
   functions.
2. :any:`GlobalVariableAnalysis` - Analysis of global variables
3. :any:`GlobalVarHoistTransformation` - Hoist global variables
   to the driver.
4. :any:`DerivedTypeArgumentsTransformation` - Flatten derived types/
   remove derived types from procedure signatures by replacing the
   (relevant) derived type arguments by its member variables.
5. :any:`SCCBaseTransformation` - Ensure utility variables and resolve
   problematic code constructs.
6. :any:`SCCDevectorTransformation` - Remove horizontal vector loops.
7. :any:`SCCDemoteTransformation` - Demote local temporary array
   variables where appropriate.
8. :any:`SCCRevectorTransformation` - Re-insert the vecotr loops outermost,
   according to identified vector sections.
9. :any:`LowerBlockIndexTransformation` - Lower the block index (for
   array argument definitions).
10. :any:`InjectBlockIndexTransformation` - Complete the previous step
   and inject the block index for the relevant arrays.
11. :any:`LowerBlockLoopTransformation` - Lower the block loop
   from driver to kernel(s).
12. :any:`SCCLowLevelLaunchConfiguration` - Create launch configuration
   and related things.
13. :any:`SCCLowLevelDataOffload` - Create/handle data offload
   and related things.
14. :any:`ParametriseTransformation` - Parametrise according to ``dic2p``.

Parameters
----------
horizontal : :any:`Dimension`
    :any:`Dimension` object describing the variable conventions used in code
    to define the horizontal data dimension and iteration space.
block_dim : :any:`Dimension`
    Optional ``Dimension`` object to define the blocking dimension
    to use for hoisted column arrays if hoisting is enabled.
directive : string or None
    Directives flavour to use for parallelism annotations; either
    ``'openacc'`` or ``None``.
trim_vector_sections : bool
    Flag to trigger trimming of extracted vector sections to remove
    nodes that are not assignments involving vector parallel arrays.
demote_local_arrays : bool
    Flag to trigger local array demotion to scalar variables where possible
derived_types: tuple
    List of relevant derived types
transformation_type : str
    Kind of transformation/Handling of temporaries/local arrays

    - `parametrise`: parametrising the array dimensions to make the vertical dimension
      a compile-time constant
    - `hoist`: host side hoisting of (relevant) arrays
mode: str
    Mode/language to target

    - `CUF` - CUDA Fortran
    - `CUDA` - CUDA C
    - `HIP` - HIP
dic2p: dict
    Dictionary of variable names and corresponding values to be parametrised.
"""

SCCLowLevelHoist = partial(
    Pipeline, classes=(
        InlineTransformation,
        GlobalVariableAnalysis,
        GlobalVarHoistTransformation,
        DerivedTypeArgumentsTransformation,
        SCCBaseTransformation,
        SCCDevectorTransformation,
        SCCDemoteTransformation,
        SCCRevectorTransformation,
        LowerBlockIndexTransformation,
        InjectBlockIndexTransformation,
        LowerBlockLoopTransformation,
        SccLowLevelLaunchConfiguration,
        SccLowLevelDataOffload,
        HoistTemporaryArraysAnalysis,
        HoistTemporaryArraysPragmaOffloadTransformation
    )
)
"""
The Single Column Coalesced low-level GPU via low-level C-style
kernel language (CUDA, HIP, ...) handling temporaries via parametrisation.

This tranformation will convert kernels with innermost vectorisation
along a common horizontal dimension to a GPU-friendly loop-layout via
loop inversion and local array variable demotion. The resulting kernel
remains "vector-parallel", but with the ``horizontal`` loop as the
outermost iteration dimension (as far as data dependencies
allow). This allows local temporary arrays to be demoted to scalars,
where possible.

Kernels are specified via e.g., ``'__global__'`` and the number of threads that
execute the kernel for a given call is specified via the chevron syntax.

This :any:`Pipeline` applies the following :any:`Transformation`
classes in sequence:
1. :any:`InlineTransformation` - Inline constants and elemental
   functions.
2. :any:`GlobalVariableAnalysis` - Analysis of global variables
3. :any:`GlobalVarHoistTransformation` - Hoist global variables
   to the driver.
4. :any:`DerivedTypeArgumentsTransformation` - Flatten derived types/
   remove derived types from procedure signatures by replacing the
   (relevant) derived type arguments by its member variables.
5. :any:`SCCBaseTransformation` - Ensure utility variables and resolve
   problematic code constructs.
6. :any:`SCCDevectorTransformation` - Remove horizontal vector loops.
7. :any:`SCCDemoteTransformation` - Demote local temporary array
   variables where appropriate.
8. :any:`SCCRevectorTransformation` - Re-insert the vecotr loops outermost,
   according to identified vector sections.
9. :any:`LowerBlockIndexTransformation` - Lower the block index (for
   array argument definitions).
10. :any:`InjectBlockIndexTransformation` - Complete the previous step
   and inject the block index for the relevant arrays.
11. :any:`LowerBlockLoopTransformation` - Lower the block loop
   from driver to kernel(s).
12. :any:`SCCLowLevelLaunchConfiguration` - Create launch configuration
   and related things.
13. :any:`SCCLowLevelDataOffload` - Create/handle data offload
   and related things.
14. :any:`HoistTemporaryArraysAnalysis` - Analysis part of hoisting.
15. :any:`HoistTemporaryArraysPragmaOffloadTransformation` - Syntesis
    part of hoisting.

Parameters
----------
horizontal : :any:`Dimension`
    :any:`Dimension` object describing the variable conventions used in code
    to define the horizontal data dimension and iteration space.
block_dim : :any:`Dimension`
    Optional ``Dimension`` object to define the blocking dimension
    to use for hoisted column arrays if hoisting is enabled.
directive : string or None
    Directives flavour to use for parallelism annotations; either
    ``'openacc'`` or ``None``.
trim_vector_sections : bool
    Flag to trigger trimming of extracted vector sections to remove
    nodes that are not assignments involving vector parallel arrays.
demote_local_arrays : bool
    Flag to trigger local array demotion to scalar variables where possible
derived_types: tuple
    List of relevant derived types
transformation_type : str
    Kind of transformation/Handling of temporaries/local arrays

    - `parametrise`: parametrising the array dimensions to make the vertical dimension
      a compile-time constant
    - `hoist`: host side hoisting of (relevant) arrays
mode: str
    Mode/language to target

    - `CUF` - CUDA Fortran
    - `CUDA` - CUDA C
    - `HIP` - HIP
"""
loki-ecmwf-0.3.6/loki/transformations/single_column/base.py0000664000175000017500000001140015167130205024246 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.batch import Transformation

from loki.transformations.array_indexing import resolve_vector_dimension
from loki.transformations.sanitise import do_resolve_associates
from loki.transformations.utilities import (
    check_routine_sequential, rename_variables
)


__all__ = ['SCCBaseTransformation']


class SCCBaseTransformation(Transformation):
    """
    A basic set of utilities used in the SCC transformation. These utilities
    can either be used as a transformation in their own right, or the contained
    class methods can be called directly.

    Parameters
    ----------
    horizontal : :any:`Dimension`
        :any:`Dimension` object describing the variable conventions used in code
        to define the horizontal data dimension and iteration space.
    """

    def __init__(self, horizontal):
        self.horizontal = horizontal
        self.rename_indices = False

    @staticmethod
    def rename_index_aliases(routine, dimension):
        """
        Rename index aliases: map all index aliases ``dimension.indices`` to
        ``dimension.index``.

        Parameters
        ----------
        routine: :any:`Subroutine`
            The subroutine to rename index aliases.
        horizontal : :any:`Dimension`
            :any:`Dimension` object to rename the index aliases
            to the first/former index.
        """
        if len(dimension.indices) > 1:
            symbol_map = {index: dimension.index for index in dimension.indices[1:]}
            rename_variables(routine, symbol_map)

    # TODO: correct "definition" of a pure/elemental routine (take e.g. loki serial into account ...)
    @staticmethod
    def is_elemental(routine):
        """
        Check whether :any:`Subroutine` ``routine`` is an elemental routine.
        Need for distinguishing elemental and non-elemental function to transform
        those in a different way.

        Parameters
        ----------
        routine: :any:`Subroutine`
            The subroutine to check whether elemental
        """
        for prefix in routine.prefix:
            if prefix.lower() == 'elemental':
                return True
        return False

    def transform_subroutine(self, routine, **kwargs):
        """
        Apply SCCBase utilities to a :any:`Subroutine`.

        Parameters
        ----------
        routine : :any:`Subroutine`
            Subroutine to apply this transformation to.
        role : string
            Role of the subroutine in the call tree; should be ``"kernel"``
        """
        role = kwargs['role']
        item = kwargs.get('item', None)
        rename_indices = kwargs.get('rename_index_aliases', self.rename_indices)
        if item:
            rename_indices = item.config.get('rename_index_aliases', rename_indices)

        if role == 'kernel':
            self.process_kernel(routine, rename_indices=rename_indices)
        if role == 'driver':
            self.process_driver(routine)

    def process_kernel(self, routine, rename_indices=False):
        """
        Applies the SCCBase utilities to a "kernel". This consists simply
        of resolving associations, masked statements and vector notation.

        Parameters
        ----------
        routine : :any:`Subroutine`
            Subroutine to apply this transformation to.
        """

        # Bail if routine is marked as sequential or routine has already been processed
        if check_routine_sequential(routine):
            return

        # Bail if routine is elemental
        if self.is_elemental(routine):
            return

        if rename_indices:
            self.rename_index_aliases(routine, dimension=self.horizontal)

        # Associates at the highest level, so they don't interfere
        # with the sections we need to do for detecting subroutine calls
        do_resolve_associates(routine)

        # Resolve vector notation, eg. VARIABLE(KIDIA:KFDIA)
        resolve_vector_dimension(routine, dimension=self.horizontal)

    def process_driver(self, routine):
        """
        Applies the SCCBase utilities to a "driver". This consists simply
        of resolving associations.

        Parameters
        ----------
        routine : :any:`Subroutine`
            Subroutine to apply this transformation to.
        """

        # Resolve associates, since the PGI compiler cannot deal with
        # implicit derived type component offload by calling device
        # routines.
        do_resolve_associates(routine)
loki-ecmwf-0.3.6/loki/transformations/single_column/scc.py0000664000175000017500000007413615167130205024123 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from functools import partial

from loki.batch import Pipeline

from loki.transformations.temporaries import (
        HoistTemporaryArraysAnalysis, TemporariesPoolAllocatorTransformation,
        TemporariesRawStackTransformation,
        FtrPtrStackTransformation, DirectIdxStackTransformation,
        EcstackPoolAllocatorTransformation
)

from loki.transformations.single_column.base import SCCBaseTransformation
from loki.transformations.single_column.annotate import SCCAnnotateTransformation
from loki.transformations.single_column.demote import SCCDemoteTransformation
from loki.transformations.single_column.hoist import SCCHoistTemporaryArraysTransformation
from loki.transformations.single_column.devector import SCCDevectorTransformation
from loki.transformations.single_column.revector import (
    SCCVecRevectorTransformation, SCCSeqRevectorTransformation
)
from loki.transformations.single_column.vertical import SCCFuseVerticalLoops
from loki.transformations.pragma_model import PragmaModelTransformation
from loki.transformations.remove_code import RemoveCodeTransformation

__all__ = [
    'SCCVectorPipeline', 'SCCVVectorPipeline', 'SCCSVectorPipeline',
    'SCCHoistPipeline', 'SCCVHoistPipeline', 'SCCSHoistPipeline',
    'SCCStackPipeline', 'SCCVStackPipeline', 'SCCSStackPipeline',
    'SCCStackFtrPtrPipeline', 'SCCVStackFtrPtrPipeline', 'SCCSStackFtrPtrPipeline',
    'SCCStackDirectIdxPipeline', 'SCCVStackDirectIdxPipeline', 'SCCSStackDirectIdxPipeline',
    'SCCRawStackPipeline', 'SCCVRawStackPipeline', 'SCCSRawStackPipeline',
    'SCCSEcStackPipeline'
]


class RemoveUnusedVarTransformation(RemoveCodeTransformation):
    """
    A special version of :any:`RemoveCodeTransformation` being a temporary solution
    that allows to remove unused temporaries/arrays before applying
    a transformation that handles temporaries on device (hoist, stack)

    The transformation will apply the following methods in order:

    * :any:`do_remove_unused_vars`

    Parameters
    ----------
    remove_unused_vars : boolean
        Remove unused variables/locals from routines.
    remove_only_arrays : boolean
        Whether to only remove unused arrays from routines
        or all variables/locals.
    """
    def __init__(self, remove_unused_vars=False, remove_only_arrays=True, **kwargs): # pylint: disable=unused-argument
        super().__init__(remove_unused_vars=remove_unused_vars, remove_only_arrays=remove_only_arrays,
                remove_marked_regions=False, kernel_only=True)


SCCVVectorPipeline = partial(
    Pipeline, classes=(
        SCCFuseVerticalLoops,
        SCCBaseTransformation,
        SCCDevectorTransformation,
        SCCDemoteTransformation,
        SCCVecRevectorTransformation,
        SCCAnnotateTransformation,
        PragmaModelTransformation
    )
)
"""
The basic Single Column Coalesced (SCC) transformation with
vector-level kernel parallelism.

This tranformation will convert kernels with innermost vectorisation
along a common horizontal dimension to a GPU-friendly loop-layout via
loop inversion and local array variable demotion. The resulting kernel
remains "vector-parallel", but with the ``horizontal`` loop as the
outermost iteration dimension (as far as data dependencies
allow). This allows local temporary arrays to be demoted to scalars,
where possible.

The outer "driver" loop over blocks is used as the secondary dimension
of parallelism, where the outher data indexing dimension
(``block_dim``) is resolved in the first call to a "kernel"
routine. This is equivalent to a so-called "gang-vector" parallelisation
scheme.

This :any:`Pipeline` applies the following :any:`Transformation`
classes in sequence:
1. :any:`SCCBaseTransformation` - Ensure utility variables and resolve
   problematic code constructs.
2. :any:`SCCDevectorTransformation` - Remove horizontal vector loops.
3. :any:`SCCDemoteTransformation` - Demote local temporary array
   variables where appropriate.
4. :any:`SCCVecRevectorTransformation` - Re-insert the vector loops outermost,
   according to identified vector sections.
5. :any:`SCCAnnotateTransformation` - Annotate loops according to
   programming model (``directive``).

Parameters
----------
horizontal : :any:`Dimension`
    :any:`Dimension` object describing the variable conventions used in code
    to define the horizontal data dimension and iteration space.
block_dim : :any:`Dimension`
    Optional ``Dimension`` object to define the blocking dimension
    to use for hoisted column arrays if hoisting is enabled.
directive : string or None
    Directives flavour to use for parallelism annotations; either
    ``'openacc'``, ``'omp-gpu'`` or ``None``.
trim_vector_sections : bool
    Flag to trigger trimming of extracted vector sections to remove
    nodes that are not assignments involving vector parallel arrays.
demote_local_arrays : bool
    Flag to trigger local array demotion to scalar variables where possible
"""

# alias for backwards compability
SCCVectorPipeline = SCCVVectorPipeline

SCCSVectorPipeline = partial(
    Pipeline, classes=(
        SCCFuseVerticalLoops,
        SCCBaseTransformation,
        SCCDevectorTransformation,
        SCCDemoteTransformation,
        SCCSeqRevectorTransformation,
        SCCAnnotateTransformation,
        PragmaModelTransformation
    )
)
"""
The basic Single Column Coalesced (SCC) transformation with
sequential kernels.

This tranformation will convert kernels with innermost vectorisation
along a common horizontal dimension to a GPU-friendly loop-layout via
loop inversion and local array variable demotion. The resulting kernel
becomes sequential as the ``horizontal`` loop is hoisted to the driver
and the loop index becomes an argument to the kernel(s).
Moreover, this allows local temporary arrays to be demoted to scalars,
where possible.

The outer "driver" loop over blocks is used as the secondary dimension
of parallelism, where the outher data indexing dimension
(``block_dim``) is resolved in the first call to a "kernel"
routine. This is equivalent to a so-called "gang-vector" parallelisation
scheme.

This :any:`Pipeline` applies the following :any:`Transformation`
classes in sequence:
1. :any:`SCCBaseTransformation` - Ensure utility variables and resolve
   problematic code constructs.
2. :any:`SCCDevectorTransformation` - Remove horizontal vector loops.
3. :any:`SCCDemoteTransformation` - Demote local temporary array
   variables where appropriate.
4. :any:`SCCSeqRevectorTransformation` - Re-insert the vector loops outermost,
   according to identified vector sections.
5. :any:`SCCAnnotateTransformation` - Annotate loops according to
   programming model (``directive``).

Parameters
----------
horizontal : :any:`Dimension`
    :any:`Dimension` object describing the variable conventions used in code
    to define the horizontal data dimension and iteration space.
block_dim : :any:`Dimension`
    Optional ``Dimension`` object to define the blocking dimension
    to use for hoisted column arrays if hoisting is enabled.
directive : string or None
    Directives flavour to use for parallelism annotations; either
    ``'openacc'``, ``'omp-gpu'`` or ``None``.
trim_vector_sections : bool
    Flag to trigger trimming of extracted vector sections to remove
    nodes that are not assignments involving vector parallel arrays.
demote_local_arrays : bool
    Flag to trigger local array demotion to scalar variables where possible
"""

SCCVHoistPipeline = partial(
    Pipeline, classes=(
        SCCFuseVerticalLoops,
        SCCBaseTransformation,
        SCCDevectorTransformation,
        SCCDemoteTransformation,
        SCCVecRevectorTransformation,
        RemoveUnusedVarTransformation,
        HoistTemporaryArraysAnalysis,
        SCCHoistTemporaryArraysTransformation,
        SCCAnnotateTransformation,
        PragmaModelTransformation
    )
)
"""
SCC-style transformation with "vector-parallel" kernels
that additionally hoists local temporary
arrays that cannot be demoted to the outer driver call.

For details of the kernel and driver-side transformations, please
refer to :any:`SCCVVectorPipeline`

In addition, this pipeline will invoke
:any:`HoistTemporaryArraysAnalysis` and
:any:`SCCHoistTemporaryArraysTransformation` before the final
annotation step to hoist multi-dimensional local temporary array
variables to the "driver" routine, where they will be allocated on
device and passed down as arguments.

Parameters
----------
horizontal : :any:`Dimension`
    :any:`Dimension` object describing the variable conventions used in code
    to define the horizontal data dimension and iteration space.
block_dim : :any:`Dimension`
    Optional ``Dimension`` object to define the blocking dimension
    to use for hoisted column arrays if hoisting is enabled.
directive : string or None
    Directives flavour to use for parallelism annotations; either
    ``'openacc'``, ``'omp-gpu'`` or ``None``.
trim_vector_sections : bool
    Flag to trigger trimming of extracted vector sections to remove
    nodes that are not assignments involving vector parallel arrays.
demote_local_arrays : bool
    Flag to trigger local array demotion to scalar variables where possible
dim_vars: tuple of str, optional
    Variables to be within the dimensions of the arrays to be
    hoisted. If not provided, no checks will be done for the array
    dimensions in :any:`HoistTemporaryArraysAnalysis`.
"""

SCCSHoistPipeline = partial(
    Pipeline, classes=(
        SCCFuseVerticalLoops,
        SCCBaseTransformation,
        SCCDevectorTransformation,
        SCCDemoteTransformation,
        SCCSeqRevectorTransformation,
        RemoveUnusedVarTransformation,
        HoistTemporaryArraysAnalysis,
        SCCHoistTemporaryArraysTransformation,
        SCCAnnotateTransformation,
        PragmaModelTransformation
    )
)
"""
SCC-style transformation with sequential kernels
that additionally hoists local temporary
arrays that cannot be demoted to the outer driver call.

For details of the kernel and driver-side transformations, please
refer to :any:`SCCSVectorPipeline`

In addition, this pipeline will invoke
:any:`HoistTemporaryArraysAnalysis` and
:any:`SCCHoistTemporaryArraysTransformation` before the final
annotation step to hoist multi-dimensional local temporary array
variables to the "driver" routine, where they will be allocated on
device and passed down as arguments.

Parameters
----------
horizontal : :any:`Dimension`
    :any:`Dimension` object describing the variable conventions used in code
    to define the horizontal data dimension and iteration space.
block_dim : :any:`Dimension`
    Optional ``Dimension`` object to define the blocking dimension
    to use for hoisted column arrays if hoisting is enabled.
directive : string or None
    Directives flavour to use for parallelism annotations; either
    ``'openacc'``, ``'omp-gpu'`` or ``None``.
trim_vector_sections : bool
    Flag to trigger trimming of extracted vector sections to remove
    nodes that are not assignments involving vector parallel arrays.
demote_local_arrays : bool
    Flag to trigger local array demotion to scalar variables where possible
dim_vars: tuple of str, optional
    Variables to be within the dimensions of the arrays to be
    hoisted. If not provided, no checks will be done for the array
    dimensions in :any:`HoistTemporaryArraysAnalysis`.
"""

# alias for backwards compability
SCCHoistPipeline = SCCVHoistPipeline

SCCVStackPipeline = partial(
    Pipeline, classes=(
        SCCFuseVerticalLoops,
        SCCBaseTransformation,
        SCCDevectorTransformation,
        SCCDemoteTransformation,
        SCCVecRevectorTransformation,
        RemoveUnusedVarTransformation,
        SCCAnnotateTransformation,
        TemporariesPoolAllocatorTransformation,
        PragmaModelTransformation
    )
)
"""
SCC-style transformation with "vector-parallel" kernels
that additionally pre-allocates a "stack"
pool allocator and associates local arrays with preallocated memory.

For details of the kernel and driver-side transformations, please
refer to :any:`SCCVVectorPipeline`

In addition, this pipeline will invoke
:any:`TemporariesPoolAllocatorTransformation` to back the remaining
locally allocated arrays from a "stack" pool allocator that is
pre-allocated in the driver routine and passed down via arguments.

Parameters
----------
horizontal : :any:`Dimension`
    :any:`Dimension` object describing the variable conventions used in code
    to define the horizontal data dimension and iteration space.
block_dim : :any:`Dimension`
    Optional ``Dimension`` object to define the blocking dimension
    to use for hoisted column arrays if hoisting is enabled.
directive : string or None
    Directives flavour to use for parallelism annotations; either
    ``'openacc'``, ``'omp-gpu'`` or ``None``.
trim_vector_sections : bool
    Flag to trigger trimming of extracted vector sections to remove
    nodes that are not assignments involving vector parallel arrays.
demote_local_arrays : bool
    Flag to trigger local array demotion to scalar variables where possible
check_bounds : bool, optional
    Insert bounds-checks in the kernel to make sure the allocated
    stack size is not exceeded (default: `True`)
"""

# alias for backwards compability
SCCStackPipeline = SCCVStackPipeline

SCCSStackPipeline = partial(
    Pipeline, classes=(
        SCCFuseVerticalLoops,
        SCCBaseTransformation,
        SCCDevectorTransformation,
        SCCDemoteTransformation,
        SCCSeqRevectorTransformation,
        RemoveUnusedVarTransformation,
        SCCAnnotateTransformation,
        TemporariesPoolAllocatorTransformation,
        PragmaModelTransformation
    )
)
"""
SCC-style transformation with sequential kernels
that additionally pre-allocates a "stack"
pool allocator and associates local arrays with preallocated memory.

For details of the kernel and driver-side transformations, please
refer to :any:`SCCSVectorPipeline`

In addition, this pipeline will invoke
:any:`TemporariesPoolAllocatorTransformation` to back the remaining
locally allocated arrays from a "stack" pool allocator that is
pre-allocated in the driver routine and passed down via arguments.

Parameters
----------
horizontal : :any:`Dimension`
    :any:`Dimension` object describing the variable conventions used in code
    to define the horizontal data dimension and iteration space.
block_dim : :any:`Dimension`
    Optional ``Dimension`` object to define the blocking dimension
    to use for hoisted column arrays if hoisting is enabled.
directive : string or None
    Directives flavour to use for parallelism annotations; either
    ``'openacc'``, ``'omp-gpu'`` or ``None``.
trim_vector_sections : bool
    Flag to trigger trimming of extracted vector sections to remove
    nodes that are not assignments involving vector parallel arrays.
demote_local_arrays : bool
    Flag to trigger local array demotion to scalar variables where possible
check_bounds : bool, optional
    Insert bounds-checks in the kernel to make sure the allocated
    stack size is not exceeded (default: `True`)
"""


SCCVStackFtrPtrPipeline = partial(
    Pipeline, classes=(
        SCCFuseVerticalLoops,
        SCCBaseTransformation,
        SCCDevectorTransformation,
        SCCDemoteTransformation,
        SCCVecRevectorTransformation,
        RemoveUnusedVarTransformation,
        SCCAnnotateTransformation,
        FtrPtrStackTransformation,
        PragmaModelTransformation
    )
)
"""
SCC-style transformation with "vector-parallel" kernels
that additionally pre-allocates a "stack"
pool allocator and associates local arrays with preallocated memory.

For details of the kernel and driver-side transformations, please
refer to :any:`SCCVVectorPipeline`

In addition, this pipeline will invoke
:any:`FtrPtrStackTransformation` to back the remaining
locally allocated arrays from a "stack" pool allocator that is
pre-allocated in the driver routine and passed down via arguments.

Parameters
----------
horizontal : :any:`Dimension`
    :any:`Dimension` object describing the variable conventions used in code
    to define the horizontal data dimension and iteration space.
block_dim : :any:`Dimension`
    Optional ``Dimension`` object to define the blocking dimension
    to use for hoisted column arrays if hoisting is enabled.
directive : string or None
    Directives flavour to use for parallelism annotations; either
    ``'openacc'``, ``'omp-gpu'`` or ``None``.
trim_vector_sections : bool
    Flag to trigger trimming of extracted vector sections to remove
    nodes that are not assignments involving vector parallel arrays.
demote_local_arrays : bool
    Flag to trigger local array demotion to scalar variables where possible
check_bounds : bool, optional
    Insert bounds-checks in the kernel to make sure the allocated
    stack size is not exceeded (default: `True`)
"""

# alias for backwards compability
SCCStackFtrPtrPipeline = SCCVStackPipeline

SCCSStackFtrPtrPipeline = partial(
    Pipeline, classes=(
        SCCFuseVerticalLoops,
        SCCBaseTransformation,
        SCCDevectorTransformation,
        SCCDemoteTransformation,
        SCCSeqRevectorTransformation,
        RemoveUnusedVarTransformation,
        SCCAnnotateTransformation,
        FtrPtrStackTransformation,
        PragmaModelTransformation
    )
)
"""
SCC-style transformation with sequential kernels
that additionally pre-allocates a "stack"
pool allocator and associates local arrays with preallocated memory.

For details of the kernel and driver-side transformations, please
refer to :any:`SCCSVectorPipeline`

In addition, this pipeline will invoke
:any:`FtrPtrStackTransformation` to back the remaining
locally allocated arrays from a "stack" pool allocator that is
pre-allocated in the driver routine and passed down via arguments.

Parameters
----------
horizontal : :any:`Dimension`
    :any:`Dimension` object describing the variable conventions used in code
    to define the horizontal data dimension and iteration space.
block_dim : :any:`Dimension`
    Optional ``Dimension`` object to define the blocking dimension
    to use for hoisted column arrays if hoisting is enabled.
directive : string or None
    Directives flavour to use for parallelism annotations; either
    ``'openacc'``, ``'omp-gpu'`` or ``None``.
trim_vector_sections : bool
    Flag to trigger trimming of extracted vector sections to remove
    nodes that are not assignments involving vector parallel arrays.
demote_local_arrays : bool
    Flag to trigger local array demotion to scalar variables where possible
check_bounds : bool, optional
    Insert bounds-checks in the kernel to make sure the allocated
    stack size is not exceeded (default: `True`)
"""

SCCVStackDirectIdxPipeline = partial(
    Pipeline, classes=(
        SCCFuseVerticalLoops,
        SCCBaseTransformation,
        SCCDevectorTransformation,
        SCCDemoteTransformation,
        SCCVecRevectorTransformation,
        RemoveUnusedVarTransformation,
        SCCAnnotateTransformation,
        DirectIdxStackTransformation,
        PragmaModelTransformation
    )
)
"""
SCC-style transformation with "vector-parallel" kernels
that additionally pre-allocates a "stack"
pool allocator and replaces local temporaries with indexed sub-arrays
of this preallocated array.

For details of the kernel and driver-side transformations, please
refer to :any:`SCCVectorPipeline`

In addition, this pipeline will invoke
:any:`DirectIdxStackTransformation` to back the remaining
locally allocated arrays from a "stack" pool allocator that is
pre-allocated in the driver routine and passed down via arguments.

Parameters
----------
horizontal : :any:`Dimension`
    :any:`Dimension` object describing the variable conventions used in code
    to define the horizontal data dimension and iteration space.
block_dim : :any:`Dimension`
    Optional ``Dimension`` object to define the blocking dimension
    to use for hoisted column arrays if hoisting is enabled.
directive : string or None
    Directives flavour to use for parallelism annotations; either
    ``'openacc'``, ``'omp-gpu'`` or ``None``.
trim_vector_sections : bool
    Flag to trigger trimming of extracted vector sections to remove
    nodes that are not assignments involving vector parallel arrays.
demote_local_arrays : bool
    Flag to trigger local array demotion to scalar variables where possible
check_bounds : bool, optional
    Insert bounds-checks in the kernel to make sure the allocated
    stack size is not exceeded (default: `True`)
driver_horizontal : str, optional
    Override string if a separate variable name should be used for the
    horizontal when allocating the stack in the driver.
"""

# alias for backwards compability
SCCStackDirectIdxPipeline = SCCVStackDirectIdxPipeline

SCCSStackDirectIdxPipeline = partial(
    Pipeline, classes=(
        SCCFuseVerticalLoops,
        SCCBaseTransformation,
        SCCDevectorTransformation,
        SCCDemoteTransformation,
        SCCSeqRevectorTransformation,
        RemoveUnusedVarTransformation,
        SCCAnnotateTransformation,
        DirectIdxStackTransformation,
        PragmaModelTransformation
    )
)
"""
SCC-style transformation with sequential kernels
that additionally pre-allocates a "stack"
pool allocator and replaces local temporaries with indexed sub-arrays
of this preallocated array.

For details of the kernel and driver-side transformations, please
refer to :any:`SCCVectorPipeline`

In addition, this pipeline will invoke
:any:`DirectIdxStackTransformation` to back the remaining
locally allocated arrays from a "stack" pool allocator that is
pre-allocated in the driver routine and passed down via arguments.

Parameters
----------
horizontal : :any:`Dimension`
    :any:`Dimension` object describing the variable conventions used in code
    to define the horizontal data dimension and iteration space.
block_dim : :any:`Dimension`
    Optional ``Dimension`` object to define the blocking dimension
    to use for hoisted column arrays if hoisting is enabled.
directive : string or None
    Directives flavour to use for parallelism annotations; either
    ``'openacc'``, ``'omp-gpu'`` or ``None``.
trim_vector_sections : bool
    Flag to trigger trimming of extracted vector sections to remove
    nodes that are not assignments involving vector parallel arrays.
demote_local_arrays : bool
    Flag to trigger local array demotion to scalar variables where possible
check_bounds : bool, optional
    Insert bounds-checks in the kernel to make sure the allocated
    stack size is not exceeded (default: `True`)
driver_horizontal : str, optional
    Override string if a separate variable name should be used for the
    horizontal when allocating the stack in the driver.
"""

SCCVRawStackPipeline = partial(
    Pipeline, classes=(
        SCCBaseTransformation,
        SCCDevectorTransformation,
        SCCDemoteTransformation,
        SCCVecRevectorTransformation,
        RemoveUnusedVarTransformation,
        SCCAnnotateTransformation,
        TemporariesRawStackTransformation,
        PragmaModelTransformation
    )
)
"""
SCC-style transformation with "vector-parallel" kernels
that additionally pre-allocates a "stack"
pool allocator and replaces local temporaries with indexed sub-arrays
of this preallocated array.

For details of the kernel and driver-side transformations, please
refer to :any:`SCCVectorPipeline`

In addition, this pipeline will invoke
:any:`TemporariesRawStackTransformation` to back the remaining
locally allocated arrays from a "stack" pool allocator that is
pre-allocated in the driver routine and passed down via arguments.

Parameters
----------
horizontal : :any:`Dimension`
    :any:`Dimension` object describing the variable conventions used in code
    to define the horizontal data dimension and iteration space.
block_dim : :any:`Dimension`
    Optional ``Dimension`` object to define the blocking dimension
    to use for hoisted column arrays if hoisting is enabled.
directive : string or None
    Directives flavour to use for parallelism annotations; either
    ``'openacc'``, ``'omp-gpu'`` or ``None``.
trim_vector_sections : bool
    Flag to trigger trimming of extracted vector sections to remove
    nodes that are not assignments involving vector parallel arrays.
demote_local_arrays : bool
    Flag to trigger local array demotion to scalar variables where possible
check_bounds : bool, optional
    Insert bounds-checks in the kernel to make sure the allocated
    stack size is not exceeded (default: `True`)
driver_horizontal : str, optional
    Override string if a separate variable name should be used for the
    horizontal when allocating the stack in the driver.
"""

# alias for backwards compability
SCCRawStackPipeline = SCCVRawStackPipeline

SCCSRawStackPipeline = partial(
    Pipeline, classes=(
        SCCBaseTransformation,
        SCCDevectorTransformation,
        SCCDemoteTransformation,
        SCCSeqRevectorTransformation,
        SCCAnnotateTransformation,
        TemporariesRawStackTransformation,
        PragmaModelTransformation
    )
)
"""
SCC-style transformation with sequential kernels
that additionally pre-allocates a "stack"
pool allocator and replaces local temporaries with indexed sub-arrays
of this preallocated array.

For details of the kernel and driver-side transformations, please
refer to :any:`SCCVectorPipeline`

In addition, this pipeline will invoke
:any:`TemporariesRawStackTransformation` to back the remaining
locally allocated arrays from a "stack" pool allocator that is
pre-allocated in the driver routine and passed down via arguments.

Parameters
----------
horizontal : :any:`Dimension`
    :any:`Dimension` object describing the variable conventions used in code
    to define the horizontal data dimension and iteration space.
block_dim : :any:`Dimension`
    Optional ``Dimension`` object to define the blocking dimension
    to use for hoisted column arrays if hoisting is enabled.
directive : string or None
    Directives flavour to use for parallelism annotations; either
    ``'openacc'``, ``'omp-gpu'`` or ``None``.
trim_vector_sections : bool
    Flag to trigger trimming of extracted vector sections to remove
    nodes that are not assignments involving vector parallel arrays.
demote_local_arrays : bool
    Flag to trigger local array demotion to scalar variables where possible
check_bounds : bool, optional
    Insert bounds-checks in the kernel to make sure the allocated
    stack size is not exceeded (default: `True`)
driver_horizontal : str, optional
    Override string if a separate variable name should be used for the
    horizontal when allocating the stack in the driver.
"""

SCCVEcStackPipeline = partial(
    Pipeline, classes=(
        SCCFuseVerticalLoops,
        SCCBaseTransformation,
        SCCDevectorTransformation,
        SCCDemoteTransformation,
        SCCVecRevectorTransformation,
        RemoveUnusedVarTransformation,
        SCCAnnotateTransformation,
        EcstackPoolAllocatorTransformation,
        PragmaModelTransformation
    )
)
"""
SCC-style transformation with "vector-parallel" kernels
that additionally pre-allocates a "stack"
pool allocator and associates local arrays with preallocated memory.

For details of the kernel and driver-side transformations, please
refer to :any:`SCCVVectorPipeline`

In addition, this pipeline will invoke
:any:`EcstackPoolAllocatorTransformation` to back the remaining
locally allocated arrays from a "stack" pool allocator that requests
a chunk of offloaded memory from an externally defined module.

Parameters
----------
horizontal : :any:`Dimension`
    :any:`Dimension` object describing the variable conventions used in code
    to define the horizontal data dimension and iteration space.
block_dim : :any:`Dimension`
    Optional ``Dimension`` object to define the blocking dimension
    to use for hoisted column arrays if hoisting is enabled.
directive : string or None
    Directives flavour to use for parallelism annotations; either
    ``'openacc'``, ``'omp-gpu'`` or ``None``.
trim_vector_sections : bool
    Flag to trigger trimming of extracted vector sections to remove
    nodes that are not assignments involving vector parallel arrays.
demote_local_arrays : bool
    Flag to trigger local array demotion to scalar variables where possible
check_bounds : bool, optional
    Insert bounds-checks in the kernel to make sure the allocated
    stack size is not exceeded (default: `True`)
"""

# alias for backwards compability
SCCEcStackPipeline = SCCVStackPipeline

SCCSEcStackPipeline = partial(
    Pipeline, classes=(
        SCCFuseVerticalLoops,
        SCCBaseTransformation,
        SCCDevectorTransformation,
        SCCDemoteTransformation,
        SCCSeqRevectorTransformation,
        RemoveUnusedVarTransformation,
        SCCAnnotateTransformation,
        EcstackPoolAllocatorTransformation,
        PragmaModelTransformation
    )
)
"""
SCC-style transformation with sequential kernels
that additionally pre-allocates a "stack"
pool allocator and associates local arrays with preallocated memory.

For details of the kernel and driver-side transformations, please
refer to :any:`SCCSVectorPipeline`

In addition, this pipeline will invoke
:any:`EcstackPoolAllocatorTransformation` to back the remaining
locally allocated arrays from a "stack" pool allocator that requests
a chunk of offloaded memory from an externally defined module.

Parameters
----------
horizontal : :any:`Dimension`
    :any:`Dimension` object describing the variable conventions used in code
    to define the horizontal data dimension and iteration space.
block_dim : :any:`Dimension`
    Optional ``Dimension`` object to define the blocking dimension
    to use for hoisted column arrays if hoisting is enabled.
directive : string or None
    Directives flavour to use for parallelism annotations; either
    ``'openacc'``, ``'omp-gpu'`` or ``None``.
trim_vector_sections : bool
    Flag to trigger trimming of extracted vector sections to remove
    nodes that are not assignments involving vector parallel arrays.
demote_local_arrays : bool
    Flag to trigger local array demotion to scalar variables where possible
check_bounds : bool, optional
    Insert bounds-checks in the kernel to make sure the allocated
    stack size is not exceeded (default: `True`)
"""
loki-ecmwf-0.3.6/loki/transformations/single_column/devector.py0000664000175000017500000002672415167130205025166 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from more_itertools import split_at

from loki.analyse import dataflow_analysis_attached
from loki.batch import Transformation
from loki.ir import (
    nodes as ir, FindNodes, FindScopes, FindVariables, Transformer,
    NestedTransformer, is_loki_pragma, pragmas_attached,
)
from loki.tools import as_tuple, flatten
from loki.types import BasicType
from loki.expression import symbols as sym

from loki.transformations.utilities import (
    find_driver_loops, check_routine_sequential
)


__all__ = [
    'RemoveLoopTransformer', 'SCCDevectorTransformation',
]


class RemoveLoopTransformer(Transformer):
    """
    A :any:`Transformer` that removes all loops over the specified
    dimension.

    Parameters
    ----------
    horizontal : :any:`Dimension`
        The dimension specifying the horizontal vector dimension
    """
    # pylint: disable=unused-argument

    def __init__(self, dimension, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.dimension = dimension

    def visit_Loop(self, loop, **kwargs):
        if loop.variable == self.dimension.index:
            # Recurse and return body as replacement
            return self.visit(loop.body, **kwargs)

        # Rebuild loop after recursing to children
        return self._rebuild(loop, self.visit(loop.children, **kwargs))


class SCCDevectorTransformation(Transformation):
    """
    A set of utilities that can be used to strip vector loops from a :any:`Subroutine`
    and determine the regions of the IR to be placed within thread-parallel loop directives.

    Parameters
    ----------
    horizontal : :any:`Dimension`
        :any:`Dimension` object describing the variable conventions used in code
        to define the horizontal data dimension and iteration space.
    trim_vector_sections : bool
        Flag to trigger trimming of extracted vector sections to remove
        nodes that are not assignments involving vector parallel arrays.
    """

    _separator_node_types = (ir.Loop, ir.Conditional, ir.MultiConditional)

    def __init__(self, horizontal, trim_vector_sections=False):
        self.horizontal = horizontal
        self.trim_vector_sections = trim_vector_sections

    @classmethod
    def _add_separator(cls, node, section, separator_nodes):
        """
        Add either the current node or its outermost parent node from the list of types
        defining a vector region separator (:attr:`separator_node_types`) to the list of
        separator nodes.
        """

        if node in section:
            # If the node is at the current section's level, it's a separator
            separator_nodes.append(node)

        else:
            # If the node is deeper in the IR tree, it's highest ancestor is used
            ancestors = flatten(FindScopes(node).visit(section))
            ancestor_scopes = [a for a in ancestors if isinstance(a, cls._separator_node_types)]
            if len(ancestor_scopes) > 0 and ancestor_scopes[0] not in separator_nodes:
                separator_nodes.append(ancestor_scopes[0])

        return separator_nodes

    @classmethod
    def extract_vector_sections(cls, section, horizontal):
        """
        Extract a contiguous sections of nodes that contains vector-level
        computations and are not interrupted by recursive subroutine calls
        or nested control-flow structures.

        Parameters
        ----------
        section : tuple of :any:`Node`
            A section of nodes from which to extract vector-level sub-sections
        horizontal: :any:`Dimension`
            The dimension specifying the horizontal vector dimension
        """

        # Identify outer "scopes" (loops/conditionals) constrained by recursive routine calls
        calls = FindNodes(ir.CallStatement).visit(section)
        separator_nodes = []

        for call in calls:

            # check if calls have been enriched
            if not call.routine is BasicType.DEFERRED:
                # check if called routine is marked as sequential
                if check_routine_sequential(routine=call.routine):
                    continue

            separator_nodes = cls._add_separator(call, section, separator_nodes)

        for pragma in FindNodes(ir.Pragma).visit(section):
            # Reductions over thread-parallel regions should be marked as a separator node
            if (is_loki_pragma(pragma, starts_with='vector-reduction') or
                is_loki_pragma(pragma, starts_with='end vector-reduction') or
                is_loki_pragma(pragma, starts_with='separator')):

                separator_nodes = cls._add_separator(pragma, section, separator_nodes)

        for assign in FindNodes(ir.Assignment).visit(section):
            if assign.ptr and isinstance(assign.rhs, sym.Array):
                if any(s in assign.rhs.shape for s in horizontal.size_expressions):
                    separator_nodes = cls._add_separator(assign, section, separator_nodes)

            if isinstance(assign.rhs, sym.InlineCall):
                # filter out array arguments
                # we can't use arg_map here because intrinsic functions are not enriched
                _params = assign.rhs.parameters + as_tuple(assign.rhs.kw_parameters.values())
                _params = [p for p in _params if isinstance(p, sym.Array)]

                # check if a horizontal array is passed as an argument, meaning we have a vector
                # InlineCall, e.g. an array reduction intrinsic
                for p in _params:
                    if any(s in (p.dimensions or p.shape) for s in horizontal.size_expressions):
                        separator_nodes = cls._add_separator(assign, section, separator_nodes)

        # Extract contiguous node sections between separator nodes
        assert all(n in section for n in separator_nodes)
        subsections = [as_tuple(s) for s in split_at(section, lambda n: n in separator_nodes)]

        # Filter sub-sections that do not use the horizontal loop index variable
        subsections = [s for s in subsections if horizontal.index in list(FindVariables().visit(s))]

        # Recurse on all separator nodes that might contain further vector sections
        for separator in separator_nodes:

            if isinstance(separator, ir.Loop):
                subsec_body = cls.extract_vector_sections(separator.body, horizontal)
                if subsec_body:
                    subsections += subsec_body

            if isinstance(separator, ir.Conditional):
                subsec_body = cls.extract_vector_sections(separator.body, horizontal)
                if subsec_body:
                    subsections += subsec_body
                # we need to prevent that all (possibly nested) 'else_bodies' are completely wrapped as a section,
                # as 'Conditional's rely on the fact that the first element of each 'else_body'
                # (if 'has_elseif') is a Conditional itself
                for ebody in separator.else_bodies:
                    subsections += cls.extract_vector_sections(ebody, horizontal)

            if isinstance(separator, (ir.MultiConditional, ir.TypeConditional)):
                for body in separator.bodies:
                    subsec_body = cls.extract_vector_sections(body, horizontal)
                    if subsec_body:
                        subsections += subsec_body
                subsec_else = cls.extract_vector_sections(separator.else_body, horizontal)
                if subsec_else:
                    subsections += subsec_else

        return subsections

    @classmethod
    def get_trimmed_sections(cls, routine, horizontal, sections):
        """
        Trim extracted vector sections to remove nodes that are not assignments
        involving vector parallel arrays.
        """

        trimmed_sections = ()
        with dataflow_analysis_attached(routine):
            for sec in sections:
                vec_nodes = [node for node in sec if horizontal.index.lower() in node.uses_symbols]
                start = sec.index(vec_nodes[0])
                end = sec.index(vec_nodes[-1])

                trimmed_sections += (sec[start:end+1],)

        return trimmed_sections

    def transform_subroutine(self, routine, **kwargs):
        """
        Apply SCCDevector utilities to a :any:`Subroutine`.

        Parameters
        ----------
        routine : :any:`Subroutine`
            Subroutine to apply this transformation to.
        role : string
            Role of the subroutine in the call tree; should be ``"kernel"``
        """
        role = kwargs['role']
        targets = kwargs.get('targets', ())

        if role == 'kernel':
            self.process_kernel(routine)
        if role == "driver":
            self.process_driver(routine, targets=targets)

    def process_kernel(self, routine):
        """
        Applies the SCCDevector utilities to a "kernel". This consists simply
        of stripping vector loops and determing which sections of the IR can be
        placed within thread-parallel loops.

        Parameters
        ----------
        routine : :any:`Subroutine`
            Subroutine to apply this transformation to.
        """

        # Remove all vector loops over the specified dimension
        routine.body = RemoveLoopTransformer(dimension=self.horizontal).visit(routine.body)

        # Extract vector-level compute sections from the kernel
        sections = self.extract_vector_sections(routine.body.body, self.horizontal)

        if self.trim_vector_sections:
            sections = self.get_trimmed_sections(routine, self.horizontal, sections)

        # Replace sections with marked Section node
        section_mapper = {s: ir.Section(body=s, label='vector_section') for s in sections}
        routine.body = NestedTransformer(section_mapper).visit(routine.body)

    def process_driver(self, routine, targets=()):
        """
        Applies the SCCDevector utilities to a "driver". This consists simply
        of stripping vector loops and determining which sections of the IR can be
        placed within thread-parallel loops.

        Parameters
        ----------
        routine : :any:`Subroutine`
            Subroutine to apply this transformation to.
        targets : list or string
            List of subroutines that are to be considered as part of
            the transformation call tree.
        """

        with pragmas_attached(routine, ir.Loop, attach_pragma_post=True):
            driver_loops = find_driver_loops(section=routine.body, targets=targets)

        # remove vector loops
        driver_loop_map = {}
        for loop in driver_loops:
            new_driver_loop = RemoveLoopTransformer(dimension=self.horizontal).visit(loop.body)
            new_driver_loop = loop.clone(body=new_driver_loop)
            sections = self.extract_vector_sections(new_driver_loop.body, self.horizontal)
            if self.trim_vector_sections:
                sections = self.get_trimmed_sections(new_driver_loop, self.horizontal, sections)
            section_mapper = {s: ir.Section(body=s, label='vector_section') for s in sections}
            new_driver_loop = NestedTransformer(section_mapper).visit(new_driver_loop)
            driver_loop_map[loop] = new_driver_loop
        routine.body = Transformer(driver_loop_map).visit(routine.body)
loki-ecmwf-0.3.6/loki/transformations/single_column/vertical.py0000664000175000017500000002557015167130205025162 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.batch import Transformation
from loki.expression import (
    symbols as sym
)
from loki.ir import (
    nodes as ir, FindNodes, Transformer,
    is_loki_pragma, pragmas_attached,
    get_pragma_parameters, FindVariables
)
from loki.tools import as_tuple, CaseInsensitiveDict, OrderedSet
from loki.transformations.transform_loop import do_loop_fusion, do_loop_interchange
from loki.transformations.array_indexing import demote_variables
from loki.transformations.utilities import get_local_arrays
from loki.logging import info

__all__ = ['SCCFuseVerticalLoops']

class SCCFuseVerticalLoops(Transformation):
    """
    A transformation to fuse vertical loops and demote temporaries in the vertical
    dimension if possible.

    .. note::
        This transfomation currently relies on pragmas being inserted in the input
        source files. Relevant pragmas are `!$loki loop-interchange` to expose the
        vertical loops (in case vertical loops are nested) and `!$loki loop-fusion`
        possibly grouped via `group()`. Further, if there are loops
        that initialize multilevel arrays (`jk +/- 1`) it is possible to mark those
        loops as `!$loki loop-fusion group(-init)`. This allows to split
        the relevant node and moves the initialization of those arrays to the top of
        the group.

    Parameters
    ----------
    vertical : :any:`Dimension`
        :any:`Dimension` object describing the variable conventions used in code
        to define the vertical data dimension and iteration space.
    apply_to : list of str, optional
        list of routines to apply this transformation to, if not provided or None
        apply to all routines (default: None)
    """

    def __init__(self, vertical=None, apply_to=None):
        self.vertical = vertical
        self.apply_to = apply_to or ()

        if self.vertical is None:
            info('[SCCFuseVerticalLoops] is not applied as the vertical dimension is not defined!')

    def transform_subroutine(self, routine, **kwargs):
        """
        Fuse vertical loops and demote temporaries in the vertical dimension
        if possible.

        Parameters
        ----------
        routine : :any:`Subroutine`
            The subroutine in the vertical loops should be fused and
            temporaries be demoted.
        """
        if self.vertical is None:
            return
        role = kwargs['role']
        if role == 'kernel':
            if self.apply_to and routine.name.lower() not in self.apply_to:
                return
            self.process_kernel(routine)

    def process_kernel(self, routine):
        """
        Current logic (simplified):

        1. loop interchange to expose vertical loops
        2. fuse vertical loops (possibly into multiple groups)
        3. find local arrays to be demoted and apply heuristics to check whether this is safe
        4. demote those arrays which are safe to be demoted
        """
        # find local arrays with a vertical dimension
        relevant_local_arrays = self.find_relevant_local_arrays(routine)
        # find "multilevel" thus "jk +/- 1" arrays
        multilevel_relevant_local_arrays = self.identify_multilevel_arrays(relevant_local_arrays)
        # loop interchange to expose vertical loops as outermost loops
        do_loop_interchange(routine)
        # handle initialization of arrays "jk +/- 1" arrays
        multilevel_relevant_local_arrays_names = OrderedSet(
            arr.name.lower() for arr in multilevel_relevant_local_arrays
        )
        self.correct_init_of_multilevel_arrays(routine, multilevel_relevant_local_arrays_names)
        # fuse vertical loops
        do_loop_fusion(routine)
        # demote in vertical dimension if possible
        relevant_local_arrays_names = OrderedSet(arr.name.lower() for arr in relevant_local_arrays)
        demote_candidates = relevant_local_arrays_names - multilevel_relevant_local_arrays_names
        # check which variables are safe to demote in the vertical
        safe_to_demote = self.check_safe_to_demote(routine, demote_candidates)
        # demote locals in vertical dimension
        dimensions_to_demote = self.vertical.size_expressions + (f"{self.vertical.size}+1",)
        demote_variables(routine, safe_to_demote, dimensions_to_demote)

    def check_safe_to_demote(self, routine, demote_candidates):
        """
        Check whether variables that are candidates to be demoted in the vertical dimension are really
        safe to be demoted.

        Current heuristic: If the candidate is used in more than one vertical loop, assume it is NOT safe
        to demote!
        """
        loop_var_map = CaseInsensitiveDict()
        with pragmas_attached(routine, ir.Loop):
            for loop in FindNodes(ir.Loop).visit(routine.body):
                if loop.variable == self.vertical.index:
                    if is_loki_pragma(loop.pragma, starts_with='fused-loop'):
                        parameters = get_pragma_parameters(loop.pragma, starts_with='fused-loop')
                        group = parameters.get('group', 'default')
                        if group == 'ignore':
                            continue
                        for var in FindVariables().visit(loop.body):
                            if isinstance(var, sym.Array):
                                loop_var_map.setdefault(var.name, set()).add(group)

        safe_to_demote = ()
        for var in demote_candidates:
            if var in loop_var_map and len(loop_var_map[var]) <= 1:
                safe_to_demote += (var,)

        return safe_to_demote

    def find_relevant_local_arrays(self, routine):
        """
        Find local arrays/temporaries that do have the vertical dimension.
        """
        # local/temporary arrays
        local_arrays = get_local_arrays(routine, routine.body)
        # only those with the vertical size within shape
        relevant_local_arrays = [
            arr for arr in local_arrays
            if any(s in self.vertical.sizes for s in FindVariables().visit(arr.shape))
        ]
        # filter arrays to be ignored (for whatever reason)
        ignore_names = self.find_local_arrays_to_be_ignored(routine)
        if ignore_names:
            relevant_local_arrays = [arr for arr in relevant_local_arrays if arr.name.lower() not in ignore_names]
        return relevant_local_arrays

    def find_local_arrays_to_be_ignored(self, routine):
        """
        Identify variables to be ignore regarding demotion for whatever reason.

        Reasons are:

        * explicitly marked to be ignored via pragmas within the input source file, e.g.,
          'loki k-caching ignore(var1, var2, ...)'
        """
        ignore = ()
        pragmas = FindNodes(ir.Pragma).visit(routine.body)
        # look for 'loki k-caching ignore(var1, var2, ...)' pragmas within routine and ignore those vars
        for pragma in pragmas:
            if is_loki_pragma(pragma, starts_with='k-caching'):
                if pragma_ignore := get_pragma_parameters(pragma, starts_with='k-caching').get('ignore', None):
                    ignore += as_tuple(v.strip() for v in pragma_ignore.split(','))
        ignore_names = set(var.lower() for var in ignore)
        return ignore_names

    def identify_multilevel_arrays(self, local_arrays):
        """
        Identify local arrays/temporaries that have an access in the vertical dimension
        that is different to '', e.g., ' +/- 1'
        """
        multilevel_local_arrays = []
        for arr in local_arrays:
            for dim in arr.dimensions:
                if self.vertical.index in FindVariables().visit(dim):
                    # dim is not equal to vertical.index e.g., vertical.index +/- 1
                    if dim != self.vertical.index:
                        multilevel_local_arrays.append(arr)
        return multilevel_local_arrays

    def correct_init_of_multilevel_arrays(self, routine, multilevel_local_arrays):
        """
        Possibly handle initialization of those multilevel local arrays via
        splitting relevant loops or rather creating a new node with the relevant
        nodes moved to the newly created loop.

        .. note::
            This relies on pragmas being inserted in the input source code!
        """
        loop_map = {}
        # find/identify loops with pragma 'loop-fusion group(-init)'
        with pragmas_attached(routine, ir.Loop):
            loop_map = {}
            for loop in FindNodes(ir.Loop).visit(routine.body):
                if is_loki_pragma(loop.pragma, starts_with='loop-fusion'):
                    parameters = get_pragma_parameters(loop.pragma, starts_with='loop-fusion')
                    group = parameters.get('group', 'default')
                    if group.endswith('-init'):
                        nodes_to_be_moved = ()
                        nodes = FindNodes(ir.Assignment).visit(loop.body)
                        node_map = {}
                        node_map_init = {}
                        # find nodes that have multilevel arrays
                        for node in nodes:
                            node_vars = FindVariables().visit(node)
                            if any(node_var.name.lower() in multilevel_local_arrays for node_var in node_vars):
                                nodes_to_be_moved += (node,)
                                node_map[node] = None
                            else:
                                node_map_init[node] = None
                        # split the loop/create a new node to move those nodes with
                        # multilevel arrays to the new node
                        if nodes_to_be_moved:
                            pragmas = loop.pragma
                            new_pragmas = [pragma.clone(content=pragma.content.replace('-init', '')) if '-init'
                                    in pragma.content else pragma for pragma in pragmas]
                            # init part
                            transf_init = Transformer(node_map_init).visit(loop.clone(\
                                    pragma=as_tuple(ir.Pragma(keyword='loki',
                                        content='fused-loop group(ignore)'))))
                            # rest of the original node/loop
                            transf_orig = Transformer(node_map).visit(loop.clone(pragma=as_tuple(new_pragmas)))
                            loop_map[loop] = (ir.Comment('! Loki generated loop for init ...'),
                                    transf_init, transf_orig)
            if loop_map:
                routine.body = Transformer(loop_map).visit(routine.body)
loki-ecmwf-0.3.6/loki/transformations/single_column/annotate.py0000664000175000017500000003552215167130205025160 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from collections import defaultdict
from loki.analyse import dataflow_analysis_attached
from loki.batch import Transformation
from loki.expression import symbols as sym, is_dimension_constant
from loki.ir import (
    nodes as ir, FindNodes, FindVariables, Transformer,
    pragmas_attached, is_loki_pragma, get_pragma_parameters,
    pragma_regions_attached
)
from loki.logging import info, warning
from loki.tools import as_tuple, flatten
from loki.types import DerivedType

from loki.transformations.utilities import (
    find_driver_loops, get_local_arrays
)


__all__ = ['SCCAnnotateTransformation']


class SCCAnnotateTransformation(Transformation):
    """
    A set of utilities to insert generic Loki directives. This includes both :any:`Loop` and
    :any:`Subroutine` level annotations.

    Parameters
    ----------
    block_dim : :any:`Dimension`
        Optional ``Dimension`` object to define the blocking dimension
        to use for hoisted column arrays if hoisting is enabled.
    privatise_derived_types : bool, default: False
        Flag to enable privatising derived-type objects in driver loops.
    """

    def __init__(self, block_dim, privatise_derived_types=False):
        self.block_dim = block_dim
        self.privatise_derived_types = privatise_derived_types

    def annotate_vector_loops(self, routine):
        """
        Insert ``!$loki loop vector`` for previously marked loops,
        including addition of the necessary private variable declarations.

        Parameters
        ----------
        routine : :any:`Subroutine`
            The subroutine in the vector loops should be removed.
        """

        # Find any local arrays that need explicitly privatization
        private_arrays = get_local_arrays(routine, section=routine.spec)
        private_arrays = [
            v for v in private_arrays
            if all(is_dimension_constant(d) for d in v.shape)
        ]

        if private_arrays:
            # Log private arrays in vector regions, as these can impact performance
            info(
                f'[Loki-SCC::Annotate] Candidates for array privatisation in {routine.name}: '
                f'{[a.name for a in private_arrays]}'
            )

        with pragmas_attached(routine, ir.Loop):
            with dataflow_analysis_attached(routine):
                for loop in FindNodes(ir.Loop).visit(routine.body):
                    for pragma in as_tuple(loop.pragma):
                        if not is_loki_pragma(pragma, starts_with='loop vector'):
                            continue
                        if not private_arrays:
                            continue
                        pragma_params = get_pragma_parameters(pragma, starts_with='loop vector')
                        if 'reduction' not in pragma_params:
                            _private_vars = pragma_params.get('private', []) + private_arrays

                            # filter out read-only arrays
                            _private_vars = [v for v in _private_vars
                                if v.name.lower() in loop.defines_symbols
                            ]

                            # Add private clause
                            pragma_params['private'] = ', '.join([v.name for v in _private_vars])
                            pragma_content = [f'{kw}({val})' if val else kw for kw, val in pragma_params.items()]
                            pragma._update(content=f'loop vector {" ".join(pragma_content)}'.strip())

    def warn_vec_within_seq_loops(self, routine):
        """
        Check for vector inside sequential loops and print warning.

        Parameters
        ----------
        routine : :any:`Subroutine`
            The subroutine in which to check for vector inside sequential loops
        """
        with pragmas_attached(routine, ir.Loop):
            for loop in FindNodes(ir.Loop).visit(routine.body):
                if not is_loki_pragma(loop.pragma, starts_with='loop seq'):
                    continue
                # Warn if we detect vector insisde sequential loop nesting
                nested_loops = FindNodes(ir.Loop).visit(loop.body)
                loop_pragmas = flatten(as_tuple(l.pragma) for l in as_tuple(nested_loops))
                if any('loop vector' in pragma.content for pragma in loop_pragmas):
                    info(f'[Loki-SCC::Annotate] Detected vector loop in sequential loop in {routine.name}')

    def annotate_kernel_routine(self, routine):
        """
        Insert ``!$loki routine seq/vector`` directives and wrap
        subroutine body in ``!$loki device-present`` directives.

        Parameters
        ----------
        routine : :any:`Subroutine`
            The subroutine to which annotations will be added
        """

        # Move `!$loki routine seq/vector` pragmas to spec
        routine_pragmas = [
            pragma for pragma in FindNodes(ir.Pragma).visit(routine.body)
            if is_loki_pragma(pragma, starts_with='routine')
        ]
        routine.spec.append(routine_pragmas)
        routine.body = Transformer({pragma: None for pragma in routine_pragmas}).visit(routine.body)

        # Get the names of all array and derived type arguments
        args = [a for a in routine.arguments if isinstance(a, sym.Array)]
        args += [a for a in routine.arguments if isinstance(a.type.dtype, DerivedType)]
        argnames = [str(a.name) for a in args]

        if argnames:
            # Add comment to prevent false-attachment in case it is preceded by an "END DO" statement
            content = f'device-present vars({", ".join(argnames)})'
            routine.body.prepend(ir.Pragma(keyword='loki', content=content))
            # Add comment to prevent false-attachment in case it is preceded by an "END DO" statement
            content = 'end device-present'
            routine.body.append((ir.Comment(text=''), ir.Pragma(keyword='loki', content=content)))

    def transform_subroutine(self, routine, **kwargs):
        """
        Apply pragma annotations according to ``!$loki`` placeholder
        directives.

        This routine effectively adds ``!$loki device-present``
        clauses around kernel routine bodies and adds
        ``private`` clauses to loop annotations.

        Parameters
        ----------
        routine : :any:`Subroutine`
            Subroutine to apply this transformation to.
        role : string
            Role of the subroutine in the call tree; should be ``"kernel"``
        """

        role = kwargs['role']
        targets = as_tuple(kwargs.get('targets'))

        if role == 'kernel':
            # Bail if this routine has been processed before
            for p in FindNodes(ir.Pragma).visit(routine.ir):
                # Check if `!$acc routine` has already been added,
                #  e.g., this transformation has already been applied
                if p.keyword.lower() == 'acc' and 'routine' in p.content.lower():
                    return

            # Mark all parallel vector loops as `!$loki loop vector`
            self.annotate_vector_loops(routine)

            # Check for sequential loops within vector loops
            self.warn_vec_within_seq_loops(routine)

            # Wrap the routine body in `!$loki device-present vars(...)` markers to
            # ensure all arguments are device-resident.
            self.annotate_kernel_routine(routine)


        if role == 'driver':
            # Mark all parallel vector loops as `!$loki loop vector`
            self.annotate_vector_loops(routine)

            # Check for sequential loops within vector loops
            self.warn_vec_within_seq_loops(routine)

            with pragma_regions_attached(routine):
                with pragmas_attached(routine, ir.Loop, attach_pragma_post=True):
                    # Find variables with existing OpenACC data declarations
                    acc_vars = self.find_acc_vars(routine, targets)

                    driver_loops = find_driver_loops(section=routine.body, targets=targets)
                    for loop in driver_loops:
                        self.annotate_driver_loop(loop, acc_vars.get(loop, []))

    def find_acc_vars(self, routine, targets):
        """
        Find variables already specified in loki/acc data clauses.

        Parameters
        ----------
        routine : :any:`Subroutine`
            Subroutine to apply this transformation to.
        targets : list or string
            List of subroutines that are to be considered as part of
            the transformation call tree.
        """

        acc_vars = defaultdict(list)

        for region in FindNodes(ir.PragmaRegion).visit(routine.body):
            pragma_keyword = region.pragma.keyword.lower()
            if pragma_keyword in ['loki', 'acc']:
                if pragma_keyword == 'acc':
                    parameters = get_pragma_parameters(region.pragma, starts_with='data', only_loki_pragmas=False)
                else:
                    parameters = get_pragma_parameters(region.pragma, starts_with='structured-data',
                            only_loki_pragmas=False)
                if parameters is not None:
                    driver_loops = find_driver_loops(section=region.body, targets=targets)
                    if not driver_loops:
                        continue

                    # When a key is given multiple times, get_pragma_parameters returns a list
                    # We merge them here into single entries to make our life easier below
                    parameters = {key: ', '.join(as_tuple(value)) for key, value in parameters.items()}
                    if (default := parameters.get('default', None)):
                        if not 'none' in [p.strip().lower() for p in default.split(',')]:
                            for loop in driver_loops:

                                _vars = [var.name.lower() for var in FindVariables(unique=True).visit(loop)]
                                acc_vars[loop] += _vars
                    else:
                        _vars = [
                            p.strip().lower()
                            for category in ('present', 'copy', 'copyin', 'copyout', 'deviceptr')
                            for p in parameters.get(category, '').split(',')
                        ]

                        for loop in driver_loops:
                            acc_vars[loop] += _vars

        return acc_vars

    @classmethod
    def device_alloc_column_locals(cls, routine, column_locals):
        """
        Add explicit OpenACC statements for creating device variables for hoisted column locals.

        Parameters
        ----------
        routine : :any:`Subroutine`
            Subroutine to apply this transformation to.
        column_locals : list
            List of column locals to be hoisted to driver layer
        """

        if column_locals:
            vnames = ', '.join(v.name for v in column_locals)
            pragma = ir.Pragma(keyword='loki', content=f'unstructured-data create({vnames})')
            pragma_post = ir.Pragma(keyword='loki', content=f'exit unstructured-data delete({vnames})')
            # Add comments around standalone pragmas to avoid false attachment
            routine.body.prepend((ir.Comment(''), pragma, ir.Comment('')))
            routine.body.append((ir.Comment(''), pragma_post, ir.Comment('')))

    def annotate_driver_loop(self, loop, acc_vars):
        """
        Annotate driver block loop with generic Loki pragmas.

        Parameters
        ----------
        loop : :any:`Loop`
            Driver :any:`Loop` to wrap in generic Loki pragmas.
        acc_vars : list
            Variables already declared in generic Loki data directives.
        """
        sizes = self.block_dim.size_expressions

        # Mark driver loop as "gang parallel".
        loop_vars = FindVariables(unique=True).visit(loop)
        arrays = [v for v in loop_vars if isinstance(v, sym.Array)]
        arrays = [v for v in arrays if not v.type.intent]
        arrays = [v for v in arrays if not v.type.pointer]
        arrays = [v for v in arrays if not v.name_parts[0].lower() in acc_vars]
        arrays = [v for v in arrays if not any(d in sizes for d in as_tuple(v.shape))]
        private_sym = arrays

        if self.privatise_derived_types:
            # Derived-types are classified as "aggregate variables" in the OpenACC and OpenMP offload
            # standards and have the same implicit data attributes as arrays. Therefore, local derived-type
            # scalars must also be privatised.
            structs = [v for v in loop_vars if isinstance(v.type.dtype, sym.DerivedType)]
            structs = [v for v in structs if not v.name_parts[0].lower() in acc_vars]
            structs = [v for v in structs if not v.type.intent]
            structs = [v for v in structs if not v in arrays]

            # only privatise derived-type parent
            private_sym = [v for v in private_sym if not v.name_parts[0].lower() in structs]

            if (dynamic_structs := [v.name for v in structs if (v.type.pointer or v.type.allocatable)]):
                warning(f'[Loki-SCC::Annotate] dynamically allocated structs are being privatised: {dynamic_structs}')

            # Filter out arrays that are explicitly allocated with block dimension
            private_sym +=  [
                v for v in structs
                if not any(d in sizes for d in as_tuple(getattr(v, 'shape', [])))
            ]

        private_vars = ', '.join(dict.fromkeys(v.name for v in private_sym))
        private_clause = '' if not private_vars else f' private({private_vars})'

        for pragma in as_tuple(loop.pragma):
            if is_loki_pragma(pragma, starts_with='loop driver'):
                # Replace `!$loki loop driver` pragma with OpenACC equivalent
                params = get_pragma_parameters(loop.pragma, starts_with='loop driver')
                vlength = params.get('vector_length')
                asynchronous = params.get('async')
                vlength_clause = f' vlength({vlength})' if vlength else ''
                asynchronous_clause = f' async({asynchronous})' if asynchronous else ''

                content = f'loop gang{private_clause}{vlength_clause}{asynchronous_clause}'
                pragma_new = ir.Pragma(keyword='loki', content=content)
                pragma_post = ir.Pragma(keyword='loki', content='end loop gang')

                # Replace existing loki pragma and add post-pragma
                loop_pragmas = tuple(p for p in as_tuple(loop.pragma) if p is not pragma)
                loop._update(
                    pragma=loop_pragmas + (pragma_new,),
                    pragma_post=(pragma_post,) + as_tuple(loop.pragma_post)
                )
loki-ecmwf-0.3.6/loki/transformations/single_column/revector.py0000664000175000017500000004654615167130205025210 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import re

from loki.batch import Transformation
from loki.expression import symbols as sym
from loki.ir import (
    nodes as ir, FindNodes, Transformer, is_loki_pragma,
    pragmas_attached, pragma_regions_attached
)
from loki.tools import as_tuple

from loki.transformations.utilities import (
    get_integer_variable, get_loop_bounds, find_driver_loops,
    check_routine_sequential, single_variable_declaration
)


__all__ = [
    'SCCRevectorTransformation', 'SCCVecRevectorTransformation',
    'SCCSeqRevectorTransformation', 'wrap_vector_section',
    'RevectorSectionTransformer'
]


def wrap_vector_section(section, routine, bounds, index, insert_pragma=True):
    """
    Wrap a section of nodes in a vector-level loop across the horizontal.

    Parameters
    ----------
    section : tuple of :any:`Node`
        A section of nodes to be wrapped in a vector-level loop
    routine : :any:`Subroutine`
        The subroutine in the vector loops should be removed.
    horizontal: :any:`Dimension`
        The dimension specifying the horizontal vector dimension
    insert_pragma: bool, optional
        Adds a ``!$loki vector`` pragma around the created loop
    """
    # Create a single loop around the horizontal from a given body
    index = get_integer_variable(routine, index)
    bounds = sym.LoopRange(bounds)

    # Ensure we clone all body nodes, to avoid recursion issues
    body = Transformer().visit(section)

    # Add a marker pragma for later annotations
    pragma = (ir.Pragma('loki', content='loop vector'),) if insert_pragma else None
    vector_loop = ir.Loop(variable=index, bounds=bounds, body=body, pragma=pragma)

    # Add a comment before and after the pragma-annotated loop to ensure
    # we do not overlap with neighbouring pragmas
    return (ir.Comment(''), vector_loop, ir.Comment(''))


class RevectorSectionTransformer(Transformer):
    """
    :any:`Transformer` that replaces :any:`Section` objects labelled
    with ``"vector_section"`` with vector-level loops across the
    horizontal.

    Parameters
    ----------
    routine : :any:`Subroutine`
        The subroutine in the vector loops should be removed.
    horizontal: :any:`Dimension`
        The dimension specifying the horizontal vector dimension
    insert_pragma: bool, optional
        Adds a ``!$loki vector`` pragma around the created loop
    """
    # pylint: disable=unused-argument

    def __init__(self, routine, horizontal, *args, insert_pragma=True, **kwargs):
        super().__init__(*args, **kwargs)

        self.routine = routine
        self.horizontal = horizontal

        self.insert_pragma = insert_pragma

    def visit_Section(self, s, **kwargs):
        if s.label == 'vector_section':
            # Derive the loop bounds wrap section in loop
            bounds = get_loop_bounds(self.routine, dimension=self.horizontal)
            return wrap_vector_section(
                s.body, self.routine, bounds=bounds, index=self.horizontal.index,
                insert_pragma=self.insert_pragma
            )

        # Rebuild loop after recursing to children
        return self._rebuild(s, self.visit(s.children))


class BaseRevectorTransformation(Transformation):
    """
    A base/parent class for transformation to wrap thread-parallel IR sections within a horizontal loop.
    This transformation relies on markers placed by :any:`SCCDevectorTransformation`.

    Parameters
    ----------
    horizontal : :any:`Dimension`
        :any:`Dimension` object describing the variable conventions used in code
        to define the horizontal data dimension and iteration space.
    """

    _reduction_match_pattern = r'reduction\([\+\*\.\w \t]+:[\w\, \t]+\)'

    def __init__(self, horizontal):
        self.horizontal = horizontal

    def mark_vector_reductions(self, routine, section):
        """
        Mark vector-reduction loops in marked vector-reduction
        regions.

        If a region explicitly marked with
        ``!$loki vector-reduction()``/
        ``!$loki end vector-reduction`` is encountered, we replace
        existing ``!$loki loop vector`` loop pragmas and add the
        reduction keyword and clause. These will be turned into
        OpenACC equivalents by :any:`SCCAnnotate`.
        """
        with pragma_regions_attached(routine):
            for region in FindNodes(ir.PragmaRegion).visit(section):
                if is_loki_pragma(region.pragma, starts_with='vector-reduction'):
                    if (reduction_clause := re.search(self._reduction_match_pattern, region.pragma.content)):

                        loops = FindNodes(ir.Loop).visit(region)
                        assert len(loops) == 1
                        pragma = ir.Pragma(keyword='loki', content=f'loop vector {reduction_clause[0]}')
                        # Update loop and region in place to remove marker pragmas
                        loops[0]._update(pragma=(pragma,))
                        region._update(pragma=None, pragma_post=None)

    def mark_seq_loops(self, section):
        """
        Mark interior sequential loops in a thread-parallel section
        with ``!$loki loop seq`` for later annotation.

        This utility requires loop-pragmas to be attached via
        :any:`pragmas_attached`. It also updates loops in-place.

        Parameters
        ----------
        section : tuple of :any:`Node`
            Code section in which to mark "seq loops".
        """
        for loop in FindNodes(ir.Loop).visit(section):

            # Skip loops explicitly marked with `!$loki/claw nodep`
            if loop.pragma and any('nodep' in p.content.lower() for p in as_tuple(loop.pragma)):
                continue

            # Mark loop as sequential with `!$loki loop seq`
            if loop.variable != self.horizontal.index:
                loop._update(pragma=(ir.Pragma(keyword='loki', content='loop seq'),))

    def mark_driver_loop(self, routine, loop):
        """
        Add ``!$loki loop driver`` pragmas to outer block loops and
        add ``vector-length(size)`` clause for later annotations.

        This method assumes that pragmas have been attached via
        :any:`pragmas_attached`.
        """

        # Skip loops with existing parallel annotations
        if loop.pragma:
            if any(pragma.keyword.lower() in ['omp', 'acc'] and 'parallel' in pragma.content.lower()
                   for pragma in loop.pragma):
                return

        # Find a horizontal size variable to mark vector_length
        symbol_map = routine.symbol_map
        sizes = tuple(
            routine.resolve_typebound_var(size, symbol_map) for size in self.horizontal.size_expressions
            if size.split('%')[0] in symbol_map
        )
        vector_length = f' vector_length({sizes[0]})' if sizes else ''

        # Replace existing `!$loki loop driver markers, but leave all others
        loop_pragmas = []
        driver_pragma = None
        for p in as_tuple(loop.pragma):
            if is_loki_pragma(p, starts_with='driver-loop'):
                driver_pragma = p
            else:
                loop_pragmas.append(p)
        loop_pragmas = tuple(loop_pragmas)
        driver_content = f'loop driver{vector_length}'
        if driver_pragma is not None:
            driver_content = driver_pragma.content.replace('driver-loop', driver_content)
        driver_pragma = ir.Pragma(keyword='loki', content=driver_content)
        loop._update(pragma=loop_pragmas + (driver_pragma,))


class SCCVecRevectorTransformation(BaseRevectorTransformation):
    """
    A transformation to wrap thread-parallel IR sections within a horizontal loop.
    This transformation relies on markers placed by :any:`SCCDevectorTransformation`.

    Parameters
    ----------
    horizontal : :any:`Dimension`
        :any:`Dimension` object describing the variable conventions used in code
        to define the horizontal data dimension and iteration space.
    """

    def transform_subroutine(self, routine, **kwargs):
        """
        Wrap vector-parallel sections in vector :any:`Loop` objects.

        This wraps all thread-parallel sections within "kernel"
        routines or within the parallel loops in "driver" routines.

        The markers placed by :any:`SCCDevectorTransformation` are removed

        Parameters
        ----------
        routine : :any:`Subroutine`
            Subroutine to apply this transformation to.
        role : str
            Must be either ``"kernel"`` or ``"driver"``
        targets : tuple or str
            Tuple of target routine names for determining "driver" loops
        """
        role = kwargs['role']
        targets = kwargs.get('targets', ())

        if role == 'kernel':
            # Skip if kernel is marked as `!$loki routine seq`
            if check_routine_sequential(routine):
                return

            # Revector all marked vector sections within the kernel body
            routine.body = RevectorSectionTransformer(routine, self.horizontal).visit(routine.body)

            with pragmas_attached(routine, ir.Loop):
                # Check for explicitly labelled vector-reduction regions
                self.mark_vector_reductions(routine, routine.body)

                # Mark sequential loops inside vector sections
                self.mark_seq_loops(routine.body)

            # Mark subroutine as vector parallel for later annotation
            routine.spec.append(ir.Pragma(keyword='loki', content='routine vector'))

        if role == 'driver':
            with pragmas_attached(routine, ir.Loop):
                driver_loops = find_driver_loops(section=routine.body, targets=targets)

                for loop in driver_loops:
                    # Revector all marked sections within the driver loop body
                    loop._update(body=RevectorSectionTransformer(routine, self.horizontal).visit(loop.body))

                    # Check for explicitly labelled vector-reduction regions
                    self.mark_vector_reductions(routine, loop.body)

                    # Mark sequential loops inside vector sections
                    self.mark_seq_loops(loop.body)

                    # Mark outer driver loops
                    self.mark_driver_loop(routine, loop)

# alias for backwards compability
SCCRevectorTransformation = SCCVecRevectorTransformation

class SCCSeqRevectorTransformation(BaseRevectorTransformation):
    """
    A transformation to wrap thread-parallel IR sections within a horizontal loop
    in a way that the horizontal loop is hoisted/moved to the driver level while
    the horizontal/loop index is passed as an argument.
    This transformation relies on markers placed by :any:`SCCDevectorTransformation`.
    Parameters
    ----------
    horizontal : :any:`Dimension`
        :any:`Dimension` object describing the variable conventions used in code
        to define the horizontal data dimension and iteration space.
    """

    process_ignored_items = True

    def remove_vector_sections(self, section):
        """
        Remove all thread-parallel :any:`Section` objects within a given
        code section
        Parameters
        ----------
        routine : :any:`Subroutine`
            Subroutine to apply this transformation to.
        section : tuple of :any:`Node`
            Code section in which to replace vector-parallel
            :any:`Section` objects.
        """
        # Wrap all thread-parallel sections into horizontal thread loops
        mapper = {
            s: s.body
            for s in FindNodes(ir.Section).visit(section)
            if s.label == 'vector_section'
        }
        return Transformer(mapper).visit(section)

    def mark_vector_reductions(self, routine, section):
        """
        Vector reductions are not applicable to sequential routines
        so we raise an axception here.
        """

        with pragma_regions_attached(routine):
            for region in FindNodes(ir.PragmaRegion).visit(section):
                if is_loki_pragma(region.pragma, starts_with='vector-reduction'):
                    if re.search(self._reduction_match_pattern, region.pragma.content):
                        raise RuntimeError(f'[Loki::SCCSeq] Vector reduction invalid in seq routine {routine.name}')

    @staticmethod
    def _get_loop_bound(bound, call_arg_map):
        if isinstance(bound, tuple):
            for alias in bound:
                alias_arg = alias
                elem = None
                if '%' in alias:
                    elem = alias_arg.split('%')[1]
                    alias_arg = alias_arg.split('%')[0]
                if alias_arg.lower() in call_arg_map:
                    return (call_arg_map[alias_arg.lower()], elem)
        return (call_arg_map[bound.lower()], None)

    def transform_subroutine(self, routine, **kwargs):
        """
        Wrap vector-parallel sections in vector :any:`Loop` objects.
        This wraps all thread-parallel sections within "kernel"
        routines or within the parallel loops in "driver" routines.
        The markers placed by :any:`SCCDevectorTransformation` are removed
        Parameters
        ----------
        routine : :any:`Subroutine`
            Subroutine to apply this transformation to.
        role : str
            Must be either ``"kernel"`` or ``"driver"``
        targets : tuple or str
            Tuple of target routine names for determining "driver" loops
        """
        role = kwargs['role']
        targets = tuple(str(t).lower() for t in as_tuple(kwargs.get('targets', None)))
        # ignore = kwargs.get('ignore', ())
        item = kwargs.get('item', None)
        ignore = item.ignore if item else () + tuple(str(t).lower() for t in as_tuple(kwargs.get('ignore', None)))

        if role == 'kernel':
            # Skip if kernel is marked as `!$loki routine seq`
            if check_routine_sequential(routine):
                return

            if self.horizontal.index not in routine.variables:
                jl = get_integer_variable(routine, self.horizontal.index)
                routine.arguments += (jl.clone(type=jl.type.clone(intent='in')),)
            else:
                single_variable_declaration(routine, variables=(self.horizontal.index,))
                routine.symbol_attrs.update({self.horizontal.index:\
                    routine.variable_map[self.horizontal.index].type.clone(intent='in')})
                if self.horizontal.index not in routine.arguments:
                    routine.arguments += (get_integer_variable(routine, self.horizontal.index),)

            # add horizontal.index as argument for calls/routines being in targets
            call_map = {}
            for call in FindNodes(ir.CallStatement).visit(routine.body):
                if str(call.name).lower() in targets or call.routine.name.lower() in ignore:
                    if check_routine_sequential(call.routine):
                        continue
                    if self.horizontal.index not in call.arg_map:
                        new_kwarg = (self.horizontal.index, get_integer_variable(routine, self.horizontal.index))
                        updated_call = call.clone(kwarguments=call.kwarguments + (new_kwarg,))
                        call_map[call] = updated_call
                    if call.routine.name.lower() in ignore:
                        if self.horizontal.index not in call.routine.variables:
                            jl = get_integer_variable(call.routine, self.horizontal.index)
                            call.routine.arguments += (jl.clone(type=jl.type.clone(intent='in')),)
                        else:
                            single_variable_declaration(call.routine, variables=(self.horizontal.index,))
                            call.routine.symbol_attrs.update({self.horizontal.index:\
                                call.routine.variable_map[self.horizontal.index].type.clone(intent='in')})
                            if self.horizontal.index not in call.routine.arguments:
                                call.routine.arguments += (get_integer_variable(call.routine, self.horizontal.index),)
            routine.body = Transformer(call_map).visit(routine.body)

            # Revector all marked vector sections within the kernel body
            routine.body = self.remove_vector_sections(routine.body)

            with pragmas_attached(routine, ir.Loop):
                # Check for explicitly labelled vector-reduction regions
                self.mark_vector_reductions(routine, routine.body)

                # Mark sequential loops inside vector sections
                self.mark_seq_loops(routine.body)

            # Mark subroutine as seq for later annotation
            routine.spec.append(ir.Pragma(keyword='loki', content='routine seq'))

        if role == 'driver':

            # add horizontal.index, e.g., 'jl'
            index = get_integer_variable(routine, self.horizontal.index)
            routine.variables += (index,)

            with pragmas_attached(routine, ir.Loop):
                driver_loops = find_driver_loops(section=routine.body, targets=targets)

                for loop in driver_loops:

                    # Wrap calls being in targets in a horizontal loop and add horizontal.index as argument
                    call_map = {}
                    for call in FindNodes(ir.CallStatement).visit(loop.body):
                        if str(call.name).lower() in targets:
                            if self.horizontal.index not in call.arg_map:
                                new_kwarg = (self.horizontal.index,
                                        get_integer_variable(routine, self.horizontal.index))
                                updated_call = call.clone(kwarguments=call.kwarguments + (new_kwarg,))
                                call_arg_map = {k.name.lower(): v for (k, v) in call.arg_map.items()}
                                # loop bound(s) could be derived types ...
                                ltmp = self._get_loop_bound(self.horizontal.lower, call_arg_map)
                                utmp = self._get_loop_bound(self.horizontal.upper, call_arg_map)
                                lower = ltmp[0] if ltmp[1] is None else sym.Variable(name=ltmp[1], parent=ltmp[0])
                                upper = utmp[0] if utmp[1] is None else sym.Variable(name=utmp[1], parent=utmp[0])
                                # wrap call with horizontal loop
                                loop_bounds = (lower, upper)
                                call_map[call] = wrap_vector_section((updated_call,), routine, bounds=loop_bounds,
                                        insert_pragma=True, index=self.horizontal.index)

                    loop._update(body=Transformer(call_map).visit(loop.body))

                    # Revector all marked sections within the driver loop body
                    loop._update(body=RevectorSectionTransformer(routine, self.horizontal).visit(loop.body))

                    # Check for explicitly labelled vector-reduction regions
                    super().mark_vector_reductions(routine, loop.body)

                    # Mark sequential loops inside vector sections
                    self.mark_seq_loops(loop.body)

                    # Mark outer driver loops
                    self.mark_driver_loop(routine, loop)
loki-ecmwf-0.3.6/loki/transformations/idempotence.py0000664000175000017500000000122515167130205022776 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.batch import Transformation


class IdemTransformation(Transformation):
    """
    A custom transformation that does absolutely nothing!

    This can be used to test simple parse-unparse cycles.
    """

    def transform_subroutine(self, routine, **kwargs):
        pass
loki-ecmwf-0.3.6/loki/transformations/drhook.py0000664000175000017500000000732215167130205021774 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Utility transformations to update or remove calls to DR_HOOK.
"""

from loki.batch import Transformation
from loki.expression import Literal
from loki.ir import (
    FindNodes, Transformer, CallStatement, Conditional, Import
)
from loki.tools import as_tuple


__all__ = ['DrHookTransformation']


def remove_unused_drhook_import(routine):
    """
    Remove unsed DRHOOK imports and corresponding handle.

    Parameters
    ----------
    routine : :any:`Subroutine`
        The subroutine from which to remove DRHOOK import/handle.
    """

    mapper = {}
    for imp in FindNodes(Import).visit(routine.spec):
        if imp.module.lower() == 'yomhook':
            mapper[imp] = None

    if mapper:
        routine.spec = Transformer(mapper).visit(routine.spec)

    #Remove unused zhook_handle
    routine.variables = as_tuple(v for v in routine.variables if v != 'zhook_handle')


class DrHookTransformation(Transformation):
    """
    Re-write or remove the DrHook label markers either by appending a
    suffix string or by applying an explicit mapping.

    In addition, calls to DR_HOOK can also be removed, including their
    enclosing inline-conditional.

    Parameters
    ----------
    suffix : str
        String suffix to append to DrHook labels
    rename : dict of str, optional
        Dict with explicit label rename mappings
    remove : bool
        Flag to explicitly remove calls to ``DR_HOOK``
    kernel_only : boolean
        Only apply to subroutines marked as "kernel"; default: ``False``
    """

    recurse_to_internal_procedures = True

    def __init__(
            self, suffix=None, rename=None, remove=False, kernel_only=True
    ):
        self.suffix = suffix
        self.rename = rename
        self.remove = remove
        self.kernel_only = kernel_only

    def transform_subroutine(self, routine, **kwargs):
        """
        Apply transformation to subroutine object
        """
        role = kwargs.get('role')

        # Leave DR_HOOK annotations in driver routine
        if self.kernel_only and role == 'driver':
            return

        mapper = {}
        for call in FindNodes(CallStatement).visit(routine.body):
            # Lazily changing the DrHook label in-place
            if call.name == 'DR_HOOK':
                if self.remove:
                    mapper[call] = None
                else:
                    label = call.arguments[0].value
                    if self.rename and label in self.rename:
                        # Replace explicitly mapped label directly
                        new_args = (Literal(value=self.rename[label]),) + call.arguments[1:]
                        mapper[call] = call.clone(arguments=new_args)

                    elif self.suffix:
                        # Otherwise append a given suffix
                        new_label = f'{label}_{self.suffix}'
                        new_args = (Literal(value=new_label),) + call.arguments[1:]
                        mapper[call] = call.clone(arguments=new_args)

        if self.remove:
            for cond in FindNodes(Conditional).visit(routine.body):
                if cond.inline and 'LHOOK' in as_tuple(cond.condition):
                    mapper[cond] = None

        routine.body = Transformer(mapper).visit(routine.body)

        # Get rid of unused import and variable
        if self.remove:
            remove_unused_drhook_import(routine)
loki-ecmwf-0.3.6/loki/transformations/temporaries/0000775000175000017500000000000015167130205022462 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/transformations/temporaries/__init__.py0000664000175000017500000000145415167130205024577 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
"""
Transformations sub-package that provides various transformations
handling temporaries.
"""

# from loki.transformations.temporaries import * # noqa
from loki.transformations.temporaries.hoist_variables import * # noqa
from loki.transformations.temporaries.pool_allocator import * # noqa
from loki.transformations.temporaries.stack_allocator import * # noqa
from loki.transformations.temporaries.raw_stack_allocator import * # noqa
loki-ecmwf-0.3.6/loki/transformations/temporaries/tests/0000775000175000017500000000000015167130205023624 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/transformations/temporaries/tests/test_raw_stack_allocator.py0000664000175000017500000004537615167130205031272 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki.backend import fgen
from loki.batch import Scheduler, SchedulerConfig
from loki.dimension import Dimension
from loki.expression import DeferredTypeSymbol, InlineCall, IntLiteral, ProcedureSymbol
from loki.frontend import available_frontends, OMNI
from loki.ir import FindNodes, CallStatement, Assignment, Pragma
from loki.sourcefile import Sourcefile
from loki.types import BasicType

from loki.transformations.pragma_model import PragmaModelTransformation
from loki.transformations.temporaries.raw_stack_allocator import TemporariesRawStackTransformation


@pytest.fixture(scope='module', name='block_dim')
def fixture_block_dim():
    return Dimension(name='block_dim', size='nb', index='b')

@pytest.fixture(scope='module', name='horizontal')
def fixture_horizontal():
    return Dimension(name='horizontal', size='nlon', index='jl', bounds=('jstart', 'jend'))

@pytest.mark.parametrize('directive', ['openacc', 'omp-gpu'])
@pytest.mark.parametrize('frontend', available_frontends())
def test_raw_stack_allocator_temporaries(frontend, block_dim, horizontal, directive, tmp_path):

    fcode_parkind_mod = """
module parkind1
  implicit none
  integer, parameter :: jprb = selected_real_kind(13,300)
  integer, parameter :: jpim = selected_int_kind(9)
  integer, parameter :: jplm = jpim
end module parkind1
    """.strip()

    fcode_yomphy_mod = """
module yomphy
  use parkind1, only: jpim
  implicit none
  type tphy
    integer(kind=jpim) :: n_spband
  end type tphy
end module yomphy
    """.strip()

    fcode_mf_phys_mod = """
module model_physics_mf_mod
  use yomphy, only: tphy
  implicit none
  type model_physics_mf_type
    type(tphy) :: yrphy
  end type model_physics_mf_type
end module model_physics_mf_mod
    """.strip()

    fcode_driver = """
module driver_mod
  contains
  subroutine driver(nlon, klev, nb, ydml_phy_mf)

    use parkind1, only: jpim, jprb

    use model_physics_mf_mod, only: model_physics_mf_type
    use kernel1_mod, only: kernel1

    implicit none

    type(model_physics_mf_type), intent(in) :: ydml_phy_mf

    integer(kind=jpim), intent(in) :: nlon
    integer(kind=jpim), intent(in) :: klev
    integer(kind=jpim), intent(in) :: nb

    integer(kind=jpim) :: jstart
    integer(kind=jpim) :: jend

    integer(kind=jpim) :: b

    real(kind=jprb), dimension(nlon, klev) :: zzz

    jstart = 1
    jend = nlon

    do b = 1, nb

        call kernel1(ydml_phy_mf, nlon, klev, jstart, jend, zzz)

    enddo

  end subroutine driver
end module driver_mod
    """.strip()

    fcode_kernel1 = """
module kernel1_mod
  contains
  subroutine kernel1(ydml_phy_mf, nlon, klev, jstart, jend, pzz)

    use parkind1, only: jpim, jprb

    use model_physics_mf_mod, only: model_physics_mf_type
    use kernel2_mod, only: kernel2
    use kernel3_mod, only: kernel3

    implicit none

    type(model_physics_mf_type), intent(in) :: ydml_phy_mf

    integer(kind=jpim), intent(in) :: nlon
    integer(kind=jpim), intent(in) :: klev

    integer(kind=jpim), intent(in) :: jstart
    integer(kind=jpim), intent(in) :: jend

    real(kind=jprb), intent(in), dimension(nlon, klev) :: pzz

    real(kind=jprb), dimension(nlon, klev) :: zzx
    real(kind=selected_real_kind(13,300)), dimension(nlon, klev) :: zzy
    logical, dimension(nlon, klev) :: zzl

    integer(kind=jpim) :: testint
    integer(kind=jpim) :: jl, jlev

    zzl = .false.
    do jl =1, nlon
      do jlev = 1, klev
        zzx(jl, jlev) = pzz(jl, jlev)
        zzy(jl, jlev) = pzz(jl, jlev)
      enddo
    enddo

    call kernel2(ydml_phy_mf%yrphy, nlon, klev, jstart, jend, testint)
    call kernel3(ydml_phy_mf%yrphy, nlon, klev, jstart, jend, pzz)

  end subroutine kernel1
end module kernel1_mod
    """.strip()

    fcode_kernel2 = """
module kernel2_mod
  contains
  subroutine kernel2(ydphy, nlon, klev, jstart, jend, testint)

      use parkind1, only: jpim, jprb

      use yomphy, only:  tphy

      implicit none

      type(tphy), intent(in) :: ydphy

      integer(kind=jpim), intent(in) :: nlon
      integer(kind=jpim), intent(in) :: klev
      integer(kind=jpim), intent(in) :: jstart
      integer(kind=jpim), intent(in) :: jend
      integer(kind=jpim), optional, intent(in) :: testint

      integer(kind=jpim) :: jb, jlev, jl

      real(kind=jprb) :: zde1(nlon, 0:klev, ydphy%n_spband)
      real(kind=jprb) :: zde2(nlon, klev, ydphy%n_spband)

      do jb = 1, ydphy%n_spband
        do jlev = 1, klev
          do jl = jstart, jend

            zde1(jl, jlev, jb) = 0._jprb
            zde2(jl, jlev, jb) = 0._jprb

          enddo
        enddo
      enddo

  end subroutine kernel2
end module kernel2_mod
    """.strip()

    fcode_kernel3 = """
module kernel3_mod
  contains
  subroutine kernel3(ydphy, nlon, klev, jstart, jend, pzz)

      use parkind1, only: jpim, jprb

      use yomphy, only:  tphy

      implicit none

      type(tphy), intent(in) :: ydphy

      integer(kind=jpim), intent(in) :: nlon
      integer(kind=jpim), intent(in) :: klev
      integer(kind=jpim), intent(in) :: jstart
      integer(kind=jpim), intent(in) :: jend

      real(kind=jprb), intent(in), dimension(nlon, klev) :: pzz

      integer(kind=jpim) :: jb, jlev, jl

      real(kind=jprb) :: zde1(nlon, 0:klev, ydphy%n_spband)
      real(kind=jprb) :: zde2(nlon, klev, ydphy%n_spband)
      real(kind=jprb) :: zde3(nlon, 1:klev)

!$loki device-present vars(pzz)

      do jb = 1, ydphy%n_spband
        zde1(:, 0, jb) = 0._jprb
        zde2(:, :, jb) = 0._jprb
        do jlev = 1, klev
          do jl = jstart, jend

            zde1(jl, jlev, jb) = 1._jprb
            zde2(jl, jlev, jb) = 0._jprb

          enddo
        enddo
      enddo

      zde3 = pzz
      zde3(1:nlon,1:klev) = pzz

!$loki end device-present

  end subroutine kernel3
end module kernel3_mod
    """.strip()

    (tmp_path/'driver.F90').write_text(fcode_driver)
    (tmp_path/'kernel1_mod.F90').write_text(fcode_kernel1)
    (tmp_path/'kernel2_mod.F90').write_text(fcode_kernel2)
    (tmp_path/'kernel3_mod.F90').write_text(fcode_kernel3)

    config = {
        'default': {
            'mode': 'idem',
            'role': 'kernel',
            'expand': True,
            'strict': True,
            'ignore': ['parkind1', 'model_physics_mf_mod', 'yomphy'],
        },
        'routines': {
            'driver': {'role': 'driver'}
        }
    }

    if frontend == OMNI:
        (tmp_path/'parkind_mod.F90').write_text(fcode_parkind_mod)
        parkind_mod = Sourcefile.from_file(tmp_path/'parkind_mod.F90', frontend=frontend, xmods=[tmp_path])
        (tmp_path/'yomphy_mod.F90').write_text(fcode_yomphy_mod)
        yomphy_mod = Sourcefile.from_file(tmp_path/'yomphy_mod.F90', frontend=frontend, xmods=[tmp_path])
        (tmp_path/'mf_phys_mod.F90').write_text(fcode_mf_phys_mod)
        mf_phys_mod = Sourcefile.from_file(tmp_path/'mf_phys_mod.F90', frontend=frontend, xmods=[tmp_path])
        definitions = parkind_mod.definitions + yomphy_mod.definitions + mf_phys_mod.definitions
    else:
        definitions = ()

    scheduler = Scheduler(paths=[tmp_path], config=SchedulerConfig.from_dict(config), frontend=frontend,
                          definitions=definitions, xmods=[tmp_path])

    transformation = TemporariesRawStackTransformation(block_dim=block_dim, horizontal=horizontal)
    scheduler.process(transformation=transformation)
    pragma_model_trafo = PragmaModelTransformation(directive=directive)
    scheduler.process(transformation=pragma_model_trafo)

    driver_item  = scheduler['driver_mod#driver']
    kernel1_item = scheduler['kernel1_mod#kernel1']
    kernel2_item = scheduler['kernel2_mod#kernel2']
    kernel3_item = scheduler['kernel3_mod#kernel3']

    assert transformation._key in kernel1_item.trafo_data

    jprb_stack_size = 'MAX(klev + ydml_phy_mf%yrphy%n_spband + 2*klev*ydml_phy_mf%yrphy%n_spband, '\
                        '2*klev + ydml_phy_mf%yrphy%n_spband + 2*klev*ydml_phy_mf%yrphy%n_spband)'
    srk_stack_size = 'MAX(2*klev + ydml_phy_mf%yrphy%n_spband + 2*klev*ydml_phy_mf%yrphy%n_spband, '\
                         '3*klev + ydml_phy_mf%yrphy%n_spband + 2*klev*ydml_phy_mf%yrphy%n_spband)'
    klev_stack_size = 'klev'

    real = BasicType.REAL
    logical = BasicType.LOGICAL
    jprb = DeferredTypeSymbol('JPRB')
    srk = InlineCall(function = ProcedureSymbol(name = 'SELECTED_REAL_KIND'),
                     parameters = (IntLiteral(13), IntLiteral(300)))

    stack_dict = kernel1_item.trafo_data[transformation._key]['stack_dict']

    assert real in stack_dict

    if frontend == OMNI:
        assert srk in stack_dict[real]
        assert fgen(stack_dict[real][srk]) == srk_stack_size
    else:
        assert jprb in stack_dict[real]
        assert fgen(stack_dict[real][jprb]) == jprb_stack_size
        assert srk in stack_dict[real]
        assert fgen(stack_dict[real][srk]) == klev_stack_size

    assert logical in stack_dict
    assert None in stack_dict[logical]
    assert fgen(stack_dict[logical][None]) == klev_stack_size

    driver = driver_item.ir
    kernel1 = kernel1_item.ir
    kernel2 = kernel2_item.ir
    kernel3 = kernel3_item.ir

    assert 'j_ll_stack_size' in driver.variable_map
    assert 'll_stack' in driver.variable_map

    assert 'j_z_selected_real_kind_13_300_stack_size' in driver.variable_map
    assert 'z_selected_real_kind_13_300_stack' in driver.variable_map

    if not frontend == OMNI:
        assert 'j_z_jprb_stack_size' in driver.variable_map
        assert 'z_jprb_stack' in driver.variable_map

    assert 'j_p_selected_real_kind_13_300_stack_used' in kernel1.variable_map
    assert 'k_p_selected_real_kind_13_300_stack_size' in kernel1.variable_map
    assert 'p_selected_real_kind_13_300_stack' in kernel1.variable_map

    assert 'j_ld_stack_used' in kernel1.variable_map
    assert 'k_ld_stack_size' in kernel1.variable_map
    assert 'ld_stack' in kernel1.variable_map

    if not frontend == OMNI:
        assert 'j_p_jprb_stack_used' in kernel1.variable_map
        assert 'k_p_jprb_stack_size' in kernel1.variable_map
        assert 'p_jprb_stack' in kernel1.variable_map

    assert 'jd_zzx' in kernel1.variable_map
    assert 'jd_zzy' in kernel1.variable_map
    assert 'jd_zzl' in kernel1.variable_map

    calls = FindNodes(CallStatement).visit(driver.body)

    if frontend == OMNI:
        assert fgen(calls[0].arguments).lower() == 'ydml_phy_mf\n'\
        'nlon\n'\
        'klev\n'\
        'jstart\n'\
        'jend\n'\
        'zzz\n'\
        'j_z_selected_real_kind_13_300_stack_size\n'\
        'z_selected_real_kind_13_300_stack(:, :, b)\n'\
        'j_ll_stack_size\n'\
        'll_stack(:, :, b)'
    else:
        assert fgen(calls[0].arguments).lower() == 'ydml_phy_mf\n'\
        'nlon\n'\
        'klev\n'\
        'jstart\n'\
        'jend\n'\
        'zzz\n'\
        'j_z_jprb_stack_size\n'\
        'z_jprb_stack(:, :, b)\n'\
        'j_z_selected_real_kind_13_300_stack_size\n'\
        'z_selected_real_kind_13_300_stack(:, :, b)\n'\
        'j_ll_stack_size\n'\
        'll_stack(:, :, b)'

    if frontend == OMNI:
        assert fgen(kernel1.arguments).lower() == 'ydml_phy_mf\n'\
        'nlon\n'\
        'klev\n'\
        'jstart\n'\
        'jend\n'\
        'pzz(nlon, klev)\n'\
        'k_p_selected_real_kind_13_300_stack_size\n'\
        'p_selected_real_kind_13_300_stack(nlon, k_p_selected_real_kind_13_300_stack_size)\n'\
        'k_ld_stack_size\n'\
        'ld_stack(nlon, k_ld_stack_size)'
    else:
        assert fgen(kernel1.arguments).lower() == 'ydml_phy_mf\n'\
        'nlon\n'\
        'klev\n'\
        'jstart\n'\
        'jend\n'\
        'pzz(nlon, klev)\n'\
        'k_p_jprb_stack_size\n'\
        'p_jprb_stack(nlon, k_p_jprb_stack_size)\n'\
        'k_p_selected_real_kind_13_300_stack_size\n'\
        'p_selected_real_kind_13_300_stack(nlon, k_p_selected_real_kind_13_300_stack_size)\n'\
        'k_ld_stack_size\n'\
        'ld_stack(nlon, k_ld_stack_size)'

    calls = FindNodes(CallStatement).visit(kernel1.body)

    if frontend == OMNI:
        assert fgen(calls[0].arguments).lower() == 'ydml_phy_mf%yrphy\n'\
        'nlon\n'\
        'klev\n'\
        'jstart\n'\
        'jend\n'\
        'k_p_selected_real_kind_13_300_stack_size - j_p_selected_real_kind_13_300_stack_used\n'\
        'p_selected_real_kind_13_300_stack'\
        '(1:nlon, j_p_selected_real_kind_13_300_stack_used + 1:k_p_selected_real_kind_13_300_stack_size)\n'\
        'testint'
    else:
        assert fgen(calls[0].arguments).lower() == 'ydml_phy_mf%yrphy\n'\
        'nlon\n'\
        'klev\n'\
        'jstart\n'\
        'jend\n'\
        'k_p_jprb_stack_size - j_p_jprb_stack_used\n'\
        'p_jprb_stack(1:nlon, j_p_jprb_stack_used + 1:k_p_jprb_stack_size)\n'\
        'testint'

    if frontend == OMNI:
        assert fgen(kernel2.arguments).lower() == 'ydphy\n'\
        'nlon\n'\
        'klev\n'\
        'jstart\n'\
        'jend\n'\
        'k_p_selected_real_kind_13_300_stack_size\n'\
        'p_selected_real_kind_13_300_stack(nlon, k_p_selected_real_kind_13_300_stack_size)\n'\
        'testint'
    else:
        assert fgen(kernel2.arguments).lower() == 'ydphy\n'\
        'nlon\n'\
        'klev\n'\
        'jstart\n'\
        'jend\n'\
        'k_p_jprb_stack_size\n'\
        'p_jprb_stack(nlon, k_p_jprb_stack_size)\n'\
        'testint'

    assignments = FindNodes(Assignment).visit(driver.body)

    lhs = [fgen(a.lhs).lower() for a in assignments]

    assert 'j_z_selected_real_kind_13_300_stack_size' in lhs
    assert 'j_ll_stack_size' in lhs
    if not frontend == OMNI:
        assert 'j_z_jprb_stack_size' in lhs

    for a in assignments:

        if fgen(a.lhs).lower() == 'j_z_selected_real_kind_13_300_stack_size':
            if frontend == OMNI:
                assert fgen(a.rhs).lower() == srk_stack_size.lower()
            else:
                assert fgen(a.rhs).lower() == klev_stack_size.lower()

        if fgen(a.lhs).lower() == 'j_ll_stack_size':
            assert fgen(a.rhs).lower() == klev_stack_size.lower()

        if fgen(a.lhs).lower() == 'j_z_jprb_stack_size':
            assert fgen(a.rhs).lower() == jprb_stack_size.lower()

    assignments = FindNodes(Assignment).visit(kernel3.body)

    assert assignments[0].lhs == 'jd_zde1'
    assert assignments[0].rhs == '0'

    assert assignments[1].lhs == 'jd_zde2'
    assert assignments[1].rhs == 'jd_zde1 + ydphy%n_spband + klev*ydphy%n_spband'

    assert assignments[2].lhs == 'jd_zde3'
    assert assignments[2].rhs == 'jd_zde2 + klev*ydphy%n_spband'

    if frontend == OMNI:
        assert assignments[3].lhs == 'j_p_selected_real_kind_13_300_stack_used'
        assert assignments[3].rhs == 'jd_zde3 + klev'

        assert assignments[4].lhs == 'p_selected_real_kind_13_300_stack(:, jd_zde1 + jb - klev + jb*klev)'
        assert fgen(assignments[4].rhs) == '0._jprb'  # Need fgen for kind specified

        assert assignments[5].lhs == 'p_selected_real_kind_13_300_stack'\
            '(:, jd_zde2 + 1 - klev + jb*klev:jd_zde2 + jb*klev)'
        assert fgen(assignments[5].rhs) == '0._jprb'

        assert assignments[6].lhs == 'p_selected_real_kind_13_300_stack'\
            '(jl, jd_zde1 + jlev + jb - klev + jb*klev)'
        assert fgen(assignments[6].rhs) == '1._jprb'

        assert assignments[7].lhs == 'p_selected_real_kind_13_300_stack'\
            '(jl, jd_zde2 + jlev - klev + jb*klev)'
        assert fgen(assignments[7].rhs) == '0._jprb'

        assert assignments[8].lhs == 'p_selected_real_kind_13_300_stack'\
            '(1:nlon, jd_zde3 + 1:jd_zde3 + klev)'
        assert assignments[8].rhs == 'pzz'

        assert assignments[9].lhs == 'p_selected_real_kind_13_300_stack'\
            '(1:nlon, jd_zde3 + 1:jd_zde3 + klev)'
        assert assignments[9].rhs == 'pzz'
    else:
        assert assignments[3].lhs == 'j_p_jprb_stack_used'
        assert assignments[3].rhs == 'jd_zde3 + klev'

        assert assignments[4].lhs == 'p_jprb_stack(:, jd_zde1 + jb - klev + jb*klev)'
        assert fgen(assignments[4].rhs) == '0._jprb'  # Need fgen for kind specified

        assert assignments[5].lhs == 'p_jprb_stack(:, jd_zde2 + 1 - klev + jb*klev:jd_zde2 + jb*klev)'
        assert fgen(assignments[5].rhs) == '0._jprb'

        assert assignments[6].lhs == 'p_jprb_stack(jl, jd_zde1 + jlev + jb - klev + jb*klev)'
        assert fgen(assignments[6].rhs) == '1._jprb'

        assert assignments[7].lhs == 'p_jprb_stack(jl, jd_zde2 + jlev - klev + jb*klev)'
        assert fgen(assignments[7].rhs) == '0._jprb'

        assert assignments[8].lhs == 'p_jprb_stack(1:nlon, jd_zde3 + 1:jd_zde3 + klev)'
        assert assignments[8].rhs == 'pzz'

        assert assignments[9].lhs == 'p_jprb_stack(1:nlon, jd_zde3 + 1:jd_zde3 + klev)'
        assert assignments[9].rhs == 'pzz'

    if directive in ['openacc', 'omp-gpu']:
        pragmas = FindNodes(Pragma).visit(driver.body)

        if directive == 'openacc':
            if frontend == OMNI:
                assert pragmas[0].content.lower() == 'data create(z_selected_real_kind_13_300_stack, ll_stack)'
            else:
                assert pragmas[0].content.lower() == 'data create(z_jprb_stack, '\
                                                     'z_selected_real_kind_13_300_stack, ll_stack)'

        if directive == 'openmp':
            if frontend == OMNI:
                assert pragmas[0].content.lower() == 'target allocate(z_selected_real_kind_13_300_stack, ll_stack)'
            else:
                assert pragmas[0].content.lower() == 'target allocate(z_jprb_stack, '\
                                                     'z_selected_real_kind_13_300_stack, ll_stack)'

        if directive == 'omp-gpu':
            if frontend == OMNI:
                assert pragmas[0].content.lower() == \
                        'target data map(alloc: z_selected_real_kind_13_300_stack, ll_stack)'
            else:
                assert pragmas[0].content.lower() == 'target data map(alloc: z_jprb_stack, '\
                                                     'z_selected_real_kind_13_300_stack, ll_stack)'

    if directive == 'openacc':
        pragmas = FindNodes(Pragma).visit(kernel1.body)
        if frontend == OMNI:
            assert pragmas[0].content.lower() == 'data present(p_selected_real_kind_13_300_stack, ld_stack)'
        else:
            assert pragmas[0].content.lower() == 'data present(p_jprb_stack, '\
                                                 'p_selected_real_kind_13_300_stack, ld_stack)'

        pragmas = FindNodes(Pragma).visit(kernel3.body)
        if frontend == OMNI:
            assert pragmas[0].content.lower() == 'data present(p_selected_real_kind_13_300_stack, pzz)'
        else:
            assert pragmas[0].content.lower() == 'data present(p_jprb_stack, pzz)'
loki-ecmwf-0.3.6/loki/transformations/temporaries/tests/__init__.py0000664000175000017500000000057015167130205025737 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
loki-ecmwf-0.3.6/loki/transformations/temporaries/tests/test_stack_allocator.py0000664000175000017500000003526515167130205030415 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki.batch import Scheduler, SchedulerConfig
from loki.dimension import Dimension
from loki.expression import parse_expr
from loki.frontend import available_frontends, OMNI
from loki.ir import FindNodes, nodes as ir
from loki.sourcefile import Sourcefile
from loki.transformations.pragma_model import PragmaModelTransformation

from loki.transformations.temporaries import FtrPtrStackTransformation, DirectIdxStackTransformation

@pytest.fixture(scope='module', name='block_dim')
def fixture_block_dim():
    return Dimension(name='block_dim', size='nb', index='b')

@pytest.fixture(scope='module', name='horizontal')
def fixture_horizontal():
    return Dimension(name='horizontal', size='nlon', index='jl', bounds=('jstart', 'jend'))

@pytest.mark.parametrize('directive', ['openacc', 'omp-gpu'])
@pytest.mark.parametrize('stack_trafo', [FtrPtrStackTransformation, DirectIdxStackTransformation])
@pytest.mark.parametrize('frontend', available_frontends(skip=[(OMNI, 'Inlines kind parameters.')]))
@pytest.mark.parametrize('stack_insert_loc_pragma', [True, False])
def test_raw_stack_allocator_temporaries(frontend, block_dim, horizontal, directive, stack_trafo,
                                         tmp_path, stack_insert_loc_pragma):

    fcode_parkind_mod = """
module parkind1
  implicit none
  integer, parameter :: jprb = selected_real_kind(13,300)
  integer, parameter :: jpim = selected_int_kind(9)
  integer, parameter :: jplm = jpim
end module parkind1
    """.strip()

    fcode_yomphy_mod = """
module yomphy
  use parkind1, only: jpim
  implicit none
  type tphy
    integer(kind=jpim) :: n_spband
  end type tphy
end module yomphy
    """.strip()

    fcode_mf_phys_mod = """
module model_physics_mf_mod
  use yomphy, only: tphy
  implicit none
  integer, parameter :: tlen = 2
  type model_physics_mf_type
    type(tphy) :: yrphy
  end type model_physics_mf_type
end module model_physics_mf_mod
    """.strip()

    fcode_driver = f"""
module driver_mod
  contains
  subroutine driver(nlon, klev, nb, ydml_phy_mf)

    use parkind1, only: jpim, jprb

    use model_physics_mf_mod, only: model_physics_mf_type
    use kernel1_mod, only: kernel1

    implicit none

    type(model_physics_mf_type), intent(in) :: ydml_phy_mf

    integer(kind=jpim), intent(in) :: nlon
    integer(kind=jpim), intent(in) :: klev
    integer(kind=jpim), intent(in) :: nb

    integer(kind=jpim) :: jstart
    integer(kind=jpim) :: jend

    integer(kind=jpim) :: b

    real(kind=jprb), dimension(nlon, klev) :: zzz


    !$loki sep

    {'!$loki stack-insert' if stack_insert_loc_pragma else ''}

    jstart = 1
    jend = nlon

    do b = 1, nb

        call kernel1(ydml_phy_mf, nlon, klev, jstart, jend, zzz)

    enddo

  end subroutine driver
end module driver_mod
    """.strip()

    fcode_kernel1 = """
module kernel1_mod
  contains
  subroutine kernel1(ydml_phy_mf, nlon, klev, jstart, jend, pzz)

    use parkind1, only: jpim, jprb

    use model_physics_mf_mod, only: model_physics_mf_type, tlen
    use kernel2_mod, only: kernel2
    use kernel3_mod, only: kernel3

    implicit none

    type(model_physics_mf_type), intent(in) :: ydml_phy_mf

    integer(kind=jpim), intent(in) :: nlon
    integer(kind=jpim), intent(in) :: klev

    integer(kind=jpim), intent(in) :: jstart
    integer(kind=jpim), intent(in) :: jend

    real(kind=jprb), intent(in), dimension(nlon, klev) :: pzz

    real(kind=jprb), dimension(nlon, klev) :: zzx
    real(kind=selected_real_kind(13,300)), dimension(nlon, klev) :: zzy
    logical, dimension(nlon, klev) :: zzl
    logical, dimension(nlon, tlen) :: zzl2

    integer(kind=jpim) :: testint
    integer(kind=jpim) :: jl, jlev

    zzl = .false.
    zzl2 = .true.
    do jl =1, nlon
      do jlev = 1, klev
        zzx(jl, jlev) = pzz(jl, jlev)
        zzy(jl, jlev) = pzz(jl, jlev)
      enddo
    enddo

    call kernel2(ydml_phy_mf%yrphy, nlon, klev, jstart, jend, testint)
    call kernel3(ydml_phy_mf%yrphy, nlon, klev, jstart, jend, pzz)

  end subroutine kernel1
end module kernel1_mod
    """.strip()

    fcode_kernel2 = """
module kernel2_mod
  contains
  subroutine kernel2(ydphy, nlon, klev, jstart, jend, testint)

      use parkind1, only: jpim, jprb

      use yomphy, only:  tphy

      implicit none

      type(tphy), intent(in) :: ydphy

      integer(kind=jpim), intent(in) :: nlon
      integer(kind=jpim), intent(in) :: klev
      integer(kind=jpim), intent(in) :: jstart
      integer(kind=jpim), intent(in) :: jend
      integer(kind=jpim), optional, intent(in) :: testint

      integer(kind=jpim) :: jb, jlev, jl

      real(kind=jprb) :: zde1(nlon, 0:klev, ydphy%n_spband)
      real(kind=jprb) :: zde2(nlon, klev, ydphy%n_spband)

      do jb = 1, ydphy%n_spband
        do jlev = 1, klev
          do jl = jstart, jend

            zde1(jl, jlev, jb) = 0._jprb
            zde2(jl, jlev, jb) = 0._jprb

          enddo
        enddo
      enddo

  end subroutine kernel2
end module kernel2_mod
    """.strip()

    fcode_kernel3 = """
module kernel3_mod
  contains
  subroutine kernel3(ydphy, nlon, klev, jstart, jend, pzz)

      use parkind1, only: jpim, jprb

      use yomphy, only:  tphy

      implicit none

      type(tphy), intent(in) :: ydphy

      integer(kind=jpim), intent(in) :: nlon
      integer(kind=jpim), intent(in) :: klev
      integer(kind=jpim), intent(in) :: jstart
      integer(kind=jpim), intent(in) :: jend

      real(kind=jprb), intent(in), dimension(nlon, klev) :: pzz

      integer(kind=jpim) :: jb, jlev, jl

      real(kind=jprb) :: zde1(nlon, 0:klev, ydphy%n_spband)
      real(kind=jprb) :: zde2(nlon, klev, ydphy%n_spband)
      real(kind=jprb) :: zde3(nlon, 1:klev)

!$loki device-present vars(pzz)

      do jb = 1, ydphy%n_spband
        zde1(:, 0, jb) = 0._jprb
        zde2(:, :, jb) = 0._jprb
        do jlev = 1, klev
          do jl = jstart, jend

            zde1(jl, jlev, jb) = 1._jprb
            zde2(jl, jlev, jb) = 0._jprb

          enddo
        enddo
      enddo

      zde3 = pzz
      zde3(1:nlon,1:klev) = pzz

!$loki end device-present

  end subroutine kernel3
end module kernel3_mod
    """.strip()

    (tmp_path/'driver.F90').write_text(fcode_driver)
    (tmp_path/'kernel1_mod.F90').write_text(fcode_kernel1)
    (tmp_path/'kernel2_mod.F90').write_text(fcode_kernel2)
    (tmp_path/'kernel3_mod.F90').write_text(fcode_kernel3)

    config = {
        'default': {
            'mode': 'idem',
            'role': 'kernel',
            'expand': True,
            'strict': True,
            'ignore': ['parkind1', 'model_physics_mf_mod', 'yomphy'],
        },
        'routines': {
            'driver': {'role': 'driver'}
        }
    }

    (tmp_path/'parkind_mod.F90').write_text(fcode_parkind_mod)
    parkind_mod = Sourcefile.from_file(tmp_path/'parkind_mod.F90', frontend=frontend, xmods=[tmp_path])
    (tmp_path/'yomphy_mod.F90').write_text(fcode_yomphy_mod)
    yomphy_mod = Sourcefile.from_file(tmp_path/'yomphy_mod.F90', frontend=frontend, xmods=[tmp_path])
    (tmp_path/'mf_phys_mod.F90').write_text(fcode_mf_phys_mod)
    mf_phys_mod = Sourcefile.from_file(tmp_path/'mf_phys_mod.F90', frontend=frontend, xmods=[tmp_path])
    definitions = parkind_mod.definitions + yomphy_mod.definitions + mf_phys_mod.definitions

    scheduler = Scheduler(paths=[tmp_path], config=SchedulerConfig.from_dict(config), frontend=frontend,
                          definitions=definitions, xmods=[tmp_path])

    transformation = stack_trafo(block_dim=block_dim, horizontal=horizontal)
    scheduler.process(transformation=transformation)
    pragma_model_trafo = PragmaModelTransformation(directive=directive)
    scheduler.process(transformation=pragma_model_trafo)

    driver_item  = scheduler['driver_mod#driver']
    kernel1_item = scheduler['kernel1_mod#kernel1']
    kernel2_item = scheduler['kernel2_mod#kernel2']
    kernel3_item = scheduler['kernel3_mod#kernel3']

    driver = driver_item.ir
    kernel1 = kernel1_item.ir
    kernel2 = kernel2_item.ir
    kernel3 = kernel3_item.ir

    directive_keyword_map = {'openacc': 'acc', 'omp-gpu': 'omp'}
    driver_var_map = driver.variable_map
    stack_size_vars = ['j_z_jprb_stack_size', 'j_z_selected_real_kind_13_300_stack_size', 'j_ll_stack_size']
    stack_used_vars = ['j_z_jprb_stack_used', 'j_z_selected_real_kind_13_300_stack_used', 'j_ll_stack_used']
    stack_vars = ['z_jprb_stack', 'll_stack', 'z_selected_real_kind_13_300_stack']
    stack_vars_size = {
        'z_jprb_stack': ('MAX(klev*nlon + nlon*ydml_phy_mf%yrphy%n_spband + '
                         '2*klev*nlon*ydml_phy_mf%yrphy%n_spband, 2*klev*nlon + '
                         'nlon*ydml_phy_mf%yrphy%n_spband + 2*klev*nlon*ydml_phy_mf%yrphy%n_spband)'),
        'll_stack': 'klev*nlon + nlon*tlen',
        'z_selected_real_kind_13_300_stack': 'klev*nlon'
    }

    for stack_var in stack_vars:
        assert stack_var in driver_var_map
        assert driver_var_map[stack_var].type.allocatable
        assert len(driver_var_map[stack_var].dimensions) == 2

    driver_imports = FindNodes(ir.Import).visit(driver.spec)
    driver_imported_symbols = []
    for _import in driver_imports:
        driver_imported_symbols.extend([sym.name.lower() for sym in _import.symbols])
    assert 'tlen' in driver_imported_symbols

    driver_allocs = FindNodes(ir.Allocation).visit(driver.body)
    for driver_alloc in driver_allocs:
        var = driver_alloc.variables[0]
        assert var.name.lower() in stack_vars
        assert var.dimensions[1] == 'nb'
        assert var.dimensions[0] == parse_expr(stack_vars_size[var.name.lower()])

    driver_deallocs = FindNodes(ir.Deallocation).visit(driver.body)
    for driver_dealloc in driver_deallocs:
        var = driver_dealloc.variables[0]
        assert var.name.lower() in stack_vars
        assert var.dimensions == ()

    driver_pragmas = FindNodes(ir.Pragma).visit(driver.body)
    relevant_pragma = 0 if not stack_insert_loc_pragma else 1
    assert len(driver_pragmas) == 3
    assert driver_pragmas[relevant_pragma].keyword.lower() == directive_keyword_map[directive]
    # target enter data map(alloc: z_jprb_stack, z_selected_real_kind_13_300_stack, ll_stack)
    assert driver_pragmas[2].keyword.lower() == directive_keyword_map[directive]
    if directive == 'openacc':
        assert 'enter data create' in driver_pragmas[relevant_pragma].content.lower()
        assert 'exit data delete' in driver_pragmas[2].content.lower()
    if directive == 'omp-gpu':
        assert 'target enter data map(alloc:' in driver_pragmas[relevant_pragma].content.lower()
        assert 'target exit data map(delete:' in driver_pragmas[2].content.lower()
    for stack_var in stack_vars:
        assert stack_var in driver_pragmas[relevant_pragma].content.lower()
        assert stack_var in driver_pragmas[2].content.lower()

    driver_calls = FindNodes(ir.CallStatement).visit(driver.body)
    assert len(driver_calls) == 1
    driver_arg_map = {v.name.lower(): k for k,v in driver_calls[0].arg_map.items()}
    for stack_var in stack_vars:
        assert stack_var in driver_arg_map
    for stack_size_var in stack_size_vars:
        assert stack_size_var in driver_arg_map
    for stack_used_var in stack_used_vars:
        assert stack_used_var in driver_arg_map

    jprb_stack = {'size': 'k_p_jprb_stack_size', 'stack': 'p_jprb_stack', 'used': 'jd_p_jprb_stack_used'}
    selected_real_kind_stack = {'size': 'k_p_selected_real_kind_13_300_stack_size',
                                'stack': 'p_selected_real_kind_13_300_stack',
                                'used': 'jd_p_selected_real_kind_13_300_stack_used'}
    l_stack = {'size': 'k_ld_stack_size', 'stack': 'ld_stack', 'used': 'jd_ld_stack_used'}

    kernel1_args = [arg.name.lower() for arg in kernel1.arguments]
    kernel2_args = [arg.name.lower() for arg in kernel2.arguments]
    kernel3_args = [arg.name.lower() for arg in kernel3.arguments]
    for var in list(jprb_stack.values()) + list(selected_real_kind_stack.values()) + list(l_stack.values()):
        assert var in kernel1_args
    for var in jprb_stack.values():
        assert var in kernel2_args
        assert var in kernel3_args
    for var in list(selected_real_kind_stack.values()) + list(l_stack.values()):
        assert var not in kernel2_args
        assert var not in kernel3_args

    kernel1_pragmas = FindNodes(ir.Pragma).visit(kernel1.body)
    assert len(kernel1_pragmas) == 2
    if directive == 'openacc':
        assert kernel1_pragmas[0].keyword == 'acc'
        assert 'data present(' in kernel1_pragmas[0].content
    else:
        assert kernel1_pragmas[0].keyword == 'loki'
        assert 'device-present vars(' in kernel1_pragmas[0].content
        assert 'p_jprb_stack, p_selected_real_kind_13_300_stack, ld_stack' in kernel1_pragmas[0].content.lower()

    if stack_trafo == FtrPtrStackTransformation:
        kernel1_incr_vars = {'jprb': ('jd_incr_jprb',),
                             'selected_real': ('jd_incr_selected_real_kind_13_300',), 'l': ('jd_incr',)}
        kernel2_incr_vars = {'jprb': ('jd_incr_jprb',)}
        kernel3_incr_vars = {'jprb': ('jd_incr_jprb',)}
    else: # DirectIdxStackTransformation
        kernel1_incr_vars = {'jprb': ('jd_zzx',), 'selected_real': ('jd_zzy',), 'l': ('jd_zzl',)}
        kernel2_incr_vars = {'jprb': ('jd_zde1', 'jd_zde2')}
        kernel3_incr_vars = {'jprb': ('jd_zde1', 'jd_zde2', 'jd_zde3')}
    for kernel1_incr_var in kernel1_incr_vars['jprb'] + kernel1_incr_vars['selected_real'] + kernel1_incr_vars['l']:
        assert kernel1_incr_var in kernel1.variables
    for var in kernel2_incr_vars['jprb']:
        assert var in kernel2.variables
    for var in kernel3_incr_vars['jprb']:
        assert var in kernel3.variables
    kernels_stack_used_args = {'jprb': 'jd_p_jprb_stack_used',
                               'selected_real': 'jd_p_selected_real_kind_13_300_stack_used',
                               'l': 'jd_ld_stack_used'}
    kernels_stack_used_vars = {'jprb': 'j_p_jprb_stack_used',
                               'selected_real': 'j_p_selected_real_kind_13_300_stack_used',
                               'l': 'j_ld_stack_used'}
    kernel1_assignments_map = {}
    for assign in FindNodes(ir.Assignment).visit(kernel1.body):
        kernel1_assignments_map.setdefault(assign.lhs.name.lower(), []).append(assign.rhs)
    for key, kernels_stack_used_var in kernels_stack_used_vars.items():
        assert kernel1_assignments_map[kernels_stack_used_var][0] == kernels_stack_used_args[key]
        assert kernel1_assignments_map[kernel1_incr_vars[key][0]][0] == kernels_stack_used_var
loki-ecmwf-0.3.6/loki/transformations/temporaries/tests/test_hoist_variables.py0000664000175000017500000007601215167130205030421 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
A selection of tests for the (generic) hoist variables functionalities.
"""
from pathlib import Path
import pytest
import numpy as np

from loki import (
    Scheduler, SchedulerConfig, is_iterable, FindInlineCalls
)
from loki.jit_build import jit_compile_lib, Builder
from loki.frontend import available_frontends
from loki.ir import nodes as ir, FindNodes
from loki.transformations.temporaries.hoist_variables import (
    HoistVariablesAnalysis, HoistVariablesTransformation,
    HoistTemporaryArraysAnalysis, HoistTemporaryArraysTransformationAllocatable
)


@pytest.fixture(scope='module', name='here')
def fixture_here():
    return Path(__file__).parent


@pytest.fixture(scope='module', name='testdir')
def fixture_testdir(here):
    return here.parent.parent.parent/'tests'


@pytest.fixture(name='config')
def fixture_config():
    """
    Default configuration dict with basic options.
    """
    return {
        'default': {
            'mode': 'idem',
            'role': 'kernel',
            'expand': True,
            'strict': True,
        },
        'routines': {
            'driver': {
                'role': 'driver',
                'expand': True,
            },
            'another_driver': {
                'role': 'driver',
                'expand': True,
            },
            'yet_another_driver': {
                'role': 'driver',
                'expand': True,
            },
            'inline_driver': {
                'role': 'driver',
                'expand': True,
            },
        }
    }


def compile_and_test(scheduler, path, a=(5,), frontend="",  test_name="", items=None, inline=False):
    """
    Compile the source code and call the driver function in order to test the results for correctness.
    """
    assert is_iterable(a) and all(isinstance(_a, int) for _a in a)
    path = Path(path)
    if not items:
        items = [scheduler["transformation_module_hoist#driver"], scheduler["subroutines_mod#kernel1"]]
    for item in items:
        suffix = '.F90'
        item.source.path = (path/f"{item.source.path.stem}").with_suffix(suffix=suffix)
    libname = f'lib_{test_name}_{frontend}'
    builder = Builder(source_dirs=path, build_dir=path)
    lib = jit_compile_lib([item.source for item in items], path=path, name=libname, builder=builder)
    item = items[0]
    for _a in a:
        parameter_length = 3
        b = np.zeros((_a,), dtype=np.int32, order='F')
        c = np.zeros((_a, parameter_length), dtype=np.int32, order='F')
        if inline:
            lib.Transformation_Module_Hoist_Inline.inline_driver(_a, b, c)
        else:
            lib.Transformation_Module_Hoist.driver(_a, b, c)
        assert (b == 42).all()
        assert (c == 11).all()
    builder.clean()


def check_arguments(scheduler, subroutine_arguments, call_arguments, call_kwarguments, driver_item=None,
                    driver_name=None, include_device_functions=False, include_another_driver=True,
                    subroutine_mod=None):
    """
    Check the subroutine and call arguments of each subroutine.
    """
    # driver
    if not driver_item:
        driver_item = scheduler['transformation_module_hoist#driver']
    if not driver_name:
        driver_name = "driver"

    assert [arg.name for arg in driver_item.ir.arguments] == subroutine_arguments[driver_name]
    for call in FindNodes(ir.CallStatement).visit(driver_item.ir.body):
        if "kernel1" in call.name:
            assert call.arguments == call_arguments["kernel1"]
            assert call.kwarguments == call_kwarguments["kernel1"]
        elif "kernel2" in call.name:
            assert call.arguments == call_arguments["kernel2"]
            assert call.kwarguments == call_kwarguments["kernel2"]
    # another driver
    if include_another_driver:
        item = scheduler['transformation_module_hoist#another_driver']
        assert [arg.name for arg in item.ir.arguments] == subroutine_arguments["another_driver"]
        for call in FindNodes(ir.CallStatement).visit(item.ir.body):
            if "kernel1" in call.name:
                assert call.arguments == call_arguments["kernel1"]
                assert call.kwarguments == call_kwarguments["kernel1"]
    # kernel 1
    if not subroutine_mod:
        subroutine_mod = 'subroutines_mod'

    item = scheduler[subroutine_mod + '#kernel1']
    assert [arg.name for arg in item.ir.arguments] == subroutine_arguments["kernel1"]

    for call in FindInlineCalls().visit(item.ir.body):
        if 'func1' in call.name:
            assert call.arguments == call_arguments["func1"]
            assert call.kwarguments == call_kwarguments["func1"]

    # kernel 2
    item = scheduler[subroutine_mod + '#kernel2']
    assert [arg.name for arg in item.ir.arguments] == subroutine_arguments["kernel2"]
    for call in FindNodes(ir.CallStatement).visit(item.ir.body):
        if "device1" in call.name:
            assert call.arguments == call_arguments["device1"]
            assert call.kwarguments == call_kwarguments["device1"]
        elif "device2" in call.name:
            assert call.arguments == call_arguments["device2"]
            assert call.kwarguments == call_kwarguments["device2"]
    if include_device_functions:
        # device 1
        item = scheduler[subroutine_mod + '#device1']
        assert [arg.name for arg in item.ir.arguments] == subroutine_arguments["device1"]
        for call in FindNodes(ir.CallStatement).visit(item.ir.body):
            if "device2" in call.name:
                assert call.arguments == call_arguments["device2"]
                assert call.kwarguments == call_kwarguments["device2"]
        # device 2
        item = scheduler[subroutine_mod + '#device2']
        assert [arg.name for arg in item.ir.arguments] == subroutine_arguments["device2"]

        for call in FindInlineCalls().visit(item.ir.body):
            if 'init_int' in call.name:
                assert call.arguments == call_arguments["init_int"]
                assert call.kwarguments == call_kwarguments["init_int"]


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('as_kwarguments', [False, True])
def test_hoist(tmp_path, testdir, frontend, config, as_kwarguments):
    """
    Basic testing of the non-modified Hoist functionality, thus hoisting all (non-parameter) local variables.
    """

    proj = testdir/'sources/projHoist'
    scheduler = Scheduler(
        paths=[proj], config=config, seed_routines=['driver', 'another_driver'], frontend=frontend, xmods=[tmp_path]
    )

    # check correctness of original source code
    compile_and_test(scheduler=scheduler, path=tmp_path, frontend=frontend, a=(5, 10, 100), test_name="source")

    # Transformation: Analysis
    scheduler.process(transformation=HoistVariablesAnalysis())
    # Transformation: Synthesis
    scheduler.process(transformation=HoistVariablesTransformation(as_kwarguments=as_kwarguments))

    # check generated source code
    subroutine_arguments = {
        "driver": ['a', 'b', 'c'],
        "another_driver": ['a', 'b', 'c'],
        "kernel1": ['a', 'b', 'c', 'x', 'y', 'k1_tmp'],
        "kernel2": ['a1', 'b', 'x', 'y', 'z', 'k2_tmp', 'device1_z', 'device1_d1_tmp', 'device2_z', 'device2_d2_tmp'],
        "device1": ['a1', 'b', 'x', 'y', 'z', 'd1_tmp', 'device2_z', 'device2_d2_tmp'],
        "device2": ['a2', 'b', 'x', 'z', 'd2_tmp'],
    }

    call_arguments = {
        "kernel1": ('a', 'b', 'c'),
        "kernel2": ('a', 'b'),
        "device1": ('a1', 'b', 'x', 'k2_tmp'),
        "device2": ('a1', 'b', 'x')
    }
    if not as_kwarguments:
        call_arguments["kernel1"] += ('kernel1_x', 'kernel1_y', 'kernel1_k1_tmp')
        call_arguments["kernel2"] += ('kernel2_x', 'kernel2_y', 'kernel2_z', 'kernel2_k2_tmp',
                'device1_z', 'device1_d1_tmp', 'device2_z', 'device2_d2_tmp')
        call_arguments["device1"] += ('device1_z', 'device1_d1_tmp', 'device2_z', 'device2_d2_tmp')
        call_arguments["device2"] += ('device2_z', 'device2_d2_tmp')

    call_kwarguments = {
        "kernel1": (('x', 'kernel1_x'), ('y', 'kernel1_y'), ('k1_tmp', 'kernel1_k1_tmp')) if as_kwarguments else (),
        "kernel2": (('x', 'kernel2_x'), ('y', 'kernel2_y'), ('z', 'kernel2_z'), ('k2_tmp', 'kernel2_k2_tmp'),
            ('device1_z', 'device1_z'), ('device1_d1_tmp', 'device1_d1_tmp'),
            ('device2_z', 'device2_z'), ('device2_d2_tmp', 'device2_d2_tmp')) if as_kwarguments else (),
        "device1": (('z', 'device1_z'), ('d1_tmp', 'device1_d1_tmp'), ('device2_z', 'device2_z'),
            ('device2_d2_tmp', 'device2_d2_tmp')) if as_kwarguments else (),
        "device2": (('z', 'device2_z'), ('d2_tmp', 'device2_d2_tmp')) if as_kwarguments else ()
    }

    check_arguments(scheduler=scheduler, subroutine_arguments=subroutine_arguments, call_arguments=call_arguments,
            call_kwarguments=call_kwarguments)
    compile_and_test(scheduler=scheduler, path=tmp_path, a=(5, 10, 100), frontend=frontend, test_name="all_hoisted")


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('as_kwarguments', [False, True])
def test_hoist_disable(tmp_path, testdir, frontend, config, as_kwarguments):
    """
    Basic testing of the non-modified Hoist functionality excluding/disabling some subroutines,
    thus hoisting all (non-parameter) local variables for the non-disabled subroutines.
    """

    disable = ("device1", "device2")
    config['routines']['kernel2'] = {'role': 'kernel', 'block': disable}
    proj = testdir/'sources/projHoist'
    scheduler = Scheduler(
        paths=[proj], config=config, seed_routines=['driver', 'another_driver'], frontend=frontend, xmods=[tmp_path]
    )

    # Transformation: Analysis
    scheduler.process(transformation=HoistVariablesAnalysis())
    # Transformation: Synthesis
    scheduler.process(transformation=HoistVariablesTransformation(as_kwarguments=as_kwarguments))

    # check generated source code
    subroutine_arguments = {
        "driver": ['a', 'b', 'c'],
        "another_driver": ['a', 'b', 'c'],
        "kernel1": ['a', 'b', 'c', 'x', 'y', 'k1_tmp'],
        "kernel2": ['a1', 'b', 'x', 'y', 'z', 'k2_tmp'],
        "device1": ['a1', 'b', 'x', 'y'],
        "device2": ['a2', 'b', 'x'],
    }

    call_arguments = {
        "kernel1": ('a', 'b', 'c', 'kernel1_x', 'kernel1_y', 'kernel1_k1_tmp'),
        "kernel2": ('a', 'b', 'kernel2_x', 'kernel2_y', 'kernel2_z', 'kernel2_k2_tmp'),
        "device1": ('a1', 'b', 'x', 'k2_tmp'),
        "device2": ('a1', 'b', 'x')
    }

    call_arguments = {
        "kernel1": ('a', 'b', 'c'),
        "kernel2": ('a', 'b'),
        "device1": ('a1', 'b', 'x', 'k2_tmp'),
        "device2": ('a1', 'b', 'x')
    }
    if not as_kwarguments:
        call_arguments["kernel1"] += ('kernel1_x', 'kernel1_y', 'kernel1_k1_tmp')
        call_arguments["kernel2"] += ('kernel2_x', 'kernel2_y', 'kernel2_z', 'kernel2_k2_tmp')

    call_kwarguments = {
        "kernel1": (('x', 'kernel1_x'), ('y', 'kernel1_y'), ('k1_tmp', 'kernel1_k1_tmp')) if as_kwarguments else (),
        "kernel2": (('x', 'kernel2_x'), ('y', 'kernel2_y'), ('z', 'kernel2_z'),
            ('k2_tmp', 'kernel2_k2_tmp')) if as_kwarguments else (),
        "device1": (),
        "device2": ()
    }

    check_arguments(
        scheduler=scheduler, subroutine_arguments=subroutine_arguments,
        call_arguments=call_arguments, call_kwarguments=call_kwarguments,
        include_device_functions=False
    )
    compile_and_test(
        scheduler=scheduler, path=tmp_path, a=(5, 10, 100),
        frontend=frontend, test_name="all_hoisted_disable"
    )

@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('as_kwarguments', [False, True])
def test_hoist_arrays_inline(tmp_path, testdir, frontend, config, as_kwarguments):
    """
    Testing hoist functionality for local arrays using the :class:`HoistTemporaryArraysAnalysis` for the *Analysis*
    part. The hoisted kernel contains inline function calls.
    """

    proj = testdir/'sources/projHoist'
    scheduler = Scheduler(
        paths=[proj], config=config, seed_routines=['inline_driver',], frontend=frontend, xmods=[tmp_path]
    )

    # Transformation: Analysis
    scheduler.process(transformation=HoistTemporaryArraysAnalysis())
    # Transformation: Synthesis
    scheduler.process(transformation=HoistVariablesTransformation(as_kwarguments=as_kwarguments))

    # check generated source code
    subroutine_arguments = {
        "inline_driver": ['a', 'b', 'c'],
        "kernel1": ['a', 'b', 'c', 'x', 'y', 'k1_tmp'],
        "kernel2": ['a1', 'b', 'x', 'k2_tmp', 'device2_z', 'init_int_tmp0'],
        "device1": ['a1', 'b', 'x', 'y', 'device2_z', 'init_int_tmp0'],
        "device2": ['a2', 'b', 'x', 'z', 'init_int_tmp0'],
        "init_int": ['a2', 'tmp0'],
        "func1": ['a']
    }

    call_arguments = {
        "kernel1": ('a', 'b', 'c'),
        "kernel2": ('a', 'b'),
        "device1": ('a1', 'b', 'x', 'k2_tmp'),
        "device2": ('a1', 'b', 'x'),
        "init_int": ('a2',),
        "func1": ('a',)
    }
    if not as_kwarguments:
        call_arguments["kernel1"] += ('kernel1_x', 'kernel1_y', 'kernel1_k1_tmp')
        call_arguments["kernel2"] += ('kernel2_x', 'kernel2_k2_tmp', 'device2_z', 'init_int_tmp0')
        call_arguments["device1"] += ('device2_z', 'init_int_tmp0')
        call_arguments["device2"] += ('device2_z', 'init_int_tmp0')
        call_arguments["init_int"] += ('init_int_tmp0',)

    call_kwarguments = {
        "kernel1": (('x', 'kernel1_x'), ('y', 'kernel1_y'), ('k1_tmp', 'kernel1_k1_tmp')) if as_kwarguments else (),
        "kernel2": (('x', 'kernel2_x'), ('k2_tmp', 'kernel2_k2_tmp'),
            ('device2_z', 'device2_z'), ('init_int_tmp0', 'init_int_tmp0')) if as_kwarguments else (),
        "device1": (('device2_z', 'device2_z'), ('init_int_tmp0', 'init_int_tmp0')) if as_kwarguments else (),
        "device2": (('z', 'device2_z'), ('init_int_tmp0', 'init_int_tmp0')) if as_kwarguments else (),
        "init_int": (('tmp0', 'init_int_tmp0'),) if as_kwarguments else (),
        "func1": ()
    }

    check_arguments(scheduler=scheduler, subroutine_arguments=subroutine_arguments, call_arguments=call_arguments,
           call_kwarguments=call_kwarguments, driver_item=scheduler['transformation_module_hoist_inline#inline_driver'],
           driver_name='inline_driver', include_another_driver=False, subroutine_mod='subroutines_inline_mod',
           include_device_functions=True)
    compile_and_test(scheduler=scheduler, path=tmp_path, a=(5, 10, 100), frontend=frontend,
                     test_name="hoisted_arrays_inline",
                     items=[scheduler["transformation_module_hoist_inline#inline_driver"],
                            scheduler["subroutines_inline_mod#kernel1"]], inline=True)

@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('as_kwarguments', [False, True])
def test_hoist_arrays(tmp_path, testdir, frontend, config, as_kwarguments):
    """
    Testing hoist functionality for local arrays using the :class:`HoistTemporaryArraysAnalysis` for the *Analysis*
    part.
    """

    proj = testdir/'sources/projHoist'
    scheduler = Scheduler(
        paths=[proj], config=config, seed_routines=['driver', 'another_driver'], frontend=frontend, xmods=[tmp_path]
    )

    # Transformation: Analysis
    scheduler.process(transformation=HoistTemporaryArraysAnalysis())
    # Transformation: Synthesis
    scheduler.process(transformation=HoistVariablesTransformation(as_kwarguments=as_kwarguments))

    # check generated source code
    subroutine_arguments = {
        "driver": ['a', 'b', 'c'],
        "another_driver": ['a', 'b', 'c'],
        "kernel1": ['a', 'b', 'c', 'x', 'y', 'k1_tmp'],
        "kernel2": ['a1', 'b', 'x', 'k2_tmp', 'device2_z'],
        "device1": ['a1', 'b', 'x', 'y', 'device2_z'],
        "device2": ['a2', 'b', 'x', 'z'],
    }

    call_arguments = {
        "kernel1": ('a', 'b', 'c'),
        "kernel2": ('a', 'b'),
        "device1": ('a1', 'b', 'x', 'k2_tmp'),
        "device2": ('a1', 'b', 'x')
    }
    if not as_kwarguments:
        call_arguments["kernel1"] += ('kernel1_x', 'kernel1_y', 'kernel1_k1_tmp')
        call_arguments["kernel2"] += ('kernel2_x', 'kernel2_k2_tmp', 'device2_z')
        call_arguments["device1"] += ('device2_z',)
        call_arguments["device2"] += ('device2_z',)

    call_kwarguments = {
        "kernel1": (('x', 'kernel1_x'), ('y', 'kernel1_y'), ('k1_tmp', 'kernel1_k1_tmp')) if as_kwarguments else (),
        "kernel2": (('x', 'kernel2_x'), ('k2_tmp', 'kernel2_k2_tmp'),
            ('device2_z', 'device2_z')) if as_kwarguments else (),
        "device1": (('device2_z', 'device2_z'), ) if as_kwarguments else (),
        "device2": (('z', 'device2_z'), ) if as_kwarguments else ()
    }

    check_arguments(scheduler=scheduler, subroutine_arguments=subroutine_arguments, call_arguments=call_arguments,
            call_kwarguments=call_kwarguments)
    compile_and_test(scheduler=scheduler, path=tmp_path, a=(5, 10, 100), frontend=frontend, test_name="hoisted_arrays")


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('as_kwarguments', [False, True])
def test_hoist_specific_variables(tmp_path, testdir, frontend, config, as_kwarguments):
    """
    Testing hoist functionality for local arrays with variable ``a`` in the array dimensions using the
    :class:`HoistTemporaryArraysAnalysis` for the *Analysis* part.
    """

    proj = testdir/'sources/projHoist'
    scheduler = Scheduler(
        paths=[proj], config=config, seed_routines=['driver', 'another_driver'], frontend=frontend, xmods=[tmp_path]
    )

    # Transformation: Analysis
    scheduler.process(transformation=HoistTemporaryArraysAnalysis(dim_vars=('a', 'a1', 'a2')))
    # Transformation: Synthesis
    scheduler.process(transformation=HoistVariablesTransformation(as_kwarguments=as_kwarguments))

    # check generated source code
    subroutine_arguments = {
        "driver": ['a', 'b', 'c'],
        "another_driver": ['a', 'b', 'c'],
        "kernel1": ['a', 'b', 'c', 'x', 'y', 'k1_tmp'],
        "kernel2": ['a1', 'b', 'x', 'k2_tmp', 'device2_z'],
        "device1": ['a1', 'b', 'x', 'y', 'device2_z'],
        "device2": ['a2', 'b', 'x', 'z'],
    }

    call_arguments = {
        "kernel1": ('a', 'b', 'c'),
        "kernel2": ('a', 'b'),
        "device1": ('a1', 'b', 'x', 'k2_tmp'),
        "device2": ('a1', 'b', 'x')
    }
    if not as_kwarguments:
        call_arguments["kernel1"] += ('kernel1_x', 'kernel1_y', 'kernel1_k1_tmp')
        call_arguments["kernel2"] += ('kernel2_x', 'kernel2_k2_tmp', 'device2_z')
        call_arguments["device1"] += ('device2_z',)
        call_arguments["device2"] += ('device2_z',)

    call_kwarguments = {
        "kernel1": (('x', 'kernel1_x'), ('y', 'kernel1_y'), ('k1_tmp', 'kernel1_k1_tmp')) if as_kwarguments else (),
        "kernel2": (('x', 'kernel2_x'), ('k2_tmp', 'kernel2_k2_tmp'),
            ('device2_z', 'device2_z')) if as_kwarguments else (),
        "device1": (('device2_z', 'device2_z'),) if as_kwarguments else (),
        "device2": (('z', 'device2_z'),) if as_kwarguments else ()
    }

    check_arguments(scheduler=scheduler, subroutine_arguments=subroutine_arguments, call_arguments=call_arguments,
            call_kwarguments=call_kwarguments)

    compile_and_test(scheduler=scheduler, path=tmp_path, a=(5, 10, 100), frontend=frontend,
                     test_name="hoisted_specific_arrays")


def check_variable_declaration(item, key):
    declarations = [_.symbols[0].name for _ in FindNodes(ir.VariableDeclaration).visit(item.ir.spec)]
    allocations = [_.variables[0].name for _ in FindNodes(ir.Allocation).visit(item.ir.body)]
    de_allocations = [_.variables[0].name for _ in FindNodes(ir.Deallocation).visit(item.ir.body)]
    assert allocations
    assert de_allocations
    to_hoist_vars = [var.name for var in item.trafo_data[key]["to_hoist"]]
    assert all(_ in declarations for _ in to_hoist_vars)
    assert all(_ in to_hoist_vars for _ in allocations)
    assert all(_ in to_hoist_vars for _ in de_allocations)


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('as_kwarguments', [False, True])
def test_hoist_allocatable(tmp_path, testdir, frontend, config, as_kwarguments):
    """
    Testing hoist functionality for local arrays with variable ``a`` in the array dimensions using the
    :class:`HoistTemporaryArraysAnalysis` for the *Analysis* part **and** a *Synthesis* implementation using declaring
    hoisted arrays as *allocatable*, including allocation and de-allocation using
    :class:`HoistTemporaryArraysTransformationAllocatable`.
    """

    proj = testdir/'sources/projHoist'
    scheduler = Scheduler(
        paths=[proj], config=config, seed_routines=['driver', 'another_driver'], frontend=frontend, xmods=[tmp_path]
    )

    key = "HoistVariablesTransformation"
    # Transformation: Analysis
    scheduler.process(
        transformation=HoistTemporaryArraysAnalysis(dim_vars=('a', 'a1', 'a2'))
    )
    # Transformation: Synthesis
    scheduler.process(
        transformation=HoistTemporaryArraysTransformationAllocatable(
            as_kwarguments=as_kwarguments
        )
    )

    # check generated source code
    for item in scheduler.items:
        if "driver" in item.name and "another" not in item.name:
            check_variable_declaration(item, key)
        elif "another_driver" in item.name:
            check_variable_declaration(item, key)

    subroutine_arguments = {
        "driver": ['a', 'b', 'c'],
        "another_driver": ['a', 'b', 'c'],
        "kernel1": ['a', 'b', 'c', 'x', 'y', 'k1_tmp'],
        "kernel2": ['a1', 'b', 'x', 'k2_tmp', 'device2_z'],
        "device1": ['a1', 'b', 'x', 'y', 'device2_z'],
        "device2": ['a2', 'b', 'x', 'z'],
    }

    call_arguments = {
        "kernel1": ('a', 'b', 'c', 'kernel1_x', 'kernel1_y', 'kernel1_k1_tmp'),
        "kernel2": ('a', 'b', 'kernel2_x', 'kernel2_k2_tmp', 'device2_z'),
        "device1": ('a1', 'b', 'x', 'k2_tmp', 'device2_z'),
        "device2": ('a1', 'b', 'x', 'device2_z')
    }

    call_arguments = {
        "kernel1": ('a', 'b', 'c'),
        "kernel2": ('a', 'b'),
        "device1": ('a1', 'b', 'x', 'k2_tmp'),
        "device2": ('a1', 'b', 'x')
    }
    if not as_kwarguments:
        call_arguments["kernel1"] += ('kernel1_x', 'kernel1_y', 'kernel1_k1_tmp')
        call_arguments["kernel2"] += ('kernel2_x', 'kernel2_k2_tmp', 'device2_z')
        call_arguments["device1"] += ('device2_z',)
        call_arguments["device2"] += ('device2_z',)

    call_kwarguments = {
        "kernel1": (('x', 'kernel1_x'), ('y', 'kernel1_y'), ('k1_tmp', 'kernel1_k1_tmp')) if as_kwarguments else (),
        "kernel2": (('x', 'kernel2_x'), ('k2_tmp', 'kernel2_k2_tmp'),
            ('device2_z', 'device2_z')) if as_kwarguments else (),
        "device1": (('device2_z', 'device2_z'),) if as_kwarguments else (),
        "device2": (('z', 'device2_z'),) if as_kwarguments else ()
    }

    check_arguments(scheduler=scheduler, subroutine_arguments=subroutine_arguments, call_arguments=call_arguments,
            call_kwarguments=call_kwarguments)
    compile_and_test(scheduler=scheduler, path=tmp_path, a=(5, 10, 100), frontend=frontend, test_name="allocatable")


@pytest.mark.parametrize('frontend', available_frontends())
def test_hoist_mixed_variable_declarations(tmp_path, frontend, config):

    fcode_driver = """
subroutine driver(NLON, NZ, NB, FIELD1, FIELD2)
    use kernel_mod, only: kernel
    implicit none
    INTEGER, PARAMETER :: JPRB = SELECTED_REAL_KIND(13,300)
    INTEGER, INTENT(IN) :: NLON, NZ, NB
    integer :: b
    real(kind=jprb), intent(inout) :: field1(nlon, nb)
    real(kind=jprb), intent(inout) :: field2(nlon, nz, nb)
    do b=1,nb
        call KERNEL(1, nlon, nlon, nz, 2, field1(:,b), field2(:,:,b))
    end do
end subroutine driver
    """.strip()
    fcode_kernel = """
module kernel_mod
    implicit none
contains
    subroutine kernel(start, end, klon, klev, nclv, field1, field2)
        use, intrinsic :: iso_c_binding, only : c_size_t
        implicit none
        interface
           subroutine another_kernel(klev)
               integer, intent(in) :: klev
           end subroutine another_kernel
        end interface
        integer, parameter :: jprb = selected_real_kind(13,300)
        integer, intent(in) :: nclv
        integer, intent(in) :: start, end, klon, klev
        real(kind=jprb), intent(inout) :: field1(klon)
        real(kind=jprb), intent(inout) :: field2(klon,klev)
        real(kind=jprb) :: tmp1(klon)
        real(kind=jprb) :: tmp2(klon, klev), tmp3(nclv)
        real(kind=jprb) :: tmp4(2), tmp5(klon, nclv, klev)
        integer :: jk, jl, jm

        do jk=1,klev
            tmp1(jl) = 0.0_jprb
            do jl=start,end
                tmp2(jl, jk) = field2(jl, jk)
                tmp1(jl) = field2(jl, jk)
            end do
            field1(jl) = tmp1(jl)
        end do

        do jm=1,nclv
           tmp3(jm) = 0._jprb
           do jl=start,end
             tmp5(jl, jm, :) = field1(jl)
           enddo
        enddo

        call another_kernel(klev)
    end subroutine kernel
end module kernel_mod
    """.strip()
    fcode_mod = """
module size_mod
   implicit none
   integer :: n
end module size_mod
""".strip()
    fcode_another_kernel = """
subroutine another_kernel(klev)
    use size_mod, only : n
    implicit none
    integer, intent(in) :: klev
    real :: another_tmp(klev,n)
end subroutine another_kernel
""".strip()

    (tmp_path/'driver.F90').write_text(fcode_driver)
    (tmp_path/'kernel_mod.F90').write_text(fcode_kernel)
    (tmp_path/'size_mod.F90').write_text(fcode_mod)
    (tmp_path/'another_kernel.F90').write_text(fcode_another_kernel)

    config = {
        'default': {
            'mode': 'idem',
            'role': 'kernel',
            'expand': True,
            'strict': True,
            'enable_imports': True
        },
        'routines': {
            'driver': {'role': 'driver'}
        }
    }

    scheduler = Scheduler(
        paths=[tmp_path], config=SchedulerConfig.from_dict(config), frontend=frontend, xmods=[tmp_path]
    )
    scheduler.process(transformation=HoistTemporaryArraysAnalysis(dim_vars=('klev',)))
    scheduler.process(transformation=HoistTemporaryArraysTransformationAllocatable())

    driver_variables = (
        'jprb', 'nlon', 'nz', 'nb', 'b',
        'field1(nlon, nb)', 'field2(nlon, nz, nb)',
        'kernel_tmp2(:,:)', 'kernel_tmp5(:,:,:)', 'another_kernel_another_tmp(:,:)'
    )
    kernel_arguments = (
       'start', 'end', 'klon', 'klev', 'nclv',
        'field1(klon)', 'field2(klon,klev)', 'tmp2(klon,klev)', 'tmp5(klon,nclv,klev)',
        'another_kernel_another_tmp(klev,n)'
    )

    # Check hoisting and declaration in driver
    assert scheduler['#driver'].ir.variables == driver_variables
    assert scheduler['kernel_mod#kernel'].ir.arguments == kernel_arguments

    # Check updated call signature
    calls = FindNodes(ir.CallStatement).visit(scheduler['#driver'].ir.body)
    assert len(calls) == 1
    assert calls[0].arguments == (
        '1', 'nlon', 'nlon', 'nz', '2', 'field1(:,b)', 'field2(:,:,b)',
        'kernel_tmp2', 'kernel_tmp5', 'another_kernel_another_tmp'
    )

    # Check that fgen works
    assert scheduler['kernel_mod#kernel'].source.to_fortran()

    # Check that imports were updated
    imports = FindNodes(ir.Import).visit(scheduler['kernel_mod#kernel'].ir.spec)
    assert len(imports) == 2
    assert 'n' in scheduler['kernel_mod#kernel'].ir.imported_symbols
    assert imports[0].module.lower() == 'size_mod'

    imports = FindNodes(ir.Import).visit(scheduler['#driver'].ir.spec)
    assert len(imports) == 2
    assert 'n' in scheduler['#driver'].ir.imported_symbols
    assert imports[0].module.lower() == 'size_mod'


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('remap_dimensions', (False, True))
def test_hoist_dim_mapping(tmp_path, frontend, config, remap_dimensions):

    fcode_driver = """
subroutine driver(NLON, NB, FIELD1)
    use kernel_mod, only: kernel
    implicit none
    INTEGER, INTENT(IN) :: NLON, NB
    integer :: b
    integer, intent(inout) :: field1(nlon, nb)
    integer :: local_nlon
    local_nlon = nlon
    do b=1,nb
        call KERNEL(local_nlon, field1(:,b))
    end do
end subroutine driver
    """.strip()
    fcode_kernel = """
module kernel_mod
    implicit none
contains
    subroutine kernel(klon, field1)
        implicit none
        integer, intent(in) :: klon
        integer, intent(inout) :: field1(klon)
        integer :: tmp1(klon)
        integer :: jl

        do jl=1,klon
            tmp1(jl) = 0
            field1(jl) = tmp1(jl)
        end do

    end subroutine kernel
end module kernel_mod
    """.strip()

    (tmp_path/'driver.F90').write_text(fcode_driver)
    (tmp_path/'kernel_mod.F90').write_text(fcode_kernel)

    config = {
        'default': {
            'mode': 'idem',
            'role': 'kernel',
            'expand': True,
            'strict': True,
            'enable_imports': True
        },
        'routines': {
            'driver': {'role': 'driver'}
        }
    }

    scheduler = Scheduler(
        paths=[tmp_path], config=SchedulerConfig.from_dict(config), frontend=frontend, xmods=[tmp_path]
    )

    scheduler.process(transformation=HoistTemporaryArraysAnalysis())
    scheduler.process(transformation=HoistVariablesTransformation(remap_dimensions=remap_dimensions))

    driver_var_map = scheduler['#driver'].ir.variable_map
    assert 'kernel_tmp1' in driver_var_map
    if remap_dimensions:
        assert driver_var_map['kernel_tmp1'].dimensions == ('nlon',)
    else:
        assert driver_var_map['kernel_tmp1'].dimensions == ('local_nlon',)


@pytest.mark.parametrize('frontend', available_frontends())
def test_hoist_dim_alias(tmp_path, testdir, frontend, config):
    """
    Test that temporaries declared with an aliased dimension aren't repeated.
    """

    proj = testdir/'sources/projHoist'
    scheduler = Scheduler(
        paths=[proj], config=config, seed_routines=['yet_another_driver'], frontend=frontend, xmods=[tmp_path]
    )

    scheduler.process(transformation=HoistTemporaryArraysAnalysis())
    scheduler.process(transformation=HoistVariablesTransformation())

    kernel3 = scheduler['subroutines_mod#kernel3'].ir
    device3 = scheduler['subroutines_mod#device3'].ir
    driver = scheduler['transformation_module_hoist#yet_another_driver'].ir

    # check temporary was hoisted
    assert 'x' in device3._dummies
    assert 'device3_x' in kernel3._dummies
    assert len(kernel3.arguments) == 3
    device3_temp = kernel3.variable_map['device3_x']
    assert 'a1' in device3_temp.shape

    # check call signatures for device3 were updated
    calls = FindNodes(ir.CallStatement).visit(kernel3.body)
    assert len(calls) == 2
    assert 'device3_x' in calls[0].arguments
    assert 'device3_x' in calls[1].arguments

    # check driver layer
    calls = FindNodes(ir.CallStatement).visit(driver.body)
    assert len(calls[0].arguments) == 3
    assert 'device3_x' in calls[0].arguments
loki-ecmwf-0.3.6/loki/transformations/temporaries/tests/test_pool_allocator.py0000664000175000017500000015620115167130205030253 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki.expression.parser import parse_expr
from loki import Dimension
from loki.batch import Scheduler, SchedulerConfig
from loki.expression import (
        InlineCall, RangeIndex, simplify, Sum,
        Product
)
from loki.frontend import available_frontends, OMNI, FP
from loki.ir import (
    FindNodes, CallStatement, Assignment, Allocation, Deallocation,
    Loop, Pragma, get_pragma_parameters, FindVariables, FindInlineCalls,
    Intrinsic
)

from loki.transformations.pragma_model import PragmaModelTransformation
from loki.transformations.temporaries.pool_allocator import (
    TemporariesPoolAllocatorTransformation, EcstackPoolAllocatorTransformation
)


@pytest.fixture(scope='module', name='block_dim')
def fixture_block_dim():
    return Dimension(name='block_dim', size='nb', index='b')

@pytest.fixture(scope='module', name='horizontal')
def fixture_horizontal():
    return Dimension(name='horizontal', size='nlon', index='jl', bounds=('start', 'end'), aliases=('klon', 'columns'))

@pytest.fixture(scope='module', name='block_dim_alt')
def fixture_block_dim_alt():
    return Dimension(name='block_dim_alt', size='geom%blk_dim%nb', index='b')

def check_c_sizeof_import(routine):
    assert any(import_.module.lower() == 'iso_c_binding' for import_ in routine.imports)
    assert 'c_sizeof' in routine.imported_symbols

def check_real64_import(routine):
    assert any(import_.module.lower() == 'iso_fortran_env' for import_ in routine.imports)
    assert 'real64' in routine.imported_symbols


def check_stack_created_in_driver(
        driver, stack_size, first_kernel_call, num_block_loops,
        check_bounds=True, cray_ptr_loc_rhs=False
):
    # Are stack size, storage and stack derived type declared?
    assert 'istsz' in driver.variables
    assert 'zstack(:,:)' in driver.variables
    assert 'ylstack_l' in driver.variables

    # Is there an allocation and deallocation for the stack storage?
    allocations = FindNodes(Allocation).visit(driver.body)
    assert len(allocations) == 1 and 'zstack(istsz,nb)' in allocations[0].variables
    deallocations = FindNodes(Deallocation).visit(driver.body)
    assert len(deallocations) == 1 and 'zstack' in deallocations[0].variables

    # # Check the stack size
    assignments = FindNodes(Assignment).visit(driver.body)
    for assignment in assignments:
        if assignment.lhs == 'istsz':
            assert simplify(assignment.rhs) == simplify(stack_size)

    # # Check for stack assignment inside loop
    loops = FindNodes(Loop).visit(driver.body)
    assert len(loops) == num_block_loops
    assignments = FindNodes(Assignment).visit(loops[0].body)
    assert assignments[0].lhs == 'ylstack_l'
    if cray_ptr_loc_rhs:
        assert assignments[0].rhs == '1'
    else:
        assert isinstance(assignments[0].rhs, InlineCall) and assignments[0].rhs.function == 'loc'
        assert 'zstack(1, b)' in assignments[0].rhs.parameters
    if check_bounds:
        if cray_ptr_loc_rhs:
            assert assignments[1].lhs == 'ylstack_u' and (
                    assignments[1].rhs == 'ylstack_l + istsz')
        else:
            assert assignments[1].lhs == 'ylstack_u' and (
                    assignments[1].rhs == 'ylstack_l + istsz * c_sizeof(real(1, kind=real64))')
    # Check that stack assignment happens before kernel call
    assert all(loops[0].body.index(a) < loops[0].body.index(first_kernel_call) for a in assignments)


@pytest.mark.parametrize('generate_driver_stack', [True, False])
@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('check_bounds', [True, False])
@pytest.mark.parametrize('nclv_param', [False, True])
@pytest.mark.parametrize('cray_ptr_loc_rhs', [False, True])
@pytest.mark.parametrize('trafo', [EcstackPoolAllocatorTransformation, TemporariesPoolAllocatorTransformation])
def test_pool_allocator_temporaries(tmp_path, frontend, generate_driver_stack, block_dim, check_bounds,
                                    nclv_param, cray_ptr_loc_rhs, trafo, horizontal):
    fcode_iso_c_binding = "use, intrinsic :: iso_c_binding, only: c_sizeof"
    fcode_iso_env = "use iso_fortran_env, only: real64"
    fcode_nclv_param = 'integer, parameter :: nclv = 2'
    # set kind comaprison string
    nclv_var = '2' if frontend == OMNI else 'nclv'
    if frontend == OMNI:
        kind_real = 'selected_real_kind(13, 300)'
    else:
        kind_real = 'jprb'
    kind_stack = 'real64'
    if nclv_param:
        nclv_var = '2' if frontend == OMNI else 'nclv'
        stack_size_str = (
            f'ishft(7 + c_sizeof(real(1, kind={kind_real}))*nlon, -3)'
            f' + ishft(7 + c_sizeof(real(1, kind={kind_real}))*nz*nlon, -3)'
            f' + ishft(7 + c_sizeof(real(1, kind={kind_real}))*nlon*{nclv_var}, -3)'
        )
    else:
        stack_size_str = (
            f'ishft(7 + c_sizeof(real(1, kind={kind_real}))*nlon, -3) &\n'
            f' & + ishft(7 + c_sizeof(real(1, kind={kind_real}))*nz*nlon, -3) &\n'
            f' & + ishft(7 + c_sizeof(real(1, kind={kind_real}))*2, -3)'
            f' + ishft(7 + c_sizeof(real(1, kind={kind_real}))*nlon*2, -3)'
        )

    fcode_stack_decl = f"""
        integer :: istsz
        REAL(KIND=REAL64), ALLOCATABLE :: ZSTACK(:, :)
        integer(kind=8) :: ylstack_l
        integer(kind=8) :: ylstack_u

        istsz = {stack_size_str}

        {'ALLOCATE(ZSTACK(ISTSZ, nb))' if trafo == TemporariesPoolAllocatorTransformation else 'CALL ECSTACK%GET_STACK_PTR(ZSTACK, ISTSZ, nb)'}
    """
    if cray_ptr_loc_rhs:
        fcode_stack_assign = """
            ylstack_l = 1
            ylstack_u = ylstack_l + istsz
        """
    else:
        fcode_stack_assign = f"""
            ylstack_l = loc(zstack(1, b))
            ylstack_u = ylstack_l + istsz * c_sizeof(real(1, kind={kind_stack}))
        """
    fcode_stack_dealloc = "DEALLOCATE(ZSTACK)" if trafo == TemporariesPoolAllocatorTransformation else ''

    fcode_ecstack = """
    module ecstack_mod
    implicit none
     type tecstack
       integer :: size
       contains
       PROCEDURE :: GET_STACK_PTR
     end type tecstack 

     type(tecstack) :: ecstack

     contains
       SUBROUTINE GET_STACK_PTR(SELF, PTR, KSIZE, NGPBLKS)
          CLASS(TECSTACK) :: SELF
          REAL, POINTER, CONTIGUOUS, INTENT(INOUT) :: PTR(:, :)
          INTEGER, INTENT(IN) :: KSIZE
          INTEGER, INTENT(IN) :: NGPBLKS
       
       END SUBROUTINE GET_STACK_PTR
    end module ecstack_mod
    """
    fcode_driver = f"""
subroutine driver(NLON, NZ, NB, FIELD1, FIELD2)
    {fcode_iso_c_binding if not generate_driver_stack else ''}
    {fcode_iso_env if not generate_driver_stack else ''}
    {'use ecstack_mod, only: ecstack' if (trafo == EcstackPoolAllocatorTransformation and not generate_driver_stack) else ''}
    use kernel_mod, only: kernel
    implicit none
    INTEGER, PARAMETER :: JPRB = SELECTED_REAL_KIND(13,300)
    INTEGER, INTENT(IN) :: NLON, NZ, NB
    real(kind=jprb), intent(inout) :: field1(nlon, nb)
    real(kind=jprb), intent(inout) :: field2(nlon, nz, nb)
    integer :: b
    {fcode_stack_decl if not generate_driver_stack else ''}
    do b=1,nb
        {fcode_stack_assign if not generate_driver_stack else ''}
        call KERNEL(1, nlon, nlon, nz, {'2, ' if not nclv_param else ''} field1(:,b), field2(:,:,b))
    end do
    {fcode_stack_dealloc if not generate_driver_stack else ''}
end subroutine driver
    """.strip()
    fcode_kernel = f"""
module kernel_mod
    implicit none
contains
    subroutine kernel(start, end, klon, klev, {'nclv, ' if not nclv_param else ''} field1, field2)
        ! use, intrinsic :: iso_c_binding, only : c_size_t
        implicit none
        integer, parameter :: jprb = selected_real_kind(13,300)
        {fcode_nclv_param if nclv_param else 'integer, intent(in) :: nclv'}
        integer, intent(in) :: start, end, klon, klev
        real(kind=jprb), intent(inout), target :: field1(klon)
        real(kind=jprb), intent(inout) :: field2(klon,klev)
        real(kind=jprb) :: tmp1(klon)
        real(kind=jprb) :: tmp2(klon, klev)
        real(kind=jprb) :: tmp3(nclv), tmp4(2), tmp5(klon, nclv)
        real(kind=jprb), pointer :: tmp3_ptr(:)
        integer :: jk, jl, jm

        tmp3_ptr => field1

        do jk=1,klev
            tmp1(jl) = 0.0_jprb
            do jl=start,end
                tmp2(jl, jk) = field2(jl, jk)
                tmp1(jl) = field2(jl, jk)
            end do
            field1(jl) = tmp1(jl)
        end do

        do jm=1,nclv
           tmp3(jm) = 0._jprb
           do jl=start,end
             tmp5(jl, jm) = field1(jl)
           enddo
        enddo
    end subroutine kernel
end module kernel_mod
    """.strip()


    (tmp_path/'driver.F90').write_text(fcode_driver)
    (tmp_path/'kernel_mod.F90').write_text(fcode_kernel)
    (tmp_path/'ecstack_mod.F90').write_text(fcode_ecstack)

    config = {
        'default': {
            'mode': 'idem',
            'role': 'kernel',
            'expand': True,
            'strict': True,
            'enable_imports': True,
            'ignore': ['iso_fortran_env']
        },
        'routines': {
            'driver': {'role': 'driver'}
        }
    }

    if frontend == FP and not generate_driver_stack:
        # Patch "LOC" intrinsic into fparser. This is not strictly needed (it will just represent it
        # as Array instead of an InlineCall) but makes for a more coherent check further down
        from fparser.two import Fortran2003  # pylint: disable=import-outside-toplevel
        Fortran2003.Intrinsic_Name.other_inquiry_names.update({"LOC": {'min': 1, 'max': 1}})
        Fortran2003.Intrinsic_Name.generic_function_names.update({"LOC": {'min': 1, 'max': 1}})
        Fortran2003.Intrinsic_Name.function_names += ["LOC"]

    scheduler = Scheduler(
        paths=[tmp_path], config=SchedulerConfig.from_dict(config),
        frontend=frontend, xmods=[tmp_path]
    )

    transformation = trafo(
        block_dim=block_dim, horizontal=horizontal, check_bounds=check_bounds,
        cray_ptr_loc_rhs=cray_ptr_loc_rhs
    )
    scheduler.process(transformation=transformation)
    pragma_model_trafo = PragmaModelTransformation()
    scheduler.process(transformation=pragma_model_trafo)
    kernel_item = scheduler['kernel_mod#kernel']

    assert transformation._key in kernel_item.trafo_data

    # a few driver checks
    driver = scheduler['#driver'].ir
    check_c_sizeof_import(driver)
    check_real64_import(driver)
    calls = FindNodes(CallStatement).visit(driver.body)
    assert len(calls) == 1 if trafo == TemporariesPoolAllocatorTransformation else 2
    if nclv_param:
        expected_args = ('1', 'nlon', 'nlon', 'nz', 'field1(:,b)', 'field2(:,:,b)')
    else:
        expected_args = ('1', 'nlon', 'nlon', 'nz', '2', 'field1(:,b)', 'field2(:,:,b)')
    if check_bounds:
        expected_kwargs = (('YDSTACK_L', 'ylstack_l'), ('YDSTACK_U', 'ylstack_u'))
    else:
        expected_kwargs = (('YDSTACK_L', 'ylstack_l'),)
    if cray_ptr_loc_rhs:
        if frontend == OMNI and not generate_driver_stack:
            # If the stack exists already in the driver, that variable is used. And because
            # OMNI lower-cases everything, this will result in a lower-case name for the
            # argument for that particular case...
            expected_kwargs += (('zstack', 'zstack(:,b)'),)
        else:
            expected_kwargs += (('ZSTACK', 'zstack(:,b)'),)
    relevant_call = calls[0] if trafo == TemporariesPoolAllocatorTransformation else calls[1]
    assert relevant_call.arguments == expected_args
    assert relevant_call.kwarguments == expected_kwargs
    if trafo == EcstackPoolAllocatorTransformation:
        assert calls[0].arguments == ('ZSTACK', 'ISTSZ', 'nb')

    if nclv_param:
        nclv_var = '2' if frontend == OMNI else 'nclv'
        stack_size_str = (
            f'ishft(7 + c_sizeof(real(1, kind={kind_real}))*nlon, -3)'
            f' + ishft(7 + c_sizeof(real(1, kind={kind_real}))*nz*nlon, -3)'
            f' + ishft(7 + c_sizeof(real(1, kind={kind_real}))*nlon*{nclv_var}, -3)'
        )
    else:
        stack_size_str = (
            f'ishft(7 + c_sizeof(real(1, kind={kind_real}))*nlon, -3)'
            f' + ishft(7 + c_sizeof(real(1, kind={kind_real}))*nz*nlon, -3)'
            f' + ishft(7 + c_sizeof(real(1, kind={kind_real}))*2, -3)'
            f' + ishft(7 + c_sizeof(real(1, kind={kind_real}))*nlon*2, -3)'
        )

    stack_size = parse_expr(stack_size_str)
    if trafo == TemporariesPoolAllocatorTransformation:
        check_stack_created_in_driver(driver, stack_size, calls[0], 1, check_bounds=check_bounds,
                cray_ptr_loc_rhs=cray_ptr_loc_rhs)

    # a few kernel checks
    kernel = kernel_item.ir
    check_c_sizeof_import(kernel)
    check_real64_import(kernel)
    # # Has the stack been added to the arguments?
    assert 'ydstack_l' in kernel.arguments
    if check_bounds:
        assert 'ydstack_u' in kernel.arguments

    # Is it being assigned to a local variable?
    assert 'ylstack_l' in kernel.variables
    if check_bounds:
        assert 'ylstack_u' in kernel.variables

    # Let's check for the relevant "allocations" happening in the right order
    if nclv_param:
        tmp_indices = (1, 2, 5)
    else:
        tmp_indices = (1, 2, 3, 5)
    assign_idx = {}
    for idx, assign in enumerate(FindNodes(Assignment).visit(kernel.body)):
        if assign.lhs == 'ylstack_l' and assign.rhs == 'ydstack_l':
            # Local copy of stack status
            assign_idx['stack_assign'] = idx
        elif str(assign.lhs).lower().startswith('ip_tmp'):
            # Assign Cray pointer for tmp1, tmp2, tmp5 (and tmp3, tmp4 if no alloc_dims provided)
            for tmp_index in tmp_indices:
                if f'ip_tmp{tmp_index}' == assign.lhs:
                    assign_idx[f'tmp{tmp_index}_ptr_assign'] = idx
        elif assign.lhs == 'ylstack_l' and 'ylstack_l' in assign.rhs: #  and 'c_sizeof' in assign.rhs:

            # Stack increment for tmp1, tmp2, tmp5 (and tmp3, tmp4 if no alloc_dims provided)
            for tmp_index in tmp_indices:

                dim = f"{kernel.variable_map[f'tmp{tmp_index}'].shape[0]}"
                for v in kernel.variable_map[f'tmp{tmp_index}'].shape[1:]:
                    dim += f'*{v}'
                if cray_ptr_loc_rhs:
                    exp_rhs_str = f'ylstack_l + ishft({dim}*C_SIZEOF(REAL(1, kind={kind_real})) + 7, -3)'
                else:
                    exp_rhs_str = f'ylstack_l + ishft(ishft({dim}*C_SIZEOF(REAL(1, kind={kind_real})) + 7, -3), 3)'
                expected_rhs = parse_expr(exp_rhs_str)
                if expected_rhs == assign.rhs:
                    assign_idx[f'tmp{tmp_index}_stack_incr'] = idx

    expected_assign_in_order = ['stack_assign']
    for tmp_index in tmp_indices:
        expected_assign_in_order += [f'tmp{tmp_index}_ptr_assign', f'tmp{tmp_index}_stack_incr']
    assert set(expected_assign_in_order) == set(assign_idx.keys())

    for assign1, assign2 in zip(expected_assign_in_order, expected_assign_in_order[1:]):
        assert assign_idx[assign2] > assign_idx[assign1]

    # Check for pointer declarations in generated code
    fcode = kernel.to_fortran()
    for tmp_index in tmp_indices:
        assert f'pointer(ip_tmp{tmp_index}, tmp{tmp_index})' in fcode.lower()

    # Check for stack size safeguards in generated code
    if check_bounds:
        assert fcode.lower().count('if (ylstack_l > ylstack_u)') == len(tmp_indices)
        assert fcode.lower().count('stop') == len(tmp_indices)
    else:
        assert 'if (ylstack_l > ylstack_u)' not in fcode.lower()
        assert 'stop' not in fcode.lower()


@pytest.mark.parametrize('frontend', available_frontends())
def test_pool_allocator_unused_temporaries(tmp_path, frontend, horizontal, block_dim):
    fcode_driver = """
subroutine driver(NLON, NZ, NB, FIELD1, FIELD2)
    use kernel_mod, only: kernel
    implicit none
    INTEGER, PARAMETER :: JPRB = SELECTED_REAL_KIND(13,300)
    INTEGER, INTENT(IN) :: NLON, NZ, NB
    real(kind=jprb), intent(inout) :: field1(nlon, nb)
    real(kind=jprb), intent(inout) :: field2(nlon, nz, nb)
    integer :: b
    do b=1,nb
        call KERNEL(1, nlon, nlon, nz, field1(:,b), field2(:,:,b))
    end do
end subroutine driver
    """.strip()
    fcode_kernel = """
module kernel_mod
    implicit none
contains
    subroutine kernel(start, end, klon, klev, field1, field2)
        implicit none
        integer, parameter :: jprb = selected_real_kind(13,300)
        integer, parameter :: nclv = 2
        integer, intent(in) :: start, end, klon, klev
        real(kind=jprb), intent(inout) :: field1(klon)
        real(kind=jprb), intent(inout) :: field2(klon,klev)
        real(kind=jprb) :: tmp1(klon)
        real(kind=jprb) :: tmp2(klon, klev)
        ! shouldn't end up on stack since not used at all
        real(kind=jprb) :: tmp3(klon, klev)
        ! shouldn't end up on stack since all dimension are constant
        real(kind=jprb) :: tmp4(nclv)
        ! shouldn't end up on stack since all dimension are constant
        real(kind=jprb) :: tmp5(2)
        ! should be on the stack although only read and not written to
        real(kind=jprb) :: tmp6(klon, nclv)
        ! should be on the stack although only used as an argument in a call
        real(kind=jprb) :: tmp7(klon, nclv)
        integer :: jk, jl, jm

        tmp5(1) = 0.0_jprb
        do jk=1,klev
            tmp1(jl) = 0.0_jprb
            do jl=start,end
                tmp2(jl, jk) = field2(jl, jk)
                tmp1(jl) = field2(jl, jk)
            end do
            field1(jl) = tmp1(jl)
        end do

        do jm=1,nclv
           tmp4(jm) = 0._jprb
           do jl=start,end
             field1(jl) = tmp6(jl, jm)
           enddo
        enddo
        
        call foo(tmp7)

    end subroutine kernel
end module kernel_mod
    """.strip()

    config = {
        'default': {
            'mode': 'idem',
            'role': 'kernel',
            'expand': True,
            'strict': False,
            'enable_imports': True,
        },
        'routines': {
            'driver': {'role': 'driver'}
        }
    }

    (tmp_path/'driver.F90').write_text(fcode_driver)
    (tmp_path/'kernel_mod.F90').write_text(fcode_kernel)
    scheduler = Scheduler(
        paths=[tmp_path], config=SchedulerConfig.from_dict(config),
        frontend=frontend, xmods=[tmp_path]
    )

    transformation = TemporariesPoolAllocatorTransformation(
        block_dim=block_dim, horizontal=horizontal, check_bounds=False,
        cray_ptr_loc_rhs=False
    )
    scheduler.process(transformation=transformation)
    kernel_item = scheduler['kernel_mod#kernel']

    assert transformation._key in kernel_item.trafo_data

    # check that the correct variables end up on the stack
    #  look for 'POINTER(IP_tmp<...>, tmp<...>)' Intrinsics
    pointers = [intrinsic.text.split(',')[1].replace(')', '').replace(' ', '') for intrinsic
            in FindNodes(Intrinsic).visit(kernel_item.ir.spec)
            if 'pointer' in intrinsic.text.lower()]
    assert 'tmp1' in pointers
    assert 'tmp2' in pointers
    assert 'tmp6' in pointers
    assert 'tmp7' in pointers
    assert 'tmp3' not in pointers
    assert 'tmp4' not in pointers
    assert 'tmp5' not in pointers

@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('directive', [False, 'openmp', 'openacc', 'openmp-manual'])
@pytest.mark.parametrize('stack_insert_pragma', [False, True])
@pytest.mark.parametrize('cray_ptr_loc_rhs', [False, True])
def test_pool_allocator_temporaries_kernel_sequence(tmp_path, frontend, block_dim, directive,
                                                    stack_insert_pragma, cray_ptr_loc_rhs, horizontal):

    if directive == 'openmp-manual':
        driver_loop_pragma1 = '!$omp parallel default(shared) private(b) firstprivate(a)\n    !$omp do'
        driver_end_loop_pragma1 = '!$omp end do\n    !$omp end parallel'
        driver_loop_pragma2 = '!$omp parallel do firstprivate(a)'
        driver_end_loop_pragma2 = '!$omp end parallel do'
        kernel_pragma = ''
        # from here on continue as directive is 'openmp'
        directive = 'openmp'
    else:
        driver_loop_pragma1 = '!$loki loop gang default(shared) private(b) firstprivate(a)'
        driver_end_loop_pragma1 = '!$loki end loop gang'
        driver_loop_pragma2 = '!$loki loop gang firstprivate(a)'
        driver_end_loop_pragma2 = '!$loki end loop gang'
        kernel_pragma = '!$loki routine vector'

    if stack_insert_pragma:
        stack_size_location_pragma = '!$loki stack-insert'
    else:
        stack_size_location_pragma = ''


    fcode_parkind_mod = """
module parkind1
implicit none
integer, parameter :: jprb = selected_real_kind(13,300)
integer, parameter :: jpim = selected_int_kind(9)
integer, parameter :: jplm = jpim
end module parkind1
    """.strip()

    fcode_driver = f"""
subroutine driver(NLON, NZ, NB, FIELD1, FIELD2)
    use kernel_mod, only: kernel, kernel2
    implicit none
    INTEGER, PARAMETER :: JPRB = SELECTED_REAL_KIND(13,300)
    INTEGER, INTENT(IN) :: NLON, NZ, NB
    real(kind=jprb), intent(inout) :: field1(nlon, nb)
    real(kind=jprb), intent(inout) :: field2(nlon, nz, nb)
    integer :: a,b

    ! a = 1, necessary to check loki stack-insert pragma
    a = 1
    {stack_size_location_pragma}

    {driver_loop_pragma1}
    do b=1,nb
        call KERNEL(1, nlon, nlon, nz, field1(:,b), field2(:,:,b))
    end do
    {driver_end_loop_pragma1}

    {driver_loop_pragma2}
    do b=1,nb
        call KERNEL2(1, nlon, nlon, nz, field2(:,:,b))
    end do
    {driver_end_loop_pragma2}
end subroutine driver
    """.strip()

    fcode_kernel = f"""
module kernel_mod
    implicit none
contains
    subroutine kernel(start, end, klon, klev, field1, field2)
        use parkind1, only: jprb, jpim, jplm
        implicit none
        integer, intent(in) :: start, end, klon, klev
        real(kind=jprb), intent(inout) :: field1(klon)
        real(kind=jprb), intent(inout) :: field2(klon,klev)
        real(kind=jprb) :: tmp1(klon)
        real(kind=jprb) :: tmp2(klon, klev)
        integer(kind=jpim) :: tmp3(klon*2)
        logical(kind=jplm) :: tmp4(klev)
        integer :: jk, jl
        {kernel_pragma}

        do jk=1,klev
            tmp1(jl) = 0.0_jprb
            do jl=start,end
                tmp2(jl, jk) = field2(jl, jk)
                tmp1(jl) = field2(jl, jk)
            end do
            field1(jl) = tmp1(jl)
            tmp4(jk) = .true.
        end do

        do jl=start,end
           tmp3(jl) = 1_jpim
           tmp3(jl+klon) = 1_jpim
        enddo
    end subroutine kernel

    subroutine kernel2(start, end, klon, klev, field2)
        implicit none
        integer, parameter :: jprb = selected_real_kind(13,300)
        integer, intent(in) :: start, end, klon, klev
        real(kind=jprb), intent(inout) :: field2(klon,klev)
        real(kind=jprb) :: tmp1(2*klon, klev), tmp2(0:klon, 0:klev)
        integer :: jk, jl

        do jk=1,klev
            do jl=start,end
                tmp1(jl, jk) = field2(jl, jk)
                tmp1(jl+klon, jk) = field2(jl, jk)*2._jprb
                tmp2(jl, jk) = tmp1(jl, jk) + 1._jprb
                field2(jl, jk) = tmp2(jl, jk)
            end do
        end do
    end subroutine kernel2

end module kernel_mod
    """.strip()

    (tmp_path/'driver.F90').write_text(fcode_driver)
    (tmp_path/'kernel_mod.F90').write_text(fcode_kernel)
    (tmp_path/'parkind_mod.F90').write_text(fcode_parkind_mod)

    config = {
        'default': {
            'mode': 'idem',
            'role': 'kernel',
            'expand': True,
            'strict': True,
            'enable_imports': True,
        },
        'routines': {
            'driver': {'role': 'driver'}
        }
    }
    scheduler = Scheduler(
        paths=[tmp_path], config=SchedulerConfig.from_dict(config),
        frontend=frontend, xmods=[tmp_path]
    )

    transformation = TemporariesPoolAllocatorTransformation(
        block_dim=block_dim, horizontal=horizontal, directive=directive, cray_ptr_loc_rhs=cray_ptr_loc_rhs
    )
    scheduler.process(transformation=transformation)
    pragma_model_trafo = PragmaModelTransformation(directive=directive)
    scheduler.process(transformation=pragma_model_trafo)

    kernel_item = scheduler['kernel_mod#kernel']
    kernel2_item = scheduler['kernel_mod#kernel2']

    # set kind comaprison string
    if frontend == OMNI:
        kind_real = 'selected_real_kind(13, 300)'
        kind_int = '4'
        kind_log = '4'
    else:
        kind_real = 'jprb'
        kind_int = 'jpim'
        kind_log = 'jplm'

    assert transformation._key in kernel_item.trafo_data
    exp_stack_size_str = (
            f'ishft(7 + c_sizeof(real(1, kind={kind_real}))*klon, -3)'
            f' + ishft(7 + c_sizeof(real(1, kind={kind_real}))*klev*klon, -3)'
            f' + ishft(7 + 2*c_sizeof(int(1, kind={kind_int}))*klon, -3)'
            f' + ishft(7 + c_sizeof(logical(true, kind={kind_log}))*klev, -3)'
    )
    exp_stack_size = parse_expr(exp_stack_size_str )
    assert kernel_item.trafo_data[transformation._key]['stack_size'] == exp_stack_size
    exp_stack_size_str = (
            f'ishft(7 + 2*c_sizeof(real(1, kind={kind_real}))*klev*klon, -3)'
            f' + ishft(7 + c_sizeof(real(1, kind={kind_real}))'
            f' + c_sizeof(real(1, kind={kind_real}))*klon'
            f' + c_sizeof(real(1, kind={kind_real}))*klev'
            f' + c_sizeof(real(1, kind={kind_real}))*klev*klon, -3)'
    )

    exp_stack_size = parse_expr(exp_stack_size_str)
    assert kernel2_item.trafo_data[transformation._key]['stack_size'] == exp_stack_size
    assert all(
        v.scope is None
        for v in FindVariables().visit(kernel_item.trafo_data[transformation._key]['stack_size'])
    )
    assert all(
        v.scope is None
        for v in FindVariables().visit(kernel2_item.trafo_data[transformation._key]['stack_size'])
    )

    #
    # A few checks on the driver
    #
    driver = scheduler['#driver'].ir

    stack_order = FindNodes(Assignment).visit(driver.body)
    if stack_insert_pragma:
        assert stack_order[0].lhs == "a"
    else:
        assert stack_order[0].lhs == "ISTSZ"

    # Check if allocation type symbols have been imported
    if frontend != OMNI:
        assert 'jpim' in driver.imported_symbols
        assert 'jplm' in driver.imported_symbols
        assert driver.import_map['jpim'] == driver.import_map['jplm']
        assert 'jprb' not in driver.import_map['jpim'].symbols

    # Has the stack been added to the call statements?
    calls = FindNodes(CallStatement).visit(driver.body)
    expected_kwarguments = (('YDSTACK_L', 'ylstack_l'), ('YDSTACK_U', 'ylstack_U'))
    if cray_ptr_loc_rhs:
        expected_kwarguments += (('ZSTACK', 'zstack(:,b)'),)
    assert len(calls) == 2
    assert calls[0].arguments == ('1', 'nlon', 'nlon', 'nz', 'field1(:,b)', 'field2(:,:,b)')
    assert calls[0].kwarguments == expected_kwarguments
    assert calls[1].arguments == ('1', 'nlon', 'nlon', 'nz', 'field2(:,:,b)')
    assert calls[1].kwarguments == expected_kwarguments

    stack_size_str = (
            f'max(ishft(7 + c_sizeof(real(1, kind={kind_real}))*nlon, -3)'
            f' + ishft(7 + c_sizeof(real(1, kind={kind_real}))*nz*nlon, -3)'
            f' + ishft(7 + 2*c_sizeof(int(1, kind={kind_int}))*nlon, -3)'
            f' + ishft(7 + c_sizeof(logical(true, kind={kind_log}))*nz, -3),'
            f' ishft(7 + 2*c_sizeof(real(1, kind={kind_real}))*nz*nlon, -3)'
            f' + ishft(7 + c_sizeof(real(1, kind={kind_real}))'
            f' + c_sizeof(real(1, kind={kind_real}))*nlon'
            f' + c_sizeof(real(1, kind={kind_real}))*nz'
            f' + c_sizeof(real(1, kind={kind_real}))*nz*nlon, -3))'        
    )
    stack_size = parse_expr(stack_size_str)

    check_stack_created_in_driver(driver, stack_size, calls[0], 2, cray_ptr_loc_rhs=cray_ptr_loc_rhs)

    # Has the data sharing been updated?
    if directive in ['openmp', 'openacc']:
        keyword = {'openmp': 'omp', 'openacc': 'acc'}[directive]
        pragmas = [
            p for p in FindNodes(Pragma).visit(driver.body)
            if p.keyword.lower() == keyword and p.content.startswith('parallel')
        ]
        assert len(pragmas) == 2
        for pragma in pragmas:
            parameters = get_pragma_parameters(pragma, starts_with='parallel', only_loki_pragmas=False)
            assert 'private' in parameters and 'ylstack' in parameters['private'].lower()
            assert not 'ylstack' in parameters['firstprivate'].lower()

    # Are there data regions for the stack?
    if directive == ['openacc']:
        pragmas = [
            p for p in FindNodes(Pragma).visit(driver.body)
            if p.keyword.lower() == 'acc' and 'data' in p.content
        ]
        assert len(pragmas) == 2
        parameters = get_pragma_parameters(pragmas[0], starts_with='data', only_loki_pragmas=False)
        assert parameters['create'] == 'zstack'

    #
    # A few checks on the kernel
    #
    for count, item in enumerate([kernel_item, kernel2_item]):
        kernel = item.ir

        # Has the stack been added to the arguments?
        assert 'ydstack_l' in kernel.arguments
        assert 'ydstack_u' in kernel.arguments

        # Is it being assigned to a local variable?
        assert 'ylstack_l' in kernel.variables
        assert 'ylstack_u' in kernel.variables

        dim1 = f"{kernel.variable_map['tmp1'].shape[0]}"
        for v in kernel.variable_map['tmp1'].shape[1:]:
            dim1 += f'*{v}'
        # tmp2 has the shape "0:klon, 0:klev"
        dim2 = kernel.variable_map['tmp2'].shape[0]
        if isinstance(dim2, RangeIndex):
            dim2 = Sum((dim2.upper, Product((-1, dim2.lower)), 1))
        for v in kernel.variable_map['tmp2'].shape[1:]:
            _dim = v
            if isinstance(_dim, RangeIndex):
                _dim = Sum((_dim.upper, Product((-1, _dim.lower)), 1))
            dim2 = Product((dim2, _dim))

        if cray_ptr_loc_rhs:
            exp_rhs_1 =  parse_expr(f'ylstack_l + ishft({dim1}*C_SIZEOF(REAL(1, kind={kind_real})) + 7, -3)')
            exp_rhs_2 =  parse_expr(f'ylstack_l + ishft({dim2}*C_SIZEOF(REAL(1, kind={kind_real})) + 7, -3)')
        else:
            exp_rhs_1 = parse_expr(f'ylstack_l + ishft(ishft({dim1}*C_SIZEOF(REAL(1, kind={kind_real})) + 7, -3), 3)')
            exp_rhs_2 = parse_expr(f'ylstack_l + ishft(ishft({dim2}*C_SIZEOF(REAL(1, kind={kind_real})) + 7, -3), 3)')

        # Let's check for the relevant "allocations" happening in the right order
        assign_idx = {}
        for idx, ass in enumerate(FindNodes(Assignment).visit(kernel.body)):

            if ass.lhs == 'ylstack_l' and ass.rhs == 'ydstack_l':
                # Local copy of stack status
                assign_idx['stack_assign'] = idx
            elif ass.lhs == 'ylstack_u' and ass.rhs == 'ydstack_u':
                # Local copy of stack status
                assign_idx['stack_assign_end'] = idx
            elif ass.lhs == 'ip_tmp1':
                # ass Cray pointer for tmp1
                assign_idx['tmp1_ptr_assign'] = idx
            elif ass.lhs == 'ip_tmp2':
                # ass Cray pointer for tmp2
                assign_idx['tmp2_ptr_assign'] = idx
            elif ass.lhs == 'ylstack_l' and 'ylstack_l' in ass.rhs and ass.rhs == exp_rhs_1:
                # Stack increment for tmp1
                assign_idx['tmp1_stack_incr'] = idx
            elif ass.lhs == 'ylstack_l' and 'ylstack_l' in ass.rhs and ass.rhs == exp_rhs_2:
                # Stack increment for tmp2
                assign_idx['tmp2_stack_incr'] = idx

        expected_assign_in_order = [
            'stack_assign', 'stack_assign_end', 'tmp1_ptr_assign', 'tmp1_stack_incr', 'tmp2_ptr_assign',
            'tmp2_stack_incr'
        ]
        assert set(expected_assign_in_order) == set(assign_idx.keys())

        for assign1, assign2 in zip(expected_assign_in_order, expected_assign_in_order[1:]):
            assert assign_idx[assign2] > assign_idx[assign1]

        # Check for pointer declarations in generated code
        fcode = kernel.to_fortran()
        assert 'pointer(ip_tmp1, tmp1)' in fcode.lower()
        assert 'pointer(ip_tmp2, tmp2)' in fcode.lower()

        # Check for stack size safegurads in generated code
        if count == 0:
            assert fcode.lower().count('if (ylstack_l > ylstack_u)') == 4
            assert fcode.lower().count('stop') == 4
        else:
            assert fcode.lower().count('if (ylstack_l > ylstack_u)') == 2
            assert fcode.lower().count('stop') == 2


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('directive', [False, 'openmp', 'openacc'])
@pytest.mark.parametrize('cray_ptr_loc_rhs', [False, True])
def test_pool_allocator_temporaries_kernel_nested(tmp_path, frontend, block_dim, directive, cray_ptr_loc_rhs):
    driver_pragma = f'!$loki loop gang{" private(b)" if directive == "openmp" else ""}'
    driver_end_pragma = '!$loki end loop gang'
    kernel_pragma = '!$loki routine vector'

    fcode_parkind_mod = """
module parkind1
implicit none
integer, parameter :: jwrb = selected_real_kind(13,300)
integer, parameter :: jpim = selected_int_kind(9)
integer, parameter :: jplm = jpim
end module parkind1
    """.strip()

    fcode_driver = f"""
subroutine driver(NLON, NZ, NB, FIELD1, FIELD2)
    use kernel_mod, only: kernel
    use parkind1, only : jpim
    implicit none
    INTEGER, PARAMETER :: JWRB = SELECTED_REAL_KIND(13,300)
    INTEGER, INTENT(IN) :: NLON, NZ, NB
    real(kind=jwrb), intent(inout) :: field1(nlon, nb)
    real(kind=jwrb), intent(inout) :: field2(nlon, nz, nb)
    integer :: b
    {driver_pragma}
    do b=1,nb
        call KERNEL(1, nlon, nlon, nz, field1(:,b), field2(:,:,b))
    end do
    {driver_end_pragma}
end subroutine driver
    """.strip()
    fcode_kernel = f"""
module kernel_mod
    implicit none
contains
    subroutine kernel(start, end, klon, klev, field1, field2)
        use parkind1, only : jpim, jplm
        implicit none
        integer, parameter :: jwrb = selected_real_kind(13,300)
        integer, intent(in) :: start, end, klon, klev
        real(kind=jwrb), intent(inout) :: field1(klon)
        real(kind=jwrb), intent(inout) :: field2(klon,klev)
        real(kind=jwrb) :: tmp1(klon)
        real(kind=jwrb) :: tmp2(klon, klev)
        integer(kind=jpim) :: tmp3(klon*2)
        logical(kind=jplm) :: tmp4(klev)
        integer :: jk, jl
        {kernel_pragma}

        do jk=1,klev
            tmp1(jl) = 0.0_jwrb
            do jl=start,end
                tmp2(jl, jk) = field2(jl, jk)
                tmp1(jl) = field2(jl, jk)
            end do
            field1(jl) = tmp1(jl)
            tmp4(jk) = .true.
        end do

        do jl=start,end
           tmp3(jl) = 1_jpim
           tmp3(jl+klon) = 1_jpim
        enddo

        call kernel2(start, end, klon, klev, field2)
    end subroutine kernel

    subroutine kernel2(start, end, columns, levels, field2)
        implicit none
        integer, parameter :: jwrb = selected_real_kind(13,300)
        integer, intent(in) :: start, end, columns, levels
        real(kind=jwrb), intent(inout) :: field2(columns,levels)
        real(kind=jwrb) :: tmp1(2*columns, levels), tmp2(columns, levels)
        integer :: jk, jl
        {kernel_pragma}

        do jk=1,levels
            do jl=start,end
                tmp1(jl, jk) = field2(jl, jk)
                tmp1(jl+columns, jk) = field2(jl, jk)*2._jwrb
                tmp2(jl, jk) = tmp1(jl, jk) + 1._jwrb
                field2(jl, jk) = tmp2(jl, jk)
            end do
        end do
    end subroutine kernel2

end module kernel_mod
    """.strip()

    (tmp_path/'driver.F90').write_text(fcode_driver)
    (tmp_path/'kernel_mod.F90').write_text(fcode_kernel)
    (tmp_path/'parkind_mod.F90').write_text(fcode_parkind_mod)

    config = {
        'default': {
            'mode': 'idem',
            'role': 'kernel',
            'expand': True,
            'strict': True,
            'enable_imports': True,
        },
        'routines': {
            # real_kind = jwrb has no (and should have no) effect anymore as
            #  'real64' as kind of the stack is baked into the recipe
            'driver': {'role': 'driver', 'real_kind': 'jwrb'}
        }
    }

    scheduler = Scheduler(
        paths=[tmp_path], config=SchedulerConfig.from_dict(config),
        frontend=frontend, xmods=[tmp_path]
    )

    transformation = TemporariesPoolAllocatorTransformation(block_dim=block_dim,
                                                            directive=directive, cray_ptr_loc_rhs=cray_ptr_loc_rhs)
    scheduler.process(transformation=transformation)
    pragma_model_trafo = PragmaModelTransformation(directive=directive)
    scheduler.process(transformation=pragma_model_trafo)

    kernel_item = scheduler['kernel_mod#kernel']
    kernel2_item = scheduler['kernel_mod#kernel2']

    # set kind comaprison string
    if frontend == OMNI:
        kind_real = 'selected_real_kind(13, 300)'
        kind_int = '4'
        kind_log = '4'
    else:
        kind_real = 'jwrb'
        kind_int = 'jpim'
        kind_log = 'jplm'

    assert transformation._key in kernel_item.trafo_data
    exp_stack_size_str = (
            f'ishft(7 + c_sizeof(real(1, kind={kind_real}))*klon, -3)'
            f' + 2*ishft(7 + c_sizeof(real(1, kind={kind_real}))*klev*klon, -3)'
            f' + ishft(7 + 2*c_sizeof(int(1, kind={kind_int}))*klon, -3)'
            f' + ishft(7 + c_sizeof(logical(true, kind={kind_log}))*klev, -3)'
            f' + ishft(7 + 2*c_sizeof(real(1, kind={kind_real}))*klev*klon, -3)'
    )
    exp_stack_size = parse_expr(exp_stack_size_str)
    assert kernel_item.trafo_data[transformation._key]['stack_size'] == exp_stack_size
    exp_stack_size_str = (
            f'ishft(7 + 2*c_sizeof(real(1, kind={kind_real}))*columns*levels, -3)'
            f' + ishft(7 + c_sizeof(real(1, kind={kind_real}))*columns*levels, -3)'
    )
    exp_stack_size = parse_expr(exp_stack_size_str)
    assert kernel2_item.trafo_data[transformation._key]['stack_size'] == exp_stack_size
    assert all(
        v.scope is None
        for v in FindVariables().visit(kernel_item.trafo_data[transformation._key]['stack_size'])
    )
    assert all(
        v.scope is None
        for v in FindVariables().visit(kernel2_item.trafo_data[transformation._key]['stack_size'])
    )

    #
    # A few checks on the driver
    #
    driver = scheduler['#driver'].ir

    # Check if allocation type symbols have been imported
    if frontend != OMNI:
        assert 'jpim' in driver.imported_symbols
        assert 'jplm' in driver.imported_symbols
        assert driver.import_map['jpim'] == driver.import_map['jplm']

    # Has the stack been added to the call statements?
    calls = FindNodes(CallStatement).visit(driver.body)
    expected_kwarguments = (('YDSTACK_L', 'ylstack_l'), ('YDSTACK_U', 'ylstack_u'))
    if cray_ptr_loc_rhs:
        expected_kwarguments += (('ZSTACK', 'zstack(:,b)'),)
    assert len(calls) == 1
    assert calls[0].arguments == ('1', 'nlon', 'nlon', 'nz', 'field1(:,b)', 'field2(:,:,b)')
    assert calls[0].kwarguments == expected_kwarguments

    stack_size_str = (
            f'ishft(7 + c_sizeof(real(1, kind={kind_real}))*nlon, -3)'
            f' + 2*ishft(7 + c_sizeof(real(1, kind={kind_real}))*nz*nlon, -3)'
            f' + ishft(7 + 2*c_sizeof(int(1, kind={kind_int}))*nlon, -3)'
            f' + ishft(7 + c_sizeof(logical(true, kind={kind_log}))*nz, -3)'
            f' + ishft(7 + 2*c_sizeof(real(1, kind={kind_real}))*nz*nlon, -3)'
    )
    stack_size = parse_expr(stack_size_str)
    check_stack_created_in_driver(
        driver, stack_size, calls[0], 1,
        cray_ptr_loc_rhs=cray_ptr_loc_rhs
    )

    # check if stack allocatable in the driver has the correct kind parameter
    if not frontend == OMNI:
        assert driver.symbol_map['zstack'].type.kind == 'real64'

    # Has the data sharing been updated?
    if directive in ['openmp', 'openacc']:
        keyword = {'openmp': 'omp', 'openacc': 'acc'}[directive]
        pragmas = [
            p for p in FindNodes(Pragma).visit(driver.body)
            if p.keyword.lower() == keyword and p.content.startswith('parallel')
        ]
        assert len(pragmas) == 1
        for pragma in pragmas:
            parameters = get_pragma_parameters(pragma, starts_with='parallel', only_loki_pragmas=False)
            assert 'private' in parameters and 'ylstack' in parameters['private'].lower()
            if directive == 'openmp':
                assert 'b' in parameters['private']

    # Are there data regions for the stack?
    if directive == ['openacc']:
        pragmas = [
            p for p in FindNodes(Pragma).visit(driver.body)
            if p.keyword.lower() == 'acc' and 'data' in p.content
        ]
        assert len(pragmas) == 2
        parameters = get_pragma_parameters(pragmas[0], starts_with='data', only_loki_pragmas=False)
        assert parameters['create'] == 'zstack'

    #
    # A few checks on the kernels
    #
    calls = FindNodes(CallStatement).visit(kernel_item.ir.body)
    expected_kwarguments = (('YDSTACK_L', 'ylstack_l'), ('YDSTACK_U', 'ylstack_u'))
    if cray_ptr_loc_rhs:
        expected_kwarguments += (('ZSTACK', 'zstack'),)
    assert len(calls) == 1
    assert calls[0].arguments == ('start', 'end', 'klon', 'klev', 'field2')
    assert calls[0].kwarguments == expected_kwarguments

    for count, item in enumerate([kernel_item, kernel2_item]):
        kernel = item.ir

        # Has the stack been added to the arguments?
        assert 'ydstack_l' in kernel.arguments
        assert 'ydstack_u' in kernel.arguments

        # Is it being assigned to a local variable?
        assert 'ylstack_l' in kernel.variables
        assert 'ylstack_u' in kernel.variables

        dim1 = f"{kernel.variable_map['tmp1'].shape[0]}"
        for v in kernel.variable_map['tmp1'].shape[1:]:
            dim1 += f'*{v}'
        dim2 = f"{kernel.variable_map['tmp2'].shape[0]}"
        for v in kernel.variable_map['tmp2'].shape[1:]:
            dim2 += f'*{v}'

        if cray_ptr_loc_rhs:
            exp_rhs_1 = parse_expr(f'ylstack_l + ishft({dim1}*c_sizeof(real(1, kind={kind_real})) + 7, -3)')
        else:
            exp_rhs_1 = parse_expr(f'ylstack_l + ishft(ishft({dim1}*c_sizeof(real(1, kind={kind_real})) + 7, -3), 3)')
        if cray_ptr_loc_rhs:
            exp_rhs_2 = parse_expr(f'ylstack_l + ishft({dim2}*c_sizeof(real(1, kind={kind_real})) + 7, -3)')
        else:
            exp_rhs_2 = parse_expr(f'ylstack_l + ishft(ishft({dim2}*c_sizeof(real(1, kind={kind_real})) + 7, -3), 3)')

        # Let's check for the relevant "allocations" happening in the right order
        assign_idx = {}
        for idx, ass in enumerate(FindNodes(Assignment).visit(kernel.body)):

            if ass.lhs == 'ylstack_l' and ass.rhs == 'ydstack_l':
                # Local copy of stack status
                assign_idx['stack_assign'] = idx
            if ass.lhs == 'ylstack_u' and ass.rhs == 'ydstack_u':
                # Local copy of stack status
                assign_idx['stack_assign_end'] = idx
            elif ass.lhs == 'ip_tmp1': #  and ass.rhs == 'ylstack_l':
                # ass Cray pointer for tmp1
                assign_idx['tmp1_ptr_assign'] = idx
            elif ass.lhs == 'ip_tmp2': #  and ass.rhs == 'ylstack_l':
                # ass Cray pointer for tmp2
                assign_idx['tmp2_ptr_assign'] = idx
            elif ass.lhs == 'ylstack_l' and 'ylstack_l' in ass.rhs and simplify(ass.rhs) == simplify(exp_rhs_1):
                # Stack increment for tmp1
                assign_idx['tmp1_stack_incr'] = idx
            elif ass.lhs == 'ylstack_l' and 'ylstack_l' in ass.rhs and simplify(ass.rhs) == simplify(exp_rhs_2):
                # Stack increment for tmp2
                assign_idx['tmp2_stack_incr'] = idx

        expected_assign_in_order = [
            'stack_assign', 'stack_assign_end', 'tmp1_ptr_assign', 'tmp1_stack_incr', 'tmp2_ptr_assign',
            'tmp2_stack_incr'
        ]
        assert set(expected_assign_in_order) == set(assign_idx.keys())

        for assign1, assign2 in zip(expected_assign_in_order, expected_assign_in_order[1:]):
            assert assign_idx[assign2] > assign_idx[assign1]

        # Check for pointer declarations in generated code
        fcode = kernel.to_fortran()
        assert 'pointer(ip_tmp1, tmp1)' in fcode.lower()
        assert 'pointer(ip_tmp2, tmp2)' in fcode.lower()

        # Check for stack size safegurads in generated code
        if count == 0:
            assert fcode.lower().count('if (ylstack_l > ylstack_u)') == 4
            assert fcode.lower().count('stop') == 4
        else:
            assert fcode.lower().count('if (ylstack_l > ylstack_u)') == 2
            assert fcode.lower().count('stop') == 2


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('cray_ptr_loc_rhs', [False, True])
def test_pool_allocator_more_call_checks(tmp_path, frontend, block_dim, caplog, cray_ptr_loc_rhs, horizontal):
    fcode = """
    module kernel_mod
      type point
         real :: x
         real :: y
         real :: z
      end type point
    contains
      real function inline_kernel(jl)
          integer, intent(in) :: jl
      end function inline_kernel
      subroutine optional_arg(klon, temp1, temp2)
          integer, intent(in) :: klon
          real, intent(inout) :: temp1
          real, intent(out), optional :: temp2
      end subroutine optional_arg
      subroutine kernel(start, end, klon, field1)
          implicit none

          interface
             real function inline_kernel(jl)
                 integer, intent(in) :: jl
             end function inline_kernel
          end interface

          integer, intent(in) :: start, end, klon
          real, intent(inout) :: field1(klon)
          real :: temp1(klon)
          real :: temp2(klon)
          type(point) :: p(klon)

          integer :: jl

          do jl=start,end
              field1(jl) = inline_kernel(jl)
              p(jl)%x = 0.
              p(jl)%y = 0.
              p(jl)%z = 0.
          end do

          call optional_arg(klon, temp1, temp2)
      end subroutine kernel
    end module kernel_mod
    """.strip()

    (tmp_path/'kernel.F90').write_text(fcode)

    config = {
        'default': {
            'mode': 'idem',
            'role': 'kernel',
            'expand': True,
            'strict': True,
            'enable_imports': True,
        },
        'routines': {
            'kernel': {}
        }
    }
    scheduler = Scheduler(
        paths=[tmp_path], config=SchedulerConfig.from_dict(config), frontend=frontend, xmods=[tmp_path]
    )

    transformation = TemporariesPoolAllocatorTransformation(block_dim=block_dim, horizontal=horizontal,
                                                            cray_ptr_loc_rhs=cray_ptr_loc_rhs)
    scheduler.process(transformation=transformation)
    pragma_model_trafo = PragmaModelTransformation()
    scheduler.process(transformation=pragma_model_trafo)

    item = scheduler['kernel_mod#kernel']
    kernel = item.ir

    # Has the stack been added to the arguments?
    assert 'ydstack_l' in kernel.arguments
    assert 'ydstack_u' in kernel.arguments

    # Is it being assigned to a local variable?
    assert 'ylstack_l' in kernel.variables
    assert 'ylstack_u' in kernel.variables

    # Has the stack been added to the call statement at the correct location?
    calls = FindNodes(CallStatement).visit(kernel.body)
    expected_kwarguments = (('YDSTACK_L', 'ylstack_l'), ('YDSTACK_U', 'ylstack_u'))
    if cray_ptr_loc_rhs:
        expected_kwarguments += (('ZSTACK', 'zstack'),)
    assert len(calls) == 1
    assert calls[0].arguments == ('klon', 'temp1', 'temp2')
    assert calls[0].kwarguments == expected_kwarguments

    # Now repeat the checks for the inline call
    calls = [i for i in FindInlineCalls().visit(kernel.body) if not i.name.lower() in ('max', 'c_sizeof', 'real')]
    # filter out ishft inline calls being part of the stack size calculation
    calls = [call for call in calls if str(call.name).lower() != 'ishft']
    if cray_ptr_loc_rhs:
        assert len(calls) == 2
        if calls[0].name == 'inline_kernel':
            relevant_call = calls[0]
        else:
            relevant_call = calls[1]
    else:
        assert len(calls) == 1
        relevant_call = calls[0]
    assert relevant_call.arguments == ('jl',)
    assert relevant_call.kwarguments == expected_kwarguments

    assert 'Derived-type vars in Subroutine:: kernel not supported in pool allocator' in caplog.text


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('cray_ptr_loc_rhs', [False, True])
def test_pool_allocator_args_vs_kwargs(tmp_path, frontend, block_dim_alt, cray_ptr_loc_rhs, horizontal):
    fcode_parkind_mod = """
module parkind1
implicit none
integer, parameter :: jwrb = selected_real_kind(13,300)
integer, parameter :: jpim = selected_int_kind(9)
integer, parameter :: jplm = jpim
end module parkind1
    """.strip()

    fcode_module = """
module geom_mod
    implicit none
    type dim_type
       integer :: nb
    end type dim_type

    type geom_type
       type(dim_type) :: blk_dim
    end type geom_type

    integer :: n
end module geom_mod
"""

    fcode_driver = """
subroutine driver(NLON, NZ, GEOM, FIELD1, FIELD2)
    use kernel_mod, only: kernel, kernel2
    use parkind1, only : jpim
    use geom_mod, only : geom_type
    implicit none
    INTEGER, PARAMETER :: JWRB = SELECTED_REAL_KIND(13,300)
    INTEGER, INTENT(IN) :: NLON, NZ
    type(geom_type), intent(in) :: geom
    real(kind=jwrb), intent(inout) :: field1(nlon, geom%blk_dim%nb)
    real(kind=jwrb), intent(inout) :: field2(nlon, nz, geom%blk_dim%nb)
    integer :: b
    real(kind=jwrb) :: opt
    do b=1,geom%blk_dim%nb
        call KERNEL(start=1, end=nlon, klon=nlon, klev=nz, field1=field1(:,b), field2=field2(:,:,b))
        call KERNEL2(1, nlon, nlon, nz, field2=field2(:,:,b))
        call KERNEL2(1, nlon, nlon, nz, field2(:,:,b))
        call KERNEL2(1, nlon, nlon, nz, field2=field2(:,:,b), opt_arg=opt)
        call KERNEL2(1, nlon, nlon, nz, field2(:,:,b), opt)
    end do
end subroutine driver
    """.strip()

    fcode_kernel = """
module kernel_mod
    implicit none
contains
    subroutine kernel(start, end, klon, klev, field1, field2)
        use parkind1, only : jpim, jplm
        use geom_mod, only : n
        implicit none
        integer, parameter :: jwrb = selected_real_kind(13,300)
        integer, intent(in) :: start, end, klon, klev
        real(kind=jwrb), intent(inout) :: field1(klon)
        real(kind=jwrb), intent(inout) :: field2(klon,klev)
        real(kind=jwrb) :: tmp1(klon)
        real(kind=jwrb) :: tmp2(klon, klev)
        integer(kind=jpim) :: tmp3(klon*2)
        logical(kind=jplm) :: tmp4(klev)
        logical(kind=jplm) :: tmp5(klev,n)
        integer :: jk, jl

        do jk=1,klev
            tmp1(jl) = 0.0_jwrb
            do jl=start,end
                tmp2(jl, jk) = field2(jl, jk)
                tmp1(jl) = field2(jl, jk)
            end do
            field1(jl) = tmp1(jl)
            tmp4(jk) = .true.
            tmp5(jk,1:n) = .true.
        end do

        do jl=start,end
           tmp3(jl) = 1_jpim
           tmp3(jl+klon) = 1_jpim
        enddo

        call kernel2(start, end, klon, klev, field2)
    end subroutine kernel
    subroutine kernel2(start, end, columns, levels, field2, opt_arg)
        implicit none
        integer, parameter :: jwrb = selected_real_kind(13,300)
        integer, intent(in) :: start, end, columns, levels
        real(kind=jwrb), intent(inout) :: field2(columns,levels)
        real(kind=jwrb) :: tmp1(2*columns, levels), tmp2(columns, levels)
        real(kind=jwrb), optional :: opt_arg
        integer :: jk, jl

        do jk=1,levels
            do jl=start,end
                tmp1(jl, jk) = field2(jl, jk)
                tmp1(jl+columns, jk) = field2(jl, jk)*2._jwrb
                tmp2(jl, jk) = tmp1(jl, jk) + 1._jwrb
                field2(jl, jk) = tmp2(jl, jk)
            end do
        end do
    end subroutine kernel2

end module kernel_mod
    """.strip()

    (tmp_path / 'parkind1.F90').write_text(fcode_parkind_mod)
    (tmp_path / 'driver.F90').write_text(fcode_driver)
    (tmp_path / 'kernel.F90').write_text(fcode_kernel)
    (tmp_path / 'module.F90').write_text(fcode_module)

    config = {
        'default': {
            'mode': 'idem',
            'role': 'kernel',
            'expand': True,
            'strict': True,
            'ignore': ['parkind1'],
            'enable_imports': True,
        },
        'routines': {
            'driver': {'role': 'driver'}
        }
    }
    scheduler = Scheduler(
        paths=[tmp_path], config=SchedulerConfig.from_dict(config),
        frontend=frontend, xmods=[tmp_path]
    )

    transformation = TemporariesPoolAllocatorTransformation(block_dim=block_dim_alt, horizontal=horizontal,
                                                            cray_ptr_loc_rhs=cray_ptr_loc_rhs)
    scheduler.process(transformation=transformation)
    pragma_model_trafo = PragmaModelTransformation()
    scheduler.process(transformation=pragma_model_trafo)

    kernel = scheduler['kernel_mod#kernel'].ir
    kernel2 = scheduler['kernel_mod#kernel2'].ir
    driver = scheduler['#driver'].ir

    assert 'ydstack_l' in kernel.arguments
    assert 'ydstack_u' in kernel.arguments
    assert 'ydstack_l' in kernel2.arguments
    assert 'ydstack_u' in kernel2.arguments

    calls = FindNodes(CallStatement).visit(driver.body)
    additional_kwargs = (('ZSTACK', 'zstack(:,b)'),) if cray_ptr_loc_rhs else ()
    assert calls[0].arguments == ()
    assert calls[0].kwarguments == (
        ('start', 1), ('end', 'nlon'), ('klon', 'nlon'), ('klev', 'nz'),
        ('field1', 'field1(:, b)'), ('field2', 'field2(:, :, b)'),
        ('YDSTACK_L', 'YLSTACK_L'), ('YDSTACK_U', 'YLSTACK_U')
    ) + additional_kwargs
    assert calls[1].arguments == ('1', 'nlon', 'nlon', 'nz')
    assert calls[1].kwarguments == (
        ('field2', 'field2(:, :, b)'), ('YDSTACK_L', 'YLSTACK_L'), ('YDSTACK_U', 'YLSTACK_U')
    ) + additional_kwargs
    assert calls[2].arguments == ('1', 'nlon', 'nlon', 'nz', 'field2(:, :, b)')
    assert calls[2].kwarguments == (
            ('YDSTACK_L', 'YLSTACK_L'), ('YDSTACK_U', 'YLSTACK_U')
    ) + additional_kwargs
    assert calls[3].arguments == ('1', 'nlon', 'nlon', 'nz')
    assert calls[3].kwarguments == (
        ('field2', 'field2(:, :, b)'), ('opt_arg', 'opt'),
        ('YDSTACK_L', 'YLSTACK_L'), ('YDSTACK_U', 'YLSTACK_U')
    ) + additional_kwargs
    assert calls[4].arguments == ('1', 'nlon', 'nlon', 'nz', 'field2(:, :, b)', 'opt')
    assert calls[2].arguments == ('1', 'nlon', 'nlon', 'nz', 'field2(:, :, b)')
    assert calls[2].kwarguments == (
            ('YDSTACK_L', 'YLSTACK_L'), ('YDSTACK_U', 'YLSTACK_U')
    ) + additional_kwargs
    assert calls[3].arguments == ('1', 'nlon', 'nlon', 'nz')
    assert calls[3].kwarguments == (
        ('field2', 'field2(:, :, b)'), ('opt_arg', 'opt'),
        ('YDSTACK_L', 'YLSTACK_L'), ('YDSTACK_U', 'YLSTACK_U')
    ) + additional_kwargs
    assert calls[4].arguments == ('1', 'nlon', 'nlon', 'nz', 'field2(:, :, b)', 'opt')
    assert calls[4].kwarguments == (
            ('YDSTACK_L', 'YLSTACK_L'), ('YDSTACK_U', 'YLSTACK_U')
    ) + additional_kwargs

    # check stack size allocation
    allocations = FindNodes(Allocation).visit(driver.body)
    assert len(allocations) == 1 and 'zstack(istsz,geom%blk_dim%nb)' in allocations[0].variables

    # check that array size was imported to the driver
    assert 'n' in driver.imported_symbols
loki-ecmwf-0.3.6/loki/transformations/temporaries/hoist_variables.py0000664000175000017500000005570515167130205026226 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Multiple **transformations to hoist variables** especially to hoist temporary arrays.

E.g., the following source code

.. code-block:: fortran

    subroutine driver(...)
        integer :: a
        a = 10
        call kernel(a)
    end subroutine driver

    subroutine kernel(a)
        integer, intent(in) :: a
        real :: array(a)
        ...
    end subroutine kernel

can be transformed/hoisted to

.. code-block:: fortran

    subroutine driver(...)
        integer :: a
        real :: kernel_array(a)
        a = 10
        call kernel(a, kernel_array)
    end subroutine driver

    subroutine kernel(a, array)
        integer, intent(in) :: a
        real, intent(inout) :: array(a)
        ...
    end subroutine kernel

using

.. code-block:: python

    # Transformation: Analysis
    scheduler.process(transformation=HoistTemporaryArraysAnalysis())
    # Transformation: Synthesis
    scheduler.process(transformation=HoistVariablesTransformation())


To achieve this two transformation are necessary, whereas the first one is responsible for the *Analysis* and the
second one for the *Synthesis*. Two base classes

* :class:`.HoistVariablesAnalysis` - *Analysis* part, to be processed in reverse
    * specialise/implement :func:`find_variables`
* :class:`.HoistVariablesTransformation`- *Synthesis* part
    * specialise/implement :func:`driver_variable_declaration`

are provided to create derived classes for specialisation of the actual hoisting.

.. warning::
    :class:`.HoistVariablesAnalysis` ensures that all local variables are hoisted!
    Please consider using a specialised class like :class:`.HoistTemporaryArraysAnalysis` or create a derived class
    yourself.

.. note::
    If several of these transformations are carried out in succession, provide a unique ``key`` for each corresponding
    *Analysis* and *Synthesis* step!

    .. code-block:: python

        key = "UniqueKey"
        scheduler.process(transformation=HoistTemporaryArraysAnalysis(dim_vars=('b',), key=key))
        scheduler.process(transformation=HoistTemporaryArraysTransformation(key=key))
        key = "AnotherUniqueKey"
        scheduler.process(transformation=HoistTemporaryArraysAnalysis(dim_vars=('a',), key=key))
        scheduler.process(transformation=HoistTemporaryArraysTransformationAllocatable(key=key))
"""

from collections import defaultdict

from loki.batch import Transformation, ProcedureItem
from loki.expression import symbols as sym, is_dimension_constant
from loki.ir import (
    CallStatement, Allocation, Deallocation, Transformer, FindNodes, Comment, Import,
    Assignment, FindVariables, FindInlineCalls, SubstituteExpressions
)
from loki.tools.util import (
    is_iterable, as_tuple, CaseInsensitiveDict, flatten, OrderedSet
)
from loki.types import BasicType
from loki.logging import warning

from loki.transformations.utilities import single_variable_declaration


__all__ = [
    'HoistVariablesAnalysis', 'HoistVariablesTransformation',
    'HoistTemporaryArraysAnalysis', 'HoistTemporaryArraysTransformationAllocatable'
]


class HoistVariablesAnalysis(Transformation):
    """
    **Base class** for the *Analysis* part of the hoist variables functionality/transformation.

    Traverses all subroutines to find the variables to be hoisted.
    Create a derived class and override :func:`find_variables`
    to define which variables to be hoisted.
    """

    _key = 'HoistVariablesTransformation'

    # Apply in reverse order to recursively find all variables to be hoisted.
    reverse_traversal = True

    process_ignored_items = True

    def transform_subroutine(self, routine, **kwargs):
        """
        Analysis applied to :any:`Subroutine` item.

        Collects all the variables to be hoisted, including renaming
        in order to grant for unique variable names.

        Parameters
        ----------
        routine : :any:`Subroutine`
            The subroutine to be transformed.
        **kwargs : optional
            Keyword arguments for the transformation.
        """

        role = kwargs.get('role', None)
        item = kwargs.get('item', None)
        sub_sgraph = kwargs.get('sub_sgraph', None)
        successors = as_tuple(sub_sgraph.successors(item)) if sub_sgraph is not None else ()

        item.trafo_data[self._key] = {}

        if role != 'driver':
            variables = self.find_variables(routine)
            item.trafo_data[self._key]["to_hoist"] = variables
            dims = flatten([getattr(v, 'shape', []) for v in variables])
            kinds = [v.type.kind for v in variables if v.type.kind]
            import_map = routine.import_map
            item.trafo_data[self._key]["imported_sizes"] = [(d.type.module.name, d) for d in dims
                                                            if str(d) in import_map]
            item.trafo_data[self._key]["imported_kinds"] = [(import_map[k].module, k) for k in kinds
                                                             if k.name in import_map]
            item.trafo_data[self._key]["hoist_variables"] = [var.clone(name=f'{routine.name}_{var.name}')
                                                             for var in variables]
        else:
            item.trafo_data[self._key]["imported_sizes"] = []
            item.trafo_data[self._key]["imported_kinds"] = []
            item.trafo_data[self._key]["to_hoist"] = []
            item.trafo_data[self._key]["hoist_variables"] = []

        calls = FindNodes(CallStatement).visit(routine.body)
        calls += FindInlineCalls().visit(routine.body)
        call_map = CaseInsensitiveDict((str(call.name), call) for call in calls)

        for child in successors:
            if not isinstance(child, ProcedureItem):
                continue

            if call_map[child.local_name].routine is BasicType.DEFERRED:
                warning((
                    '[Loki::HoistVariablesAnalysis] '
                    f''
                    f'call.routine is BasicType.DEFERRED for call to {child.local_name} in {routine.name}'
                ))
                continue

            # We may call a subroutine again with aliased sizes, so we check hoisted
            # variables in children by name before adding them
            hoist_var_names = [v.name.lower() for v in item.trafo_data[self._key]["hoist_variables"]]

            arg_map = dict(call_map[child.local_name].arg_iter())
            hoist_variables = []
            for var in child.trafo_data[self._key]["hoist_variables"]:
                if var.name.lower() in hoist_var_names:
                    continue

                if isinstance(var, sym.Array):
                    dimensions = SubstituteExpressions(arg_map).visit(var.dimensions)
                    hoist_variables.append(var.clone(dimensions=dimensions, type=var.type.clone(shape=dimensions)))
                else:
                    hoist_variables.append(var)
            item.trafo_data[self._key]["to_hoist"].extend(hoist_variables)
            item.trafo_data[self._key]["hoist_variables"].extend(hoist_variables)
            item.trafo_data[self._key]["imported_sizes"] += child.trafo_data[self._key]["imported_sizes"]
            item.trafo_data[self._key]["imported_kinds"] += child.trafo_data[self._key]["imported_kinds"]

    def find_variables(self, routine):
        """
        **Override**: Find/Select all the variables to be hoisted.

        Selects all local variables that are not ``parameter`` to be hoisted.

        Parameters
        ----------
        routine : :any:`Subroutine`
            The subroutine find the variables.
        """
        return [var for var in routine.variables if var not in routine.arguments if not var.type.parameter]


class HoistVariablesTransformation(Transformation):
    """
    **Base class** for the *Synthesis* part of the hoist variables functionality/transformation.

    Traverses all subroutines to hoist the variables.
    Create a derived class and override :func:`find_variables`
    to define which variables to be hoisted.

    .. note::
        Needs the *Analysis* part to be processed first in order to hoist all already found variables.

    Parameters
    ----------
    as_kwarguments : boolean
        Whether to pass the hoisted arguments as `args` or `kwargs`.
    remap_dimensions : boolean
        Remap dimensions based on variables that are used for initializing
        other variables that could end up as dimensions for hoisted arrays.
        Thus, account for possibly uninitialized variables used as dimensions.
    """

    _key = 'HoistVariablesTransformation'

    def __init__(self, as_kwarguments=False, remap_dimensions=True):
        self.as_kwarguments = as_kwarguments
        self.remap_dimensions = remap_dimensions

    def transform_subroutine(self, routine, **kwargs):
        """
        Transformation applied to :any:`Subroutine` item.

        Hoists all to be hoisted variables which includes

        * appending the arguments for each subroutine
        * appending the arguments for each subroutine call
        * modifying the variable declaration in the subroutine
        * adding the variable declaration in the driver

        Parameters
        ----------
        routine : :any:`Subroutine`
            The subroutine to be transformed.
        **kwargs : optional
            Keyword arguments for the transformation.
        """
        role = kwargs.get('role', None)
        item = kwargs.get('item', None)
        sub_sgraph = kwargs.get('sub_sgraph', None)
        successors = as_tuple(sub_sgraph.successors(item)) if sub_sgraph is not None else ()

        successor_map = CaseInsensitiveDict(
            (successor.local_name, successor) for successor in successors
        )

        if self._key not in item.trafo_data:
            raise RuntimeError(f'{self.__class__.__name__} requires key "{self._key}" in item.trafo_data!\n'
                               f'Make sure to call HoistVariablesAnalysis (or any derived class) before and to provide '
                               f'the correct key.')

        if role == 'driver':
            if self.remap_dimensions:
                to_hoist = self.driver_variable_declaration_dim_remapping(routine,
                        item.trafo_data[self._key]["to_hoist"])
            else:
                to_hoist = item.trafo_data[self._key]["to_hoist"]
            self.driver_variable_declaration(routine, to_hoist)
        else:
            # We build the list of temporaries that are hoisted to the calling routine
            # Because this requires adding an intent, we need to make sure they are not
            # declared together with non-hoisted variables
            hoisted_temporaries = tuple(
                var.clone(type=var.type.clone(intent='inout'), scope=routine)
                for var in item.trafo_data[self._key]['to_hoist']
            )
            single_variable_declaration(routine, variables=[var.clone(dimensions=None) for var in hoisted_temporaries])
            routine.arguments += hoisted_temporaries

        call_map = {}
        for call in FindNodes(CallStatement).visit(routine.body) + list(FindInlineCalls().visit(routine.body)):
            # Only process calls in this call tree
            if str(call.name) not in successor_map:
                continue

            if call.routine is BasicType.DEFERRED:
                warning((
                    '[Loki::HoistVariablesTransformation] '
                    f''
                    f'call.routine is BasicType.DEFERRED for call to {call.name} in {routine.name}'
                ))
                continue

            successor_item = successor_map[str(call.routine.name)]
            if self.as_kwarguments:
                to_hoist = successor_item.trafo_data[self._key]["to_hoist"]
                _hoisted_variables = successor_item.trafo_data[self._key]["hoist_variables"]
                hoisted_variables = zip(to_hoist, _hoisted_variables)
            else:
                hoisted_variables = successor_item.trafo_data[self._key]["hoist_variables"]
            if role == "driver":
                call_map[call] = self.driver_call_argument_remapping(
                    routine=routine, call=call, variables=hoisted_variables
                )
            elif role == "kernel":
                if isinstance(call, CallStatement):
                    call_map[call] = self.kernel_call_argument_remapping(
                        routine=routine, call=call, variables=hoisted_variables
                    )
                else:
                    self.kernel_inline_call_argument_remapping(
                        routine=routine, call=call, variables=hoisted_variables
                    )

        # Add imports used to define hoisted
        missing_imports_map = defaultdict(OrderedSet)
        import_map = routine.import_map
        for module, var in item.trafo_data[self._key]["imported_sizes"]:
            if not var.name in import_map:
                missing_imports_map[module] |= {var}
        for module, var in item.trafo_data[self._key]["imported_kinds"]:
            if not var.name in import_map:
                missing_imports_map[module] |= {var}

        if missing_imports_map:
            routine.spec.prepend(Comment(text=(
                '![Loki::HoistVariablesTransformation] ---------------------------------------'
            )))
            for module, variables in missing_imports_map.items():
                routine.spec.prepend(Import(module=module, symbols=variables))

            routine.spec.prepend(Comment(text=(
                '![Loki::HoistVariablesTransformation] '
                '-------- Added hoisted temporary size and kind imports -------------------------------'
            )))

        routine.body = Transformer(call_map).visit(routine.body)

    def driver_variable_declaration(self, routine, variables):
        """
        **Override**: Define the variable declaration (and possibly
        allocation, de-allocation, ...)  for each variable to be
        hoisted.

        Declares hoisted variables with a re-scope.

        Parameters
        ----------
        routine : :any:`Subroutine`
            The subroutine to add the variable declaration to.
        variables : tuple of :any:`Variable`
            The tuple of variables to be declared.
        """
        routine.variables += tuple(v.rescope(routine) for v in variables)

    @staticmethod
    def driver_variable_declaration_dim_remapping(routine, variables):
        """
        Take a list of variables and remap their dimensions for those being
        arrays to account for possibly uninitialized variables/dimensions.

        Parameters
        ----------
        routine : :any:`Subroutine`
            The relevant subroutine.
        variables : tuple of :any:`Variable`
            The tuple of variables for remapping.
        """
        dim_vars = [
            dim_var
            for var in variables if isinstance(var, sym.Array)
            for dim_var in FindVariables().visit(var.dimensions)
        ]
        dim_map = {
            assignment.lhs: assignment.rhs
            for assignment in FindNodes(Assignment).visit(routine.body)
            if assignment.lhs in dim_vars
        }
        variables = [var.clone(dimensions=SubstituteExpressions(dim_map).visit(var.dimensions))
                if isinstance(var, sym.Array) else var for var in variables]
        return variables

    def driver_call_argument_remapping(self, routine, call, variables):
        """
        Callback method to re-map hoisted arguments for the driver-level routine.

        The callback will simply add all the hoisted variable arrays to the call
        without dimension range symbols.

        This callback is used to adjust the argument variable mapping, so that
        the call signature in the driver can be adjusted to the declaration
        scheme of subclassed variants of the basic hoisting tnansformation.
        Potentially, different variants of the hoist transformation can override
        the behaviour here to map to a differnt call invocation scheme.

        Parameters
        ----------
        routine : :any:`Subroutine`
            The subroutine to add the variable declaration to.
        call : :any:`CallStatement`
            Call object to which hoisted variables will be added.
        variables : tuple of :any:`Variable`
            The tuple of variables to be declared.
        as_kwarguments : boolean
            Whether to pass the hoisted arguments as `args` or `kwargs`.
        """
        # pylint: disable=unused-argument
        if self.as_kwarguments:
            new_kwargs = tuple((a.name, v.clone(dimensions=None)) for (a, v) in variables)
            kwarguments = call.kwarguments if call.kwarguments is not None else ()
            return call.clone(kwarguments=kwarguments + new_kwargs)
        new_args = tuple(v.clone(dimensions=None) for v in variables)
        return call.clone(arguments=call.arguments + new_args)

    def kernel_call_argument_remapping(self, routine, call, variables):
        """
        Callback method to re-map hoisted arguments in kernel-to-kernel calls.

        The callback will simply add all the hoisted variable arrays to the call
        without dimension range symbols.
        This callback is used to adjust the argument variable mapping, so that
        the call signature can be adjusted to the declaration
        scheme of subclassed variants of the basic hoisting transformation.
        Potentially, different variants of the hoist transformation can override
        the behaviour here to map to a different call invocation scheme.

        Parameters
        ----------
        routine : :any:`Subroutine`
            The subroutine to add the variable declaration to.
        call : :any:`CallStatement`
            Call object to which hoisted variables will be added.
        variables : tuple of :any:`Variable`
            The tuple of variables to be declared.
        """
        # pylint: disable=unused-argument
        if self.as_kwarguments:
            new_kwargs = tuple((a.name, v.clone(dimensions=None)) for (a, v) in variables)
            kwarguments = call.kwarguments if call.kwarguments is not None else ()
            return call.clone(kwarguments=kwarguments + new_kwargs)
        new_args = tuple(v.clone(dimensions=None) for v in variables)
        return call.clone(arguments=call.arguments + new_args)

    def kernel_inline_call_argument_remapping(self, routine, call, variables):
        """
        Append hoisted temporaries to inline function call arguments.

        Parameters
        ----------
        routine : :any:`Subroutine`
            The subroutine to add the variable declaration to.
        call : :any:`InlineCall`
            ProcedureSymbol to which hoisted variables will be added.
        variables : tuple of :any:`Variable`
            The tuple of variables to be declared.
        """

        if self.as_kwarguments:
            kw_params = call.kw_parameters
            kw_params.update(dict((a.name, v.clone(dimensions=None)) for (a, v) in variables))
            _call_clone = call.clone(kw_parameters=kw_params)
            vmap = {call: _call_clone}
        else:
            new_args = tuple(v.clone(dimensions=None) for v in variables)
            vmap = {call: call.clone(parameters=call.parameters + new_args)}

        routine.body = SubstituteExpressions(vmap).visit(routine.body)

class HoistTemporaryArraysAnalysis(HoistVariablesAnalysis):
    """
    **Specialisation** for the *Analysis* part of the hoist variables
    functionality/transformation, to hoist only temporary arrays and
    if provided only temporary arrays with specific variables/variable
    names within the array dimensions.

    .. code-block::python

        scheduler.process(transformation=HoistTemporaryArraysAnalysis(dim_vars=('a',)), reverse=True)
        scheduler.process(transformation=HoistVariablesTransformation())

    Parameters
    ----------
    dim_vars: tuple of str, optional
        Variables to be within the dimensions of the arrays to be
        hoisted. If not provided, no checks will be done for the array
        dimensions.
    """

    # Apply in reverse order to recursively find all variables to be hoisted.
    reverse_traversal = True

    def __init__(self, dim_vars=None):
        self.dim_vars = dim_vars
        if self.dim_vars is not None:
            assert is_iterable(self.dim_vars)

    def find_variables(self, routine):
        """
        Selects temporary arrays to be hoisted.

        * if ``dim_vars`` is ``None`` (default) all temporary arrays will be hoisted
        * if ``dim_vars`` is defined, all arrays with the corresponding dimensions will be hoisted

        Parameters
        ----------
        routine : :any:`Subroutine`
            The subroutine find the variables.
        """

        # Determine function result variable name
        result_name = routine.result_name if routine.is_function else ''

        variables = [var for var in routine.variables if isinstance(var, sym.Array)]
        return [var for var in variables
                if var not in routine.arguments    # local variable
                and not all(is_dimension_constant(d) for d in var.shape)
                and not var.name.lower() == result_name.lower()
                and not var.type.pointer and not var.type.allocatable
                and (self.dim_vars is None         # if dim_vars not empty check if at least one dim is within dim_vars
                     or any(dim_var in self.dim_vars for dim_var in FindVariables().visit(var.dimensions)))]


class HoistTemporaryArraysTransformationAllocatable(HoistVariablesTransformation):
    """
    **Specialisation** for the *Synthesis* part of the hoist variables
    functionality/transformation, to hoist temporary arrays and make
    them ``allocatable``, including the actual *allocation* and
    *de-allocation*.
    """

    def driver_variable_declaration(self, routine, variables):
        """
        Declares hoisted arrays as ``allocatable``, including *allocation* and *de-allocation*.

        Parameters
        ----------
        routine : :any:`Subroutine`
            The subroutine to add the variable declaration to.
        variables : tuple of :any:`Variable`
            The array to be declared, allocated and de-allocated.
        """
        for var in variables:
            routine.variables += as_tuple(
                var.clone(
                    dimensions=as_tuple([sym.RangeIndex((None, None))] * len(var.dimensions)),
                    type=var.type.clone(allocatable=True), scope=routine
                )
            )
            routine.body.prepend(Allocation((var.clone(),)))
            routine.body.append(Deallocation((var.clone(dimensions=None),)))
loki-ecmwf-0.3.6/loki/transformations/temporaries/raw_stack_allocator.py0000664000175000017500000010227015167130205027054 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import re

from loki.analyse import dataflow_analysis_attached
from loki.backend.fgen import fgen
from loki.batch.item import ProcedureItem
from loki.batch.transformation import Transformation
from loki.expression.symbols import (
    Array, Scalar, Variable, Literal, Product, Sum, InlineCall,
    IntLiteral, RangeIndex, DeferredTypeSymbol
)
from loki.expression.symbolic import is_dimension_constant, simplify
from loki.expression.mappers import DetachScopesMapper
from loki.ir.expr_visitors import FindVariables, SubstituteExpressions
from loki.ir.nodes import Assignment, CallStatement, Pragma
from loki.ir.find import FindNodes
from loki.ir.transformer import Transformer
from loki.tools import as_tuple, OrderedSet
from loki.types import BasicType, SymbolAttributes


__all__ = ['TemporariesRawStackTransformation']

one = IntLiteral(1)


class TemporariesRawStackTransformation(Transformation):
    """
    Transformation to inject stack arrays at the driver level. These, as well
    as corresponding sizes are passed on to the kernels. Any temporary arrays with
    the horizontal dimension as lead dimension are then allocated as offsets
    in the stack array.

    The transformation needs to be applied in reverse order, which will do the following for each **kernel**:

    * Add arguments to the kernel call signature to pass the stack arrays and their (free) size
    * Determine the combined size of all local arrays that are to be allocated on the stack,
      taking into account calls to nested kernels. This is reported in :any:`Item`'s ``trafo_data``.
    * Replace any access to temporary arrays with the corresponding offsets in the stack array
    * Pass the stack arrays as arguments to any nested kernel calls

    In a **driver** routine, the transformation will:

    * Determine the required scratch space from ``trafo_data``
    * Allocate the stack arrays
    * Insert data sharing clauses into OpenMP or OpenACC pragmas
    * Pass the stack arrays and sizes into the kernel calls

    Parameters
    ----------
    block_dim : :any:`Dimension`
        :any:`Dimension` object to define the blocking dimension
    horizontal: :any:`Dimension`
        :any:`Dimension` object to define the horizontal dimension
    stack_name : str, optional
        Name of the scratch space variable that is allocated in the
        driver (default: ``'STACK'``)
    local_int_var_name_pattern : str, optional
        Python format string pattern for the name of the integer variable
        for each temporary (default: ``'JD_{name}'``)
    driver_horizontal : str, optional
        Override string if a separate variable name should be used for the horizontal
        when allocating the stack in the driver.
    key : str, optional
        Overwrite the key that is used to store analysis results in ``trafo_data``.
    """

    _key = 'TemporariesRawStackTransformation'

    # Traverse call tree in reverse when using Scheduler
    reverse_traversal = True

    type_name_dict = {
        BasicType.REAL: {'kernel': 'P', 'driver': 'Z'},
        BasicType.LOGICAL: {'kernel': 'LD', 'driver': 'LL'},
        BasicType.INTEGER: {'kernel': 'K', 'driver': 'I'}
    }

    def __init__(
            self, block_dim, horizontal, stack_name='STACK',
            local_int_var_name_pattern='JD_{name}',
            driver_horizontal=None, **kwargs
    ):
        super().__init__(**kwargs)
        self.block_dim = block_dim
        self.horizontal = horizontal
        self.stack_name = stack_name
        self.local_int_var_name_pattern = local_int_var_name_pattern
        self.driver_horizontal = driver_horizontal

    @property
    def int_type(self):
        return SymbolAttributes(
            dtype=BasicType.INTEGER, kind=DeferredTypeSymbol('JPIM')
        )

    def transform_subroutine(self, routine, **kwargs):

        role = kwargs['role']
        item = kwargs.get('item', None)

        if item:
            # Initialize set to store kind imports
            item.trafo_data[self._key] = {'kind_imports': {}}

        sub_sgraph = kwargs.get('sub_sgraph', None)
        successors = as_tuple(sub_sgraph.successors(item)) if sub_sgraph is not None else ()

        self.role = role

        if role == 'kernel':

            stack_dict = self.apply_raw_stack_allocator_to_temporaries(routine, item=item)
            if item:
                stack_dict = self._determine_stack_size(routine, successors, stack_dict, item=item)
                item.trafo_data[self._key]['stack_dict'] = stack_dict

            self.create_stacks_kernel(routine, stack_dict, successors)

        if role == 'driver':

            stack_dict = self._determine_stack_size(routine, successors, item=item)

            self.create_stacks_driver(routine, stack_dict, successors)


    def _get_stack_int_name(self, prefix, dtype, kind, suffix):
        """
        Construct the name string for stack used and size integers.
        Replace double underscore with single if kind is None
        """
        return (prefix + '_' + self.type_name_dict[dtype][self.role] + '_' +
                self._get_kind_name(kind) + '_' + suffix).replace('__', '_')


    def insert_stack_in_calls(self, routine, stack_arg_dict, successors):
        """
        Insert stack arguments into calls to successor routines.

        Parameters
        ----------
        routine : :any:`Subroutine`
            The routine in which to transform call statements
        stack_arg_dict : dict
            dict that maps dtype and kind to the sets of stack size variables
            and their corresponding stack array variables
        successors : list of :any:`Item`
            The items corresponding to successor routines called from :data:`routine`
        """
        successor_map = {
            successor.local_name: successor
            for successor in successors if isinstance(successor, ProcedureItem)
        }
        call_map = {}

        #Loop over calls and check if they call a successor routine and if the
        #transformation data is available
        for call in FindNodes(CallStatement).visit(routine.body):
            if call.name in successor_map and self._key in successor_map[call.name].trafo_data:
                successor_stack_dict = successor_map[call.name].trafo_data[self._key]['stack_dict']

                call_stack_args = []

                #Loop over dtypes and kinds in successor arguments stacks
                #and construct list of stack arguments
                for dtype in successor_stack_dict:
                    for kind in successor_stack_dict[dtype]:
                        call_stack_args += list(stack_arg_dict[dtype][kind])

                #Get position of optional arguments so we can place the stacks in front
                arg_pos = [call.routine.arguments.index(arg) for arg in call.routine.arguments if arg.type.optional]

                arguments = call.arguments
                if arg_pos:
                    #Stack arguments have already been added to the routine call signature
                    #so we have to subtract the number of stack arguments from the optional position
                    arg_pos = min(arg_pos) - len(call_stack_args)
                    arguments = arguments[:arg_pos] + as_tuple(call_stack_args) + arguments[arg_pos:]
                else:
                    arguments += as_tuple(call_stack_args)

                call_map[call] = call.clone(arguments=arguments)

        if call_map:
            routine.body = Transformer(call_map).visit(routine.body)


    def create_stacks_driver(self, routine, stack_dict, successors):
        """
        Create stack variables in the driver routine,
        add pragma directives to create the stacks on the device,
        and add the stack_variables to kernel call arguments.

        Parameters
        ----------
        routine : :any:`Subroutine`
            The driver subroutine to get the stack_variables
        stack_dict : dict
            dict that maps dtype and kind to an expression for the required stack size
        successors : list of :any:`Item`
            The items corresponding to successor routines called from :data:`routine`
        """

        #Block variables
        kgpblock = Scalar(name=self.block_dim.size, scope=routine, type=self.int_type)
        jgpblock = Scalar(name=self.block_dim.index, scope=routine, type=self.int_type)

        #Full dimensions for arguments
        fulldim = (RangeIndex((None,None)), RangeIndex((None,None)))

        stack_vars = []
        stack_arg_dict = {}
        assignments = []
        pragma_string = ''
        pragma_data_start = None
        for dtype in stack_dict:
            for kind in stack_dict[dtype]:

                #Start integer names in the driver with 'J'
                stack_size_name = self._get_stack_int_name('J', dtype, kind, 'STACK_SIZE')
                stack_size_var = Scalar(name=stack_size_name, scope=routine, type=self.int_type)

                #Create the stack variable and its type with the correct shape
                stack_var = self._get_stack_var(routine, dtype, kind)
                horizontal_size = self._get_horizontal_variable(routine)
                if self.driver_horizontal:
                    # If override is specified, use a separate horizontal in the driver
                    horizontal_size = Variable(
                        name=self.driver_horizontal, scope=routine, type=self.int_type
                    )

                stack_type = stack_var.type.clone(
                    shape=(horizontal_size, stack_dict[dtype][kind], kgpblock)
                )
                stack_var = stack_var.clone(type=stack_type)

                #Add the variables to the stack_arg_dict with dimensions (:,:,j_block)
                if dtype in stack_arg_dict:
                    stack_arg_dict[dtype][kind] = (stack_size_var, stack_var.clone(dimensions = fulldim+(jgpblock,)))
                else:
                    stack_arg_dict[dtype] = {kind: (stack_size_var, stack_var.clone(dimensions = fulldim+(jgpblock,)))}
                stack_var = stack_var.clone(dimensions=stack_type.shape)

                #Create stack_vars pair and assignment of the size variable
                stack_vars += [stack_size_var, stack_var]
                assignments += [Assignment(lhs=stack_size_var, rhs=stack_dict[dtype][kind])]
                pragma_string += f'{stack_var.name}, '

        #Add to routine
        routine.variables = routine.variables + as_tuple(stack_vars)
        routine.body.prepend(assignments)

        if pragma_string:
            pragma_string = pragma_string[:-2].lower()

            pragma_data_start = Pragma(keyword='loki', content=f'structured-data create({pragma_string})')
            pragma_data_end = Pragma(keyword='loki', content='end structured-data')

            routine.body.prepend(pragma_data_start)
            routine.body.append(pragma_data_end)

        #Insert variables in successor calls
        self.insert_stack_in_calls(routine, stack_arg_dict, successors)


    def create_stacks_kernel(self, routine, stack_dict, successors):
        """
        Create stack variables in kernel routine,
        add pragma directives to create the stacks on the device,
        and add the stack_variables to kernel call arguments.

        Parameters
        ----------
        routine : :any:`Subroutine`
            The kernel subroutine to get the stack_variables
        stack_dict : dict
            dict that maps dtype and kind to an expression for the required stack size
        successors : list of :any:`Item`
            The items corresponding to successor routines called from :data:`routine`
        """

        stack_vars = []
        stack_arg_dict = {}
        pragma_string = ''
        for dtype in stack_dict:
            for kind in stack_dict[dtype]:

                #Start arguments integer names in kernels with 'K'
                stack_size_name = self._get_stack_int_name('K', dtype, kind, 'STACK_SIZE')
                stack_size_var = Scalar(name=stack_size_name, scope=routine, type=self.int_type.clone(intent='IN'))

                #Local variables start with 'J'
                stack_used_name = self._get_stack_int_name('J', dtype, kind, 'STACK_USED')
                stack_used_var = Scalar(name=stack_used_name, scope=routine, type=self.int_type)

                #Create the stack variable and its type with the correct shape
                stack_var = self._get_stack_var(routine, dtype, kind)
                stack_type = stack_var.type.clone(shape=(self._get_horizontal_variable(routine), stack_size_var))
                stack_var = stack_var.clone(type=stack_type)

                #Pass on the stack variable from stack_used + 1 to stack_size
                #Pass stack_size - stack_used to stack size in called kernel
                arg_dims = (self._get_horizontal_range(routine),
                            RangeIndex((Sum((stack_used_var,IntLiteral(1))), stack_size_var)))
                if dtype in stack_arg_dict:
                    stack_arg_dict[dtype][kind] = (Sum((stack_size_var, Product((-1, stack_used_var)))),
                                                   stack_var.clone(dimensions = arg_dims))
                else:
                    stack_arg_dict[dtype] = {kind: (Sum((stack_size_var, Product((-1, stack_used_var)))),
                                                    stack_var.clone(dimensions = arg_dims))}

                #Create stack_vars pair
                stack_vars += [stack_size_var, stack_var.clone(dimensions=stack_type.shape)]
                pragma_string += f'{stack_var.name}, '

        if pragma_string:
            pragma_string = pragma_string[:-2].lower()

            present_pragma = None
            acc_pragmas = [p for p in FindNodes(Pragma).visit(routine.body) if p.keyword.lower() == 'loki'] # acc
            for pragma in acc_pragmas:
                if pragma.content.lower().startswith('device-present'):
                    present_pragma = pragma
                    break
            if present_pragma:
                pragma_map = {present_pragma: None}
                routine.body = Transformer(pragma_map).visit(routine.body)
                content = re.sub(r'\bvars\(', f'vars({pragma_string}, ', present_pragma.content.lower())
                present_pragma = present_pragma.clone(content = content)
                pragma_data_end = None
            else:
                present_pragma = Pragma(keyword='loki', content=f'device-present vars({pragma_string})')
                pragma_data_end = Pragma(keyword='loki', content='end device-present')

            routine.body.prepend(present_pragma)
            routine.body.append(pragma_data_end)


        # Keep optional arguments last; a workaround for the fact that keyword arguments are not supported
        # in device code
        arg_pos = [routine.arguments.index(arg) for arg in routine.arguments if arg.type.optional]
        if arg_pos:
            routine.arguments = routine.arguments[:arg_pos[0]] + as_tuple(stack_vars) + routine.arguments[arg_pos[0]:]
        else:
            routine.arguments += as_tuple(stack_vars)

        self.insert_stack_in_calls(routine, stack_arg_dict, successors)


    def apply_raw_stack_allocator_to_temporaries(self, routine, item=None):
        """
        Apply raw stack allocator to local temporary arrays

        This appends the relevant argument to the routine's dummy argument list and
        creates the assignment for the local copy of the stack type.
        For all local arrays, a Cray pointer is instantiated and the temporaries
        are mapped via Cray pointers to the pool-allocated memory region.

        The cumulative size of all temporary arrays is determined and returned.

        Parameters
        ----------
        routine : :any:`Subroutine`
            Subroutine object to apply transformation to

        Returns
        -------
        stack_dict : :any:`dict`
            dict with required stack size mapped to type and kind
        """

        #Get all temporary dicts and sort them according to dtype and kind
        temporary_arrays = self._filter_temporary_arrays(routine)
        temporary_array_dict = self._sort_arrays_by_type(temporary_arrays)

        integers = []
        allocations = []
        var_map = {}


        stack_dict = {}
        stack_set = OrderedSet()


        for (dtype, kind_dict) in temporary_array_dict.items():

            if dtype not in stack_dict:
                stack_dict[dtype] = {}

            for (kind, arrays) in kind_dict.items():

                #Initialize stack_used to 0
                stack_used = IntLiteral(0)
                if kind not in stack_dict[dtype]:
                    stack_dict[dtype][kind] = Literal(0)

                # Store type information of temporary allocation
                if item:
                    if kind in routine.imported_symbols:
                        item.trafo_data[self._key]['kind_imports'][kind] = routine.import_map[kind.name].module.lower()

                #Get the stack variable
                stack_var = self._get_stack_var(routine, dtype, kind)
                old_int_var = IntLiteral(0)
                old_array_size = ()

                #Loop over arrays
                for array in arrays:

                    int_var = Scalar(name=self.local_int_var_name_pattern.format(name=array.name),
                                     scope=routine, type=self.int_type)
                    integers += [int_var]

                    #Computer array size
                    array_size = one
                    for d in array.shape[1:]:
                        if isinstance(d, RangeIndex):
                            d_extent = Sum((d.upper, Product((-1,d.lower)), one))
                        else:
                            d_extent = d
                        array_size = simplify(Product((array_size, d_extent)))

                    #Add to stack dict and list of allocations
                    stack_dict[dtype][kind] = simplify(Sum((stack_dict[dtype][kind], array_size)))
                    allocations += [Assignment(lhs=int_var, rhs=Sum((old_int_var,) + old_array_size))]

                    #Store the old int variable to calculate offset for next array
                    old_int_var = int_var
                    if isinstance(array_size, Sum):
                        old_array_size = array_size.children
                    else:
                        old_array_size = (array_size,)

                    #Map array instances to stack offsets
                    temp_map = self._map_temporary_array(array, int_var, routine, stack_var)
                    var_map = {**var_map, **temp_map}
                    stack_set.add(stack_var)

                #Compute stack used
                stack_used = simplify(Sum((int_var, array_size)))
                stack_used_name = self._get_stack_int_name('J', dtype, kind, 'STACK_USED')
                stack_used_var = Scalar(name=stack_used_name, scope=routine, type=self.int_type)

                #List up integers and allocations generated
                integers += [stack_used_var]
                allocations += [Assignment(lhs=stack_used_var, rhs=stack_used)]

        #Substitute temporary arrays if any map
        if var_map:
            routine.body = SubstituteExpressions(var_map).visit(routine.body)

        #Add  variables to routines and allocations to body
        routine.variables = as_tuple(v for v in routine.variables if v not in temporary_arrays) + as_tuple(integers)
        routine.body.prepend(allocations)

        return stack_dict


    def _filter_temporary_arrays(self, routine):
        """
        Find all array variables in routine
        and filter out arguments, unused variables, fixed size arrays,
        and arrays whose lead dimension is not horizontal.

        Parameters
        ----------
        routine : :any:`Subroutine`
            The subroutine object to get arrays from
        """

        # Find all temporary arrays
        arguments = routine.arguments
        temporary_arrays = [
            var for var in routine.variables
            if isinstance(var, Array) and var not in arguments
        ]

        # Filter out unused vars
        with dataflow_analysis_attached(routine):
            temporary_arrays = [
                var for var in temporary_arrays
                if var.name.lower() in routine.body.defines_symbols
            ]

        # Filter out variables whose size is known at compile-time
        temporary_arrays = [
            var for var in temporary_arrays
            if not all(is_dimension_constant(d) for d in var.shape)
        ]

        # Filter out variables whose first dimension is not horizontal
        temporary_arrays = [
            var for var in temporary_arrays if (
            isinstance(var.shape[0], Scalar) and
            var.shape[0].name.lower() == self.horizontal.size.lower())
        ]

        return temporary_arrays


    def _get_kind_name(self, kind):

        if isinstance(kind, InlineCall):
            kind_name = kind.name
            for p in kind.parameters:
                kind_name += '_' + fgen(p)
            return kind_name

        return fgen(kind)


    def _sort_arrays_by_type(self, arrays):
        """
        Go through list of arrays and map each array
        to its type and kind in the the dict type_dict

        Parameters
        ----------
        arrays : List of array objects
        """

        type_dict = {}

        for a in arrays:
            if a.type.dtype in type_dict:
                if a.type.kind in type_dict[a.type.dtype]:
                    type_dict[a.type.dtype][a.type.kind] += [a]
                else:
                    type_dict[a.type.dtype][a.type.kind] = [a]
            else:
                type_dict[a.type.dtype] = {a.type.kind: [a]}

        return type_dict


    def _map_temporary_array(self, temp_array, int_var, routine, stack_var):
        """
        Find all instances of temporary array, temp_array, in routine and
        map them to to the corresponding position in stack stack_var.
        Position in stack is stored in int_var.
        Returns a dict mapping all instances of temp_array to corresponding stack position.

        Parameters
        ----------
        temp_array : :any:`Variable`
            Array to be mapped into stack array
        int_var : :any:`Variable`
            Integer variable corresponding to the position in of the array in the stack
        routine : :any:`Subroutine`
            The subroutine object to transform
        stack_var : :any:`Variable`
            The stack array variable

        Returns
        -------
        temp_map : :any:`dict`
            dict mapping variable instances to positions in the stack array
        """

        #List instances of temp_array
        temp_arrays = [v for v in FindVariables().visit(routine.body) if v.name == temp_array.name]

        temp_map = {}
        stack_dimensions = [None, None]

        #Loop over instances of temp_array
        for t in temp_arrays:

            offset = one
            stack_size = one

            if t.dimensions:
                #If t has dimensions, we must compute the offsets in the stack
                #taking each dimension into account

                #First dimension is just horizontal
                stack_dimensions[0] = t.dimensions[0]

                #Check if lead dimension is contiguous
                contiguous = (isinstance(t.dimensions[0], RangeIndex) and
                             (t.dimensions[0] == self._get_horizontal_range(routine) or
                             (t.dimensions[0].lower is None and t.dimensions[0].upper is None)))

                s_offset = one
                for d, s in zip(t.dimensions[1:], t.shape[1:]):

                    #Check if there are range indices in shape to account for
                    if isinstance(s, RangeIndex):
                        s_lower = s.lower
                        s_upper = s.upper
                        s_extent = Sum((s_upper, Product((-1, s_lower)), one))
                    else:
                        s_lower = one
                        s_upper = s
                        s_extent = s

                    if isinstance(d, RangeIndex):

                        #If dimension is a rangeindex, compute the indices
                        #Stop if there is any non contiguous access to the array
                        if not contiguous:
                            raise RuntimeError(f'Discontiguous access of array {t}')

                        if d.lower is None:
                            d_lower = s_lower
                        else:
                            d_lower = d.lower

                        if d.upper is None:
                            d_upper = s_upper
                        else:
                            d_upper = d.upper

                        #Store if this dimension was contiguous
                        contiguous = (d_upper == s_upper) and (d_lower == s_lower)

                        #Multiply stack_size by current dimension
                        stack_size = Product((stack_size, Sum((d_upper, Product((-1, d_lower)), one))))

                    else:

                        #Only need a single index to compute offset
                        d_lower = d


                    #Compute dimension and shape offsets
                    d_offset =  Sum((d_lower, Product((-1, s_lower))))

                    offset = Sum((offset, Product((d_offset, s_offset))))

                    s_offset = Product((s_offset, s_extent))


            else:
                #If t does not have dimensions,
                #we can just access (1:horizontal.size, 1:stack_size)

                stack_dimensions[0] = self._get_horizontal_range(routine)

                for s in t.shape[1:]:
                    if isinstance(s, RangeIndex):
                        s_lower = s.lower
                        s_upper = s.upper
                        s_extent = Sum((s_upper, Product((-1, s_lower)), one))
                    else:
                        s_lower = one
                        s_upper = s
                        s_extent = s

                    stack_size = Product((stack_size, s_extent))

            offset = simplify(offset)
            stack_size = simplify(stack_size)

            #Add offset to int_var
            if isinstance(offset, Sum):
                lower = Sum((int_var,) + offset.children)
            else:
                lower = Sum((int_var, offset))

            if stack_size == one:
                #If a single element is accessed, we only need a number
                stack_dimensions[1] = lower

            else:
                #Else we'll  have to construct a range index
                offset = simplify(Sum((offset, stack_size, Product((-1,one)))))
                if isinstance(offset, Sum):
                    upper = Sum((int_var,) + offset.children)
                else:
                    upper = Sum((int_var, offset))
                stack_dimensions[1] = RangeIndex((lower, upper))

            #Finally add to the mapping
            temp_map[t] = stack_var.clone(dimensions=as_tuple(stack_dimensions))

        return temp_map


    def _determine_stack_size(self, routine, successors, local_stack_dict=None, item=None):
        """
        Utility routine to determine the stack size required for the given :data:`routine`,
        including calls to subroutines

        Parameters
        ----------
        routine : :any:`Subroutine`
            The subroutine object for which to determine the stack size
        successors : list of :any:`Item`
            The items corresponding to successor routines called from :data:`routine`
        local_stack_dict : :any:`dict`, optional
            dict mapping type and kind to the corresponding number of elements used
        item : :any:`Item`
            Scheduler work item corresponding to routine.

        Returns
        -------
        stack_dict : :any:`dict`
            dict with required stack size mapped to type and kind
        """

        # Collect variable kind imports from successors
        if item:
            item.trafo_data[self._key]['kind_imports'].update(
                {k: v
                 for s in successors if isinstance(s, ProcedureItem)
                 for k, v in s.trafo_data[self._key]['kind_imports'].items()
                }
            )

        # Note: we are not using a CaseInsensitiveDict here to be able to search directly with
        # Variable instances in the dict. The StrCompareMixin takes care of case-insensitive
        # comparisons in that case
        successor_map = {
            successor.ir.name.lower(): successor
            for successor in successors if isinstance(successor, ProcedureItem)
        }

        # Collect stack sizes for successors
        # Note that we need to translate the names of variables used in the expressions to the
        # local names according to the call signature
        stack_dict = {}
        for call in FindNodes(CallStatement).visit(routine.body):
            if call.name in successor_map and self._key in successor_map[call.name].trafo_data:
                successor_stack_dict = successor_map[call.name].trafo_data[self._key]['stack_dict']

                # Replace any occurence of routine arguments in the stack size expression
                arg_map = dict(call.arg_iter())
                for dtype in successor_stack_dict:
                    for kind in successor_stack_dict[dtype]:
                        successor_stack_size = SubstituteExpressions(arg_map).visit(successor_stack_dict[dtype][kind])

                        if dtype in stack_dict:
                            if kind in stack_dict[dtype]:
                                if successor_stack_size not in stack_dict[dtype][kind]:
                                    stack_dict[dtype][kind] += [successor_stack_size]
                            else:
                                stack_dict[dtype][kind] = [successor_stack_size]
                        else:
                            stack_dict[dtype] = {kind: [successor_stack_size]}


        if not stack_dict:
            # Return only the local stack size if there are no callees
            return local_stack_dict or {}

        # Unwind "max" expressions from successors and inject the local stack size into the expressions
        for (dtype, kind_dict) in stack_dict.items():
            for (kind, stack_sizes) in kind_dict.items():
                new_list = []
                for stack_size in stack_sizes:
                    if (isinstance(stack_size, InlineCall) and stack_size.function == 'MAX'):
                        new_list += list(stack_size.parameters)
                    else:
                        new_list += [stack_size]
                stack_sizes = new_list

        #Simplify the local stack sizes and add them to the stack_dict
        if local_stack_dict:
            for dtype in local_stack_dict:
                for kind in local_stack_dict[dtype]:
                    local_stack_dict[dtype][kind] = DetachScopesMapper()(simplify(local_stack_dict[dtype][kind]))

                    if dtype in stack_dict:
                        if kind in stack_dict[dtype]:
                            stack_dict[dtype][kind] = [simplify(Sum((local_stack_dict[dtype][kind], s)))
                                                       for s in stack_dict[dtype][kind]]
                        else:
                            stack_dict[dtype][kind] = [local_stack_dict[dtype][kind]]
                    else:
                        stack_dict[dtype] = {kind: [local_stack_dict[dtype][kind]]}

        #If several expressions, return MAX, else just add the expression
        for (dtype, kind_dict) in stack_dict.items():
            for (kind, stacks) in kind_dict.items():
                if len(stacks) == 1:
                    kind_dict[kind] = stacks[0]
                else:
                    kind_dict[kind] = InlineCall(function = Variable(name = 'MAX'), parameters = as_tuple(stacks))

        return stack_dict


    def _get_stack_var(self, routine, dtype, kind):
        """
        Get a stack variable with a name determined by
        the type_name_dict and _get_kind_name().
        intent is determined by whether the routine is a kernel or driver
        """

        stack_name = self.type_name_dict[dtype][self.role] + '_' + self._get_kind_name(kind) + '_' + self.stack_name
        stack_name = stack_name.replace('__', '_')

        stack_intent = 'INOUT' if self.role == 'kernel' else None

        stack_type = SymbolAttributes(dtype = dtype,
                                      kind = kind,
                                      intent = stack_intent,
                                      shape = (RangeIndex((None, None))))

        return Array(name=stack_name, type=stack_type, scope=routine)


    def _get_horizontal_variable(self, routine):
        """
        Get a scalar int variable corresponding to horizontal dimension with routine as scope
        """
        return Variable(name=self.horizontal.size, scope=routine, type=self.int_type)

    def _get_horizontal_range(self, routine):
        """
        Get a RangeIndex from one to horizontal dimension
        """
        return RangeIndex((one, self._get_horizontal_variable(routine)))
loki-ecmwf-0.3.6/loki/transformations/temporaries/stack_allocator.py0000664000175000017500000015525715167130205026220 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from collections import  defaultdict

from loki.analyse import dataflow_analysis_attached
from loki.batch.item import ProcedureItem
from loki.batch.transformation import Transformation
from loki.expression.symbols import (
    Array, Scalar, Variable, Literal, Product, Sum, InlineCall,
    IntLiteral, RangeIndex, DeferredTypeSymbol
)
from loki.expression.symbolic import is_dimension_constant, simplify
from loki.expression.mappers import DetachScopesMapper
from loki.ir.expr_visitors import FindVariables, SubstituteExpressions
from loki.ir.nodes import (
        Assignment, CallStatement, Pragma, Allocation, Deallocation, VariableDeclaration, Import
)
from loki.ir.find import FindNodes
from loki.ir.transformer import Transformer
from loki.tools import as_tuple, CaseInsensitiveDict
from loki.types import BasicType, SymbolAttributes
from loki.transformations.utilities import recursive_expression_map_update, single_variable_declaration


__all__ = ['FtrPtrStackTransformation', 'DirectIdxStackTransformation']

class BaseStackTransformation(Transformation):
    """
    Base Transformation to inject a stack that allocates large scratch spaces per block
    and per datatype on the driver and maps temporary arrays in kernels to this scratch space.

    Parameters
    ----------
    block_dim : :any:`Dimension`
        :any:`Dimension` object to define the blocking dimension.
    horizontal : :any:`Dimension`
        :any:`Dimension` object to define the horizontal dimension.
    stack_name : str, optional
        Name of the stack (default: 'STACK')
    local_int_var_name_pattern : str, optional
        Local integer variable names pattern
        (default: 'JD_{name}')
    int_kind : str, optional
        Integer kind (default: 'JWIM')
    """

    _key = 'PoolAllocatorBaseTransformation'

    reverse_traversal = True

    type_name_dict = {
        BasicType.REAL: {'kernel': 'P', 'driver': 'Z'},
        BasicType.LOGICAL: {'kernel': 'LD', 'driver': 'LL'},
        BasicType.INTEGER: {'kernel': 'K', 'driver': 'I'}
    }

    def __init__(self, block_dim, horizontal,
                 stack_name='STACK', local_int_var_name_pattern='JD_{name}',
                 int_kind='JWIM', driver_horizontal=None, **kwargs):

        super().__init__(**kwargs)
        self.block_dim = block_dim
        self.horizontal = horizontal
        self.stack_name = stack_name
        self.local_int_var_name_pattern = local_int_var_name_pattern
        self.int_kind = int_kind
        self.driver_horizontal = driver_horizontal


    def _get_int_type(self, intent=None):
        return SymbolAttributes(
            dtype=BasicType.INTEGER, kind=DeferredTypeSymbol(self.int_kind),
            intent=intent
        )
    int_type = property(_get_int_type)

    def transform_subroutine(self, routine, **kwargs):

        role = kwargs['role']
        self.role = role
        item = kwargs.get('item', None)

        if item:
            item.trafo_data[self._key] = {'kind_imports': {}}

        sub_sgraph = kwargs.get('sub_sgraph', None)
        successors = as_tuple(sub_sgraph.successors(item)) if sub_sgraph is not None else ()

        # TODO: probably shouldn't happen here ...
        for call in FindNodes(CallStatement).visit(routine.body):
            if call.routine is not BasicType.DEFERRED:
                call.convert_kwargs_to_args()

        if role == 'kernel':
            stack_dict = self.apply_pool_allocator_to_temporaries(routine, item=item)
            if item:
                stack_dict = self._determine_stack_size(routine, successors, stack_dict, item=item)
                item.trafo_data[self._key]['stack_dict'] = stack_dict

            self.create_stacks_kernel(routine, stack_dict, successors)

        if role == 'driver':
            stack_dict = self._determine_stack_size(routine, successors, item=item)
            if item:
                # import variable type specifiers used in stack allocations
                self.import_allocation_types(routine, item)
            self.create_stacks_driver(routine, stack_dict, successors)

    @classmethod
    def import_allocation_types(cls, routine, item):
        new_imports = defaultdict(tuple)
        for s, m in item.trafo_data[cls._key]['kind_imports'].items():
            new_imports[m] += as_tuple(s)
        import_map = {i.module.lower(): i for i in routine.imports}
        for mod, symbs in new_imports.items():
            symbs = tuple(dict.fromkeys(symbs))
            if mod in import_map:
                import_map[mod]._update(symbols=as_tuple(dict.fromkeys(import_map[mod].symbols +symbs)))
            else:
                _symbs = [s for s in symbs if not (s.name.lower() in routine.variable_map or
                                                   s.name.lower() in routine.imported_symbol_map)]
                if _symbs:
                    imp = Import(module=mod, symbols=as_tuple(_symbs))
                    routine.spec.prepend(imp)

    @staticmethod
    def _insert_stack_at_loki_pragma(routine, insert):
        for pragma in FindNodes(Pragma).visit(routine.body):
            if pragma.keyword == 'loki' and 'stack-insert' in pragma.content:
                routine.body = Transformer({pragma: insert}).visit(routine.body)
                return True
        return False


    def _get_stack_int_name(self, prefix, dtype, kind, suffix):
        """
        Construct the name string for stack used and size integers.
        Replace double underscore with single if kind is None
        """
        return (f'{prefix}_{self.type_name_dict[dtype][self.role]}_'
                f'{self._get_kind_name(kind)}_{suffix}'.replace('__', '_'))


    def _get_stack_var(self, routine, dtype, kind):
        """
        Get a stack variable with a name determined by
        the type_name_dict and _get_kind_name().
        intent is determined by whether the routine is a kernel or driver
        """

        stack_name = self.type_name_dict[dtype][self.role] + '_' + self._get_kind_name(kind) + '_' + self.stack_name
        stack_name = stack_name.replace('__', '_')

        stack_intent = 'INOUT' if self.role == 'kernel' else None

        stack_type = SymbolAttributes(dtype=dtype,
                                      kind=kind,
                                      intent=stack_intent,
                                      shape=(RangeIndex((None, None))))

        return Array(name=stack_name, type=stack_type, scope=routine)


    def _get_horizontal_variable(self, routine):
        """
        Get a scalar int variable corresponding to horizontal dimension with routine as scope
        """
        arg_map = {}
        for call in FindNodes(CallStatement).visit(routine.body):
            if call.routine is not BasicType.DEFERRED:
                arg_map.update(dict(call.arg_iter()))

        var = Variable(name=self.horizontal.size, scope=routine, type=self.int_type)
        if var in arg_map:
            return arg_map[var]
        return var

    def _get_horizontal_range(self, routine):
        """
        Get a RangeIndex from one to horizontal dimension
        """
        return RangeIndex((IntLiteral(1), self._get_horizontal_variable(routine)))

    def _get_int_var(self, name, scope, type=None): # pylint: disable=redefined-builtin
        if type is None:
            type = self.int_type
        return Scalar(name=name, scope=scope, type=type)

    def _get_kind_name(self, kind):
        if isinstance(kind, InlineCall):
            kind_name = kind.name
            for p in kind.parameters:
                kind_name += '_' + str(p)
            return kind_name

        return str(kind) if kind is not None else ''

    def _sort_arrays_by_type(self, arrays):
        """
        Go through list of arrays and map each array
        to its type and kind in the the dict type_dict

        Parameters
        ----------
        arrays : List of array objects
        """

        type_dict = {}
        for a in arrays:
            type_dict.setdefault(a.type.dtype, {})
            type_dict[a.type.dtype].setdefault(a.type.kind, []).append(a)

        return type_dict

    def insert_stack_in_calls(self, routine, stack_arg_dict, successors):
        """
        Insert stack arguments into calls to successor routines.

        Parameters
        ----------
        routine : :any:`Subroutine`
            The routine in which to transform call statements
        stack_arg_dict : dict
            dict that maps dtype and kind to the sets of stack size variables
            and their corresponding stack array variables
        successors : list of :any:`Item`
            The items corresponding to successor routines called from :data:`routine`
        """
        successor_map = {
            successor.local_name: successor
            for successor in successors if isinstance(successor, ProcedureItem)
        }
        call_map = {}

        # loop over calls and check if they call a successor routine and if the
        # transformation data is available
        for call in FindNodes(CallStatement).visit(routine.body):
            if call.name in successor_map and self._key in successor_map[call.name].trafo_data:
                successor_stack_dict = successor_map[call.name].trafo_data[self._key]['stack_dict']

                call_stack_args = []

                # loop over dtypes and kinds in successor arguments stacks
                # and construct list of stack arguments
                for dtype in successor_stack_dict:
                    for kind in successor_stack_dict[dtype]:
                        call_stack_args += list(stack_arg_dict[dtype][kind])

                # get position of optional arguments so we can place the stacks in front
                arg_pos = [call.routine.arguments.index(arg) for arg in call.routine.arguments if arg.type.optional]

                arguments = call.arguments
                if arg_pos:
                    # stack arguments have already been added to the routine call signature
                    # so we have to subtract the number of stack arguments from the optional position
                    arg_pos = min(arg_pos) - len(call_stack_args)
                    arguments = arguments[:arg_pos] + as_tuple(call_stack_args) + arguments[arg_pos:]
                else:
                    arguments += as_tuple(call_stack_args)

                call_map[call] = call.clone(arguments=arguments)

        if call_map:
            routine.body = Transformer(call_map).visit(routine.body)


    def create_stacks_driver(self, routine, stack_dict, successors):
        """
        Create stack variables in the driver routine,
        add pragma directives to create the stacks on the device,
        and add the stack_variables to kernel call arguments.

        Parameters
        ----------
        routine : :any:`Subroutine`
            The driver subroutine to get the stack_variables
        stack_dict : dict
            dict that maps dtype and kind to an expression for the required stack size
        successors : list of :any:`Item`
            The items corresponding to successor routines called from :data:`routine`
        """

        # block variables
        kgpblock = self._get_int_var(name=self.block_dim.size, scope=routine)
        jgpblock = self._get_int_var(name=self.block_dim.index, scope=routine)

        stack_vars = []
        stack_arg_dict = {}
        assignments = []
        deallocs = []
        pragma_vars = []
        for dtype in stack_dict:
            for kind in stack_dict[dtype]:
                # start integer names in the driver with 'J'
                stack_size_name = self._get_stack_int_name('J', dtype, kind, 'STACK_SIZE')
                stack_size_var = self._get_int_var(name=stack_size_name, scope=routine)

                stack_used_name = self._get_stack_int_name('J', dtype, kind, 'STACK_USED')
                stack_used_var = self._get_int_var(name=stack_used_name, scope=routine)

                # create the stack variable and its type with the correct shape
                stack_var = self._get_stack_var(routine, dtype, kind)

                stack_type = stack_var.type.clone(shape=(RangeIndex((None,None)), RangeIndex((None,None))),
                                                  allocatable=True)
                stack_var = stack_var.clone(type=stack_type)

                stack_alloc = Allocation(variables=(stack_var.clone(dimensions=(stack_dict[dtype][kind], kgpblock)),))
                stack_dealloc = Deallocation(variables=(stack_var.clone(dimensions=None),))

                # add the variables to the stack_arg_dict with dimensions (:,j_block)
                stack_arg_dict.setdefault(dtype, {})
                stack_arg_dict[dtype][kind] = (stack_size_var,
                                               stack_var.clone(dimensions=(RangeIndex((None,None)), jgpblock,)),
                                               stack_used_var)
                stack_var = stack_var.clone(dimensions=stack_type.shape)

                stack_used_var_init = Assignment(lhs=stack_used_var, rhs=IntLiteral(1))
                # create stack_vars pair and assignment of the size variable
                stack_vars += [stack_size_var, stack_var, stack_used_var]
                assignments += [Assignment(lhs=stack_size_var,
                                           rhs=stack_dict[dtype][kind]), stack_alloc, stack_used_var_init]
                deallocs += [stack_dealloc]
                pragma_vars.append(stack_var.name)

        # add to routine
        routine.variables = routine.variables + as_tuple(stack_vars)
        nodes_to_add = assignments

        if pragma_vars:
            pragma_string = ', '.join(pragma_vars)

            pragma_data_start = Pragma(keyword='loki', content=f'unstructured-data create({pragma_string})')
            pragma_data_end = Pragma(keyword='loki', content=f'exit unstructured-data delete({pragma_string})')
            nodes_to_add += [pragma_data_start]

            routine.body.append(pragma_data_end)

        if not self._insert_stack_at_loki_pragma(routine, nodes_to_add):
            routine.body.prepend(nodes_to_add)

        if deallocs:
            routine.body.append(deallocs)

        # insert variables in successor calls
        self.insert_stack_in_calls(routine, stack_arg_dict, successors)


    def create_stacks_kernel(self, routine, stack_dict, successors):
        """
        Create stack variables in kernel routine,
        add pragma directives to create the stacks on the device,
        and add the stack_variables to kernel call arguments.

        Parameters
        ----------
        routine : :any:`Subroutine`
            The kernel subroutine to get the stack_variables
        stack_dict : dict
            dict that maps dtype and kind to an expression for the required stack size
        successors : list of :any:`Item`
            The items corresponding to successor routines called from :data:`routine`
        """

        stack_vars = []
        stack_args = []
        stack_arg_dict = {}
        pragma_vars = []
        assignments = []
        for dtype in stack_dict:
            for kind in stack_dict[dtype]:

                # start arguments integer names in kernels with 'K'
                stack_size_name = self._get_stack_int_name('K', dtype, kind, 'STACK_SIZE')
                stack_size_var = self._get_int_var(name=stack_size_name, scope=routine,
                                                   type=self._get_int_type(intent='IN'))

                # local variables start with 'J'
                stack_used_arg_name = self._get_stack_int_name('JD', dtype, kind, 'STACK_USED')
                stack_used_arg = self._get_int_var(name=stack_used_arg_name, scope=routine,
                                                   type=self._get_int_type(intent='INOUT'))
                stack_used_name = self._get_stack_int_name('J', dtype, kind, 'STACK_USED')
                stack_used_var = self._get_int_var(name=stack_used_name, scope=routine)
                assignments += [Assignment(lhs=stack_used_var, rhs=stack_used_arg)]

                # create the stack variable and its type with the correct shape
                shape = (stack_size_var,)
                stack_var = self._get_stack_var(routine, dtype, kind)
                stack_type = stack_var.type.clone(shape=as_tuple(shape), target=True)
                stack_var = stack_var.clone(type=stack_type)

                # pass on the stack variable from stack_used + 1 to stack_size
                # pass stack_size - stack_used to stack size in called kernel
                arg_dims = (RangeIndex((None, None)),)
                stack_arg_dict.setdefault(dtype, {})
                stack_arg_dict[dtype][kind] = (stack_size_var, stack_var.clone(dimensions=arg_dims), stack_used_var)

                # create stack_vars pair
                stack_args += [stack_size_var,
                               stack_var.clone(dimensions=stack_type.shape, type=stack_var.type.clone(contiguous=True)),
                               stack_used_arg]
                stack_vars += [stack_used_var]
                pragma_vars.append(stack_var.name)

        pragma_string = ', '.join(pragma_vars)
        if pragma_vars:
            present_pragma = Pragma(keyword='loki', content=f'device-present vars({pragma_string})')
            pragma_data_end = Pragma(keyword='loki', content='end device-present')
            routine.body.prepend(present_pragma)
            routine.body.append(pragma_data_end)
        routine.body.prepend(as_tuple(assignments))
        routine.variables += as_tuple(stack_vars)

        # keep optional arguments last; a workaround for the fact that keyword arguments are not supported
        # in device code
        arg_pos = [routine.arguments.index(arg) for arg in routine.arguments if arg.type.optional]
        if arg_pos:
            routine.arguments = routine.arguments[:arg_pos[0]] + as_tuple(stack_args) + routine.arguments[arg_pos[0]:]
        else:
            routine.arguments += as_tuple(stack_args)

        self.insert_stack_in_calls(routine, stack_arg_dict, successors)


    def apply_pool_allocator_to_temporaries(self, routine, item=None): # pylint: disable=unused-argument
        """
        Base method to apply raw stack allocator to local temporary arrays

        This method when implemented should append
        the relevant arguments to the routine's dummy argument list and
        should create the assignment for the local copy of the stack type.
        Further, the cumulative size of all temporary arrays
        should be determined and returned.

        Parameters
        ----------
        routine : :any:`Subroutine`
            Subroutine object to apply transformation to

        Returns
        -------
        stack_dict : :any:`dict`
            dict with required stack size mapped to type and kind
        """
        return {}

    def _filter_temporary_arrays(self, routine):
        """
        Find all array variables in routine
        and filter out arguments, unused variables, fixed size arrays,
        and arrays whose lead dimension is not horizontal.

        Parameters
        ----------
        routine : :any:`Subroutine`
            The subroutine object to get arrays from
        """

        # Find all temporary arrays
        arguments = routine.arguments
        temporary_arrays = [
            var for var in routine.variables
            if isinstance(var, Array) and var not in arguments
        ]

        # Filter out unused vars
        with dataflow_analysis_attached(routine):
            temporary_arrays = [
                var for var in temporary_arrays
                if var.name.lower() in routine.body.defines_symbols
            ]

        # Filter out variables whose size is known at compile-time
        temporary_arrays = [
            var for var in temporary_arrays
            if not all(is_dimension_constant(d) for d in var.shape)
        ]

        return temporary_arrays


    def _determine_stack_size(self, routine, successors, local_stack_dict=None, item=None):
        """
        Utility routine to determine the stack size required for the given :data:`routine`,
        including calls to subroutines

        Parameters
        ----------
        routine : :any:`Subroutine`
            The subroutine object for which to determine the stack size
        successors : list of :any:`Item`
            The items corresponding to successor routines called from :data:`routine`
        local_stack_dict : :any:`dict`, optional
            dict mapping type and kind to the corresponding number of elements used
        item : :any:`Item`
            Scheduler work item corresponding to routine.

        Returns
        -------
        stack_dict : :any:`dict`
            dict with required stack size mapped to type and kind
        """

        # Collect variable kind imports from successors
        if item:
            item.trafo_data[self._key]['kind_imports'].update(
                {k: v
                 for s in successors if isinstance(s, ProcedureItem)
                 for k, v in s.trafo_data[self._key]['kind_imports'].items()
                }
            )

        # Note: we are not using a CaseInsensitiveDict here to be able to search directly with
        # Variable instances in the dict. The StrCompareMixin takes care of case-insensitive
        # comparisons in that case
        successor_map = {
            successor.ir.name.lower(): successor
            for successor in successors if isinstance(successor, ProcedureItem)
        }

        # Collect stack sizes for successors
        # Note that we need to translate the names of variables used in the expressions to the
        # local names according to the call signature
        stack_dict = {}
        for call in FindNodes(CallStatement).visit(routine.body):
            if call.name in successor_map and self._key in successor_map[call.name].trafo_data:
                successor_stack_dict = successor_map[call.name].trafo_data[self._key]['stack_dict']

                # Replace any occurence of routine arguments in the stack size expression
                arg_map = dict(call.arg_iter())
                for dtype in successor_stack_dict:
                    for kind in successor_stack_dict[dtype]:
                        successor_stack_size = SubstituteExpressions(arg_map).visit(successor_stack_dict[dtype][kind])
                        arg_map = dict(call.arg_iter())
                        expr_map = {
                            expr: DetachScopesMapper()(arg_map[expr])\
                                    for expr in FindVariables().visit(successor_stack_size)
                            if expr in arg_map
                        }
                        if expr_map:
                            expr_map = recursive_expression_map_update(expr_map)
                            successor_stack_size = SubstituteExpressions(expr_map).visit(successor_stack_size)
                        stack_dict.setdefault(dtype, {})
                        stack_dict[dtype].setdefault(kind, []).append(successor_stack_size)


        if not stack_dict:
            # Return only the local stack size if there are no callees
            return local_stack_dict or {}

        # Unwind "max" expressions from successors and inject the local stack size into the expressions
        for (dtype, kind_dict) in stack_dict.items():
            for (kind, stack_sizes) in kind_dict.items():
                new_list = []
                for stack_size in stack_sizes:
                    if (isinstance(stack_size, InlineCall) and stack_size.function == 'MAX'):
                        new_list += list(stack_size.parameters)
                    else:
                        new_list += [stack_size]
                stack_sizes = new_list

        # simplify the local stack sizes and add them to the stack_dict
        if local_stack_dict:
            for dtype in local_stack_dict:
                for kind in local_stack_dict[dtype]:
                    local_stack_dict[dtype][kind] = DetachScopesMapper()(simplify(local_stack_dict[dtype][kind]))

                    if dtype in stack_dict:
                        if kind in stack_dict[dtype]:
                            stack_dict[dtype][kind] = [simplify(Sum((local_stack_dict[dtype][kind], s)))
                                                       for s in stack_dict[dtype][kind]]
                        else:
                            stack_dict[dtype][kind] = [local_stack_dict[dtype][kind]]
                    else:
                        stack_dict[dtype] = {kind: [local_stack_dict[dtype][kind]]}

        # if several expressions, return MAX, else just add the expression
        for (dtype, kind_dict) in stack_dict.items():
            for (kind, stacks) in kind_dict.items():
                if len(stacks) == 1:
                    kind_dict[kind] = stacks[0]
                else:
                    kind_dict[kind] = InlineCall(function=Variable(name='MAX'), parameters=as_tuple(stacks))

        return stack_dict

class FtrPtrStackTransformation(BaseStackTransformation):
    """         
    Transformation to inject a stack that allocates large scratch spaces per block
    and per datatype on the driver and maps temporary arrays in kernels to this scratch space.

    Starting from:

    .. code-block:: fortran

        SUBROUTINE driver (nlon, klev, nb, ydml_phy_mf)

          USE kernel_mod, ONLY: kernel

          IMPLICIT NONE

          INTEGER, INTENT(IN) :: nlon
          INTEGER, INTENT(IN) :: klev
          INTEGER, INTENT(IN) :: nb

          INTEGER :: jstart
          INTEGER :: jend

          INTEGER :: b

          REAL(KIND=jprb), DIMENSION(nlon, klev) :: zzz

          jstart = 1
          jend = nlon

          DO b=1,nb
            CALL kernel(nlon, klev, jstart, jend, zzz)
          END DO

        END SUBROUTINE driver

        SUBROUTINE kernel (nlon, klev, jstart, jend, pzz)

          IMPLICIT NONE

          INTEGER, INTENT(IN) :: nlon
          INTEGER, INTENT(IN) :: klev

          INTEGER, INTENT(IN) :: jstart
          INTEGER, INTENT(IN) :: jend

          REAL, INTENT(IN), DIMENSION(nlon, klev) :: pzz

          REAL, DIMENSION(nlon, klev) :: zzx
          REAL(KIND=SELECTED_REAL_KIND(13, 300)), DIMENSION(nlon, klev) :: zzy
          LOGICAL, DIMENSION(nlon, klev) :: zzl

          INTEGER :: testint
          INTEGER :: jl, jlev

          zzl = .false.
          DO jl=1,nlon
            DO jlev=1,klev
              zzx(jl, jlev) = pzz(jl, jlev)
              zzy(jl, jlev) = pzz(jl, jlev)
            END DO
          END DO

        END SUBROUTINE kernel

    This transformation generates:

    .. code-block:: fortran

        SUBROUTINE driver (nlon, klev, nb)

          USE kernel_mod, ONLY: kernel

          IMPLICIT NONE

          INTEGER, INTENT(IN) :: nlon
          INTEGER, INTENT(IN) :: klev
          INTEGER(KIND=JWIM) :: nb

          INTEGER :: jstart
          INTEGER :: jend

          INTEGER(KIND=JWIM) :: b

          REAL(KIND=jprb), DIMENSION(nlon, klev) :: zzz
          INTEGER(KIND=JWIM) :: J_Z_STACK_SIZE
          REAL, ALLOCATABLE :: Z_STACK(:, :)
          INTEGER(KIND=JWIM) :: J_Z_STACK_USED
          INTEGER(KIND=JWIM) :: J_Z_SELECTED_REAL_KIND_13_300_STACK_SIZE
          REAL(KIND=SELECTED_REAL_KIND(13, 300)), ALLOCATABLE :: Z_SELECTED_REAL_KIND_13_300_STACK(:, :)
          INTEGER(KIND=JWIM) :: J_Z_SELECTED_REAL_KIND_13_300_STACK_USED
          INTEGER(KIND=JWIM) :: J_LL_STACK_SIZE
          LOGICAL, ALLOCATABLE :: LL_STACK(:, :)
          INTEGER(KIND=JWIM) :: J_LL_STACK_USED
          J_Z_STACK_SIZE = klev*nlon
          ALLOCATE (Z_STACK(klev*nlon, nb))
          J_Z_STACK_USED = 1
          J_Z_SELECTED_REAL_KIND_13_300_STACK_SIZE = klev*nlon
          ALLOCATE (Z_SELECTED_REAL_KIND_13_300_STACK(klev*nlon, nb))
          J_Z_SELECTED_REAL_KIND_13_300_STACK_USED = 1
          J_LL_STACK_SIZE = klev*nlon
          ALLOCATE (LL_STACK(klev*nlon, nb))
          J_LL_STACK_USED = 1
        !$loki unstructured-data create( z_stack, z_selected_real_kind_13_300_stack, ll_stack )

          jstart = 1
          jend = nlon

          DO b=1,nb
            CALL kernel(nlon, klev, jstart, jend, zzz, J_Z_STACK_SIZE, Z_STACK(:, b), J_Z_STACK_USED,  &
            & J_Z_SELECTED_REAL_KIND_13_300_STACK_SIZE, Z_SELECTED_REAL_KIND_13_300_STACK(:, b),  &
            & J_Z_SELECTED_REAL_KIND_13_300_STACK_USED, J_LL_STACK_SIZE, LL_STACK(:, b), J_LL_STACK_USED)
          END DO

        !$loki exit unstructured-data delete( z_stack, z_selected_real_kind_13_300_stack, ll_stack )
          DEALLOCATE (Z_STACK)
          DEALLOCATE (Z_SELECTED_REAL_KIND_13_300_STACK)
          DEALLOCATE (LL_STACK)
        END SUBROUTINE driver

        SUBROUTINE kernel (nlon, klev, jstart, jend, pzz, K_P_STACK_SIZE, P_STACK, JD_P_STACK_USED,  &
        & K_P_SELECTED_REAL_KIND_13_300_STACK_SIZE, P_SELECTED_REAL_KIND_13_300_STACK, & 
        & JD_P_SELECTED_REAL_KIND_13_300_STACK_USED,  &
        & K_LD_STACK_SIZE, LD_STACK, JD_LD_STACK_USED)

          IMPLICIT NONE

          INTEGER, INTENT(IN) :: nlon
          INTEGER, INTENT(IN) :: klev

          INTEGER, INTENT(IN) :: jstart
          INTEGER, INTENT(IN) :: jend

          REAL, INTENT(IN), DIMENSION(nlon, klev) :: pzz

          REAL, POINTER, CONTIGUOUS, DIMENSION(:, :) :: zzx
          REAL(KIND=SELECTED_REAL_KIND(13, 300)), POINTER, CONTIGUOUS, DIMENSION(:, :) :: zzy
          LOGICAL, POINTER, CONTIGUOUS, DIMENSION(:, :) :: zzl

          INTEGER :: testint
          INTEGER :: jl, jlev
          INTEGER(KIND=JWIM) :: JD_incr
          INTEGER(KIND=JWIM) :: JD_incr_SELECTED_REAL_KIND_13_300
          INTEGER(KIND=JWIM) :: JD_incr
          INTEGER(KIND=JWIM) :: J_P_STACK_USED
          INTEGER(KIND=JWIM) :: J_P_SELECTED_REAL_KIND_13_300_STACK_USED
          INTEGER(KIND=JWIM) :: J_LD_STACK_USED
          INTEGER(KIND=JWIM), INTENT(IN) :: K_P_STACK_SIZE
          REAL, TARGET, CONTIGUOUS, INTENT(INOUT) :: P_STACK(K_P_STACK_SIZE)
          INTEGER(KIND=JWIM), INTENT(INOUT) :: JD_P_STACK_USED
          INTEGER(KIND=JWIM), INTENT(IN) :: K_P_SELECTED_REAL_KIND_13_300_STACK_SIZE
          REAL(KIND=SELECTED_REAL_KIND(13, 300)), TARGET, CONTIGUOUS, INTENT(INOUT) ::  &
          & P_SELECTED_REAL_KIND_13_300_STACK(K_P_SELECTED_REAL_KIND_13_300_STACK_SIZE)
          INTEGER(KIND=JWIM), INTENT(INOUT) :: JD_P_SELECTED_REAL_KIND_13_300_STACK_USED
          INTEGER(KIND=JWIM), INTENT(IN) :: K_LD_STACK_SIZE
          LOGICAL, TARGET, CONTIGUOUS, INTENT(INOUT) :: LD_STACK(K_LD_STACK_SIZE)
          INTEGER(KIND=JWIM), INTENT(INOUT) :: JD_LD_STACK_USED
          J_P_STACK_USED = JD_P_STACK_USED
          J_P_SELECTED_REAL_KIND_13_300_STACK_USED = JD_P_SELECTED_REAL_KIND_13_300_STACK_USED
          J_LD_STACK_USED = JD_LD_STACK_USED
        !$loki device-present vars( p_stack, p_selected_real_kind_13_300_stack, ld_stack )
          JD_incr = J_P_STACK_USED
          zzx(1:nlon, 1:klev) => P_STACK(JD_incr:JD_incr + nlon*klev)
          J_P_STACK_USED = JD_incr + klev*nlon
          JD_incr_SELECTED_REAL_KIND_13_300 = J_P_SELECTED_REAL_KIND_13_300_STACK_USED
          zzy(1:nlon, 1:klev) =>  &
          & P_SELECTED_REAL_KIND_13_300_STACK(JD_incr_SELECTED_REAL_KIND_13_300: &
              & JD_incr_SELECTED_REAL_KIND_13_300 + nlon*klev)
          J_P_SELECTED_REAL_KIND_13_300_STACK_USED = JD_incr_SELECTED_REAL_KIND_13_300 + klev*nlon
          JD_incr = J_LD_STACK_USED
          zzl(1:nlon, 1:klev) => LD_STACK(JD_incr:JD_incr + nlon*klev)
          J_LD_STACK_USED = JD_incr + klev*nlon

          zzl = .false.
          DO jl=1,nlon
            DO jlev=1,klev
              zzx(jl, jlev) = pzz(jl, jlev)
              zzy(jl, jlev) = pzz(jl, jlev)
            END DO
          END DO

        !$loki end device-present
        END SUBROUTINE kernel

    Parameters  
    ----------  
    block_dim : :any:`Dimension`
        :any:`Dimension` object to define the blocking dimension.
    horizontal : :any:`Dimension`
        :any:`Dimension` object to define the horizontal dimension.
    stack_name : str, optional
        Name of the stack (default: 'STACK')
    local_int_var_name_pattern : str, optional    
        Local integer variable names pattern
        (default: 'JD_{name}')
    int_kind : str, optional
        Integer kind (default: 'JWIM')
    """

    def adapt_temp_declarations(self, routine, temporary_arrays):
        # make sure relevant variables are declared in their own statement
        single_variable_declaration(routine, variables=[var.name for var in temporary_arrays])
        # make them 'pointer' and 'contiguous'
        for var in temporary_arrays:
            routine.symbol_attrs[var.name] = var.type.clone(pointer=True, contiguous=True)
        declarations = FindNodes(VariableDeclaration).visit(routine.spec)
        # adapt declarations
        decl_map = {}
        for decl in declarations:
            if decl.symbols[0] in temporary_arrays:
                new_dimensions = as_tuple((RangeIndex((None, None)),)*len(decl.symbols[0].dimensions))
                new_symbol = decl.symbols[0].clone(dimensions=new_dimensions)
                if decl.dimensions is not None:
                    decl_map[decl] = decl.clone(dimensions=(RangeIndex((None, None)),) * len(decl.dimensions),
                                                symbols=(new_symbol,))
                else:
                    decl_map[decl] = decl.clone(symbols=(new_symbol,))
        routine.spec = Transformer(decl_map).visit(routine.spec)

    def apply_pool_allocator_to_temporaries(self, routine, item=None):
        """
        Apply raw stack allocator to local temporary arrays

        This appends the relevant argument to the routine's dummy argument list and
        creates the assignment for the local copy of the stack type.
        For all local arrays, a Cray pointer is instantiated and the temporaries
        are mapped via Cray pointers to the pool-allocated memory region.

        The cumulative size of all temporary arrays is determined and returned.

        Parameters
        ----------
        routine : :any:`Subroutine`
            Subroutine object to apply transformation to

        Returns
        -------
        stack_dict : :any:`dict`
            dict with required stack size mapped to type and kind
        """

        # get all temporary dicts and sort them according to dtype and kind
        temporary_arrays = self._filter_temporary_arrays(routine)

        self.adapt_temp_declarations(routine, temporary_arrays)
        temporary_array_dict = self._sort_arrays_by_type(temporary_arrays)

        integers = []
        allocations = []

        stack_dict = {}

        for (dtype, kind_dict) in temporary_array_dict.items():

            if dtype not in stack_dict:
                stack_dict[dtype] = {}

            for (kind, arrays) in kind_dict.items():

                stack_used_name = self._get_stack_int_name('J', dtype, kind, 'STACK_USED')
                stack_used_var = self._get_int_var(name=stack_used_name, scope=routine)

                stack_used = IntLiteral(1)
                if kind not in stack_dict[dtype]:
                    stack_dict[dtype][kind] = Literal(0)

                # store type information of temporary allocation
                if item:
                    if kind in routine.imported_symbols:
                        item.trafo_data[self._key]['kind_imports'][kind] = routine.import_map[kind.name].module.lower()
                    for array in arrays:
                        dims = [d for d in array.shape if d in routine.imported_symbols]
                        for d in dims:
                            item.trafo_data[self._key]['kind_imports'][d] = routine.import_map[d.name].module.lower()

                # get the stack variable
                stack_var = self._get_stack_var(routine, dtype, kind)
                old_int_var = stack_used_var
                old_array_size = ()

                int_var_kind_name = self._get_kind_name(kind)
                int_var_name = 'incr'
                if int_var_kind_name:
                    int_var_name += f'_{int_var_kind_name}'
                int_var = self._get_int_var(name=self.local_int_var_name_pattern.format(name=int_var_name),
                                            scope=routine)
                integers += [int_var]

                # loop over arrays
                for array in arrays:

                    # compute array size
                    array_size = IntLiteral(1)
                    for d in array.shape:
                        if isinstance(d, RangeIndex):
                            d_extent = Sum((d.upper, Product((-1,d.lower)), IntLiteral(1)))
                        else:
                            d_extent = d
                        array_size = simplify(Product((array_size, d_extent)))

                    # add to stack dict and list of allocations
                    stack_dict[dtype][kind] = simplify(Sum((stack_dict[dtype][kind], array_size)))
                    allocations += [Assignment(lhs=int_var, rhs=Sum((old_int_var,) + old_array_size))]

                    # store the old int variable to calculate offset for next array
                    old_int_var = int_var
                    if isinstance(array_size, Sum):
                        old_array_size = array_size.children
                    else:
                        old_array_size = (array_size,)

                    ptr_assignment = self._get_ptr_assignment(array, int_var, stack_var)
                    allocations += [ptr_assignment]

                # compute stack used
                stack_used = simplify(Sum((int_var, array_size)))
                stack_used_name = self._get_stack_int_name('J', dtype, kind, 'STACK_USED')
                stack_used_var = self._get_int_var(name=stack_used_name, scope=routine)

                # list up integers and allocations generated
                allocations += [Assignment(lhs=stack_used_var, rhs=stack_used)]

        # add variables to routines and allocations to body
        routine.variables = as_tuple(v for v in routine.variables if v not in temporary_arrays) + as_tuple(integers)
        routine.body.prepend(allocations)

        return stack_dict

    def _get_ptr_assignment(self, array, int_var, stack_var):
        arr_dim = ()
        stack_dim_upper = ()
        for dim in array.shape:
            if isinstance(dim, RangeIndex):
                arr_dim += (dim,)
                stack_dim_upper += (Sum((dim.upper, IntLiteral(1), Product((-1, dim.lower)))),)
            else:
                arr_dim += (RangeIndex((IntLiteral(1), dim)),)
                stack_dim_upper += (dim,)

        if stack_dim_upper:
            stack_dim_upper = Sum((int_var, Product(stack_dim_upper)))
        else:
            stack_dim_upper = Sum((int_var, IntLiteral(1)))
        ptr_assignment = Assignment(lhs=array.clone(dimensions=arr_dim),
                                    rhs=stack_var.clone(dimensions=(RangeIndex((int_var, stack_dim_upper)))),
                                    ptr=True)
        return ptr_assignment

class DirectIdxStackTransformation(BaseStackTransformation):
    """         
    Transformation to inject a stack that allocates large scratch spaces per block
    and per datatype on the driver and maps temporary arrays in kernels to this scratch space.
                
    Starting from:

    .. code-block:: fortran

        SUBROUTINE driver (nlon, klev, nb, ydml_phy_mf)

          USE kernel_mod, ONLY: kernel

          IMPLICIT NONE

          INTEGER, INTENT(IN) :: nlon
          INTEGER, INTENT(IN) :: klev
          INTEGER, INTENT(IN) :: nb

          INTEGER :: jstart
          INTEGER :: jend

          INTEGER :: b

          REAL(KIND=jprb), DIMENSION(nlon, klev) :: zzz

          jstart = 1
          jend = nlon

          DO b=1,nb
            CALL kernel(nlon, klev, jstart, jend, zzz)
          END DO

        END SUBROUTINE driver

        SUBROUTINE kernel (nlon, klev, jstart, jend, pzz)

          IMPLICIT NONE

          INTEGER, INTENT(IN) :: nlon
          INTEGER, INTENT(IN) :: klev

          INTEGER, INTENT(IN) :: jstart
          INTEGER, INTENT(IN) :: jend

          REAL, INTENT(IN), DIMENSION(nlon, klev) :: pzz

          REAL, DIMENSION(nlon, klev) :: zzx
          REAL(KIND=SELECTED_REAL_KIND(13, 300)), DIMENSION(nlon, klev) :: zzy
          LOGICAL, DIMENSION(nlon, klev) :: zzl

          INTEGER :: testint
          INTEGER :: jl, jlev

          zzl = .false.
          DO jl=1,nlon
            DO jlev=1,klev
              zzx(jl, jlev) = pzz(jl, jlev)
              zzy(jl, jlev) = pzz(jl, jlev)
            END DO
          END DO

        END SUBROUTINE kernel

    This transformation generates:

    .. code-block:: fortran

        SUBROUTINE driver (nlon, klev, nb)

          USE kernel_mod, ONLY: kernel

          IMPLICIT NONE

          INTEGER, INTENT(IN) :: nlon
          INTEGER, INTENT(IN) :: klev
          INTEGER(KIND=JWIM) :: nb

          INTEGER :: jstart
          INTEGER :: jend

          INTEGER(KIND=JWIM) :: b

          REAL(KIND=jprb), DIMENSION(nlon, klev) :: zzz
          INTEGER(KIND=JWIM) :: J_Z_STACK_SIZE
          REAL, ALLOCATABLE :: Z_STACK(:, :)
          INTEGER(KIND=JWIM) :: J_Z_STACK_USED
          INTEGER(KIND=JWIM) :: J_Z_SELECTED_REAL_KIND_13_300_STACK_SIZE
          REAL(KIND=SELECTED_REAL_KIND(13, 300)), ALLOCATABLE :: Z_SELECTED_REAL_KIND_13_300_STACK(:, :)
          INTEGER(KIND=JWIM) :: J_Z_SELECTED_REAL_KIND_13_300_STACK_USED
          INTEGER(KIND=JWIM) :: J_LL_STACK_SIZE
          LOGICAL, ALLOCATABLE :: LL_STACK(:, :)
          INTEGER(KIND=JWIM) :: J_LL_STACK_USED
          J_Z_STACK_SIZE = klev*nlon
          ALLOCATE (Z_STACK(klev*nlon, nb))
          J_Z_STACK_USED = 1
          J_Z_SELECTED_REAL_KIND_13_300_STACK_SIZE = klev*nlon
          ALLOCATE (Z_SELECTED_REAL_KIND_13_300_STACK(klev*nlon, nb))
          J_Z_SELECTED_REAL_KIND_13_300_STACK_USED = 1
          J_LL_STACK_SIZE = klev*nlon
          ALLOCATE (LL_STACK(klev*nlon, nb))
          J_LL_STACK_USED = 1
        !$loki unstructured-data create( z_stack, z_selected_real_kind_13_300_stack, ll_stack )

          jstart = 1
          jend = nlon

          DO b=1,nb
            CALL kernel(nlon, klev, jstart, jend, zzz, J_Z_STACK_SIZE, Z_STACK(:, b), J_Z_STACK_USED,  &
            & J_Z_SELECTED_REAL_KIND_13_300_STACK_SIZE, Z_SELECTED_REAL_KIND_13_300_STACK(:, b),  &
            & J_Z_SELECTED_REAL_KIND_13_300_STACK_USED, J_LL_STACK_SIZE, LL_STACK(:, b), J_LL_STACK_USED)
          END DO

        !$loki exit unstructured-data delete( z_stack, z_selected_real_kind_13_300_stack, ll_stack )
          DEALLOCATE (Z_STACK)
          DEALLOCATE (Z_SELECTED_REAL_KIND_13_300_STACK)
          DEALLOCATE (LL_STACK)
        END SUBROUTINE driver

        SUBROUTINE kernel (nlon, klev, jstart, jend, pzz, K_P_STACK_SIZE, P_STACK, JD_P_STACK_USED,  &
        & K_P_SELECTED_REAL_KIND_13_300_STACK_SIZE, P_SELECTED_REAL_KIND_13_300_STACK, & 
        & JD_P_SELECTED_REAL_KIND_13_300_STACK_USED,  &
        & K_LD_STACK_SIZE, LD_STACK, JD_LD_STACK_USED)

          IMPLICIT NONE

          INTEGER, INTENT(IN) :: nlon
          INTEGER, INTENT(IN) :: klev

          INTEGER, INTENT(IN) :: jstart
          INTEGER, INTENT(IN) :: jend

          REAL, INTENT(IN), DIMENSION(nlon, klev) :: pzz


          INTEGER :: testint
          INTEGER :: jl, jlev
          INTEGER(KIND=JWIM) :: JD_zzx
          INTEGER(KIND=JWIM) :: JD_zzy
          INTEGER(KIND=JWIM) :: JD_zzl
          INTEGER(KIND=JWIM) :: J_P_STACK_USED
          INTEGER(KIND=JWIM) :: J_P_SELECTED_REAL_KIND_13_300_STACK_USED
          INTEGER(KIND=JWIM) :: J_LD_STACK_USED
          INTEGER(KIND=JWIM), INTENT(IN) :: K_P_STACK_SIZE
          REAL, TARGET, CONTIGUOUS, INTENT(INOUT) :: P_STACK(K_P_STACK_SIZE)
          INTEGER(KIND=JWIM), INTENT(INOUT) :: JD_P_STACK_USED
          INTEGER(KIND=JWIM), INTENT(IN) :: K_P_SELECTED_REAL_KIND_13_300_STACK_SIZE
          REAL(KIND=SELECTED_REAL_KIND(13, 300)), TARGET, CONTIGUOUS, INTENT(INOUT) ::  &
          & P_SELECTED_REAL_KIND_13_300_STACK(K_P_SELECTED_REAL_KIND_13_300_STACK_SIZE)
          INTEGER(KIND=JWIM), INTENT(INOUT) :: JD_P_SELECTED_REAL_KIND_13_300_STACK_USED
          INTEGER(KIND=JWIM), INTENT(IN) :: K_LD_STACK_SIZE
          LOGICAL, TARGET, CONTIGUOUS, INTENT(INOUT) :: LD_STACK(K_LD_STACK_SIZE)
          INTEGER(KIND=JWIM), INTENT(INOUT) :: JD_LD_STACK_USED
          J_P_STACK_USED = JD_P_STACK_USED
          J_P_SELECTED_REAL_KIND_13_300_STACK_USED = JD_P_SELECTED_REAL_KIND_13_300_STACK_USED
          J_LD_STACK_USED = JD_LD_STACK_USED
        !$loki device-present vars( p_stack, p_selected_real_kind_13_300_stack, ld_stack )
          JD_zzx = J_P_STACK_USED
          J_P_STACK_USED = JD_zzx + klev*nlon
          JD_zzy = J_P_SELECTED_REAL_KIND_13_300_STACK_USED
          J_P_SELECTED_REAL_KIND_13_300_STACK_USED = JD_zzy + klev*nlon
          JD_zzl = J_LD_STACK_USED
          J_LD_STACK_USED = JD_zzl + klev*nlon

          LD_STACK(1:klev*nlon) = .false.
          DO jl=1,nlon
            DO jlev=1,klev
              P_STACK(JD_zzx + jl - nlon + jlev*nlon) = pzz(jl, jlev)
              P_SELECTED_REAL_KIND_13_300_STACK(JD_zzy + jl - nlon + jlev*nlon) = pzz(jl, jlev)
            END DO
          END DO

        !$loki end device-present
        END SUBROUTINE kernel



    Parameters  
    ----------  
    block_dim : :any:`Dimension`
        :any:`Dimension` object to define the blocking dimension.
    horizontal : :any:`Dimension`
        :any:`Dimension` object to define the horizontal dimension.
    stack_name : str, optional
        Name of the stack (default: 'STACK')
    local_int_var_name_pattern : str, optional    
        Local integer variable names pattern
        (default: 'JD_{name}')
    int_kind : str, optional
        Integer kind (default: 'JWIM')
    """

    def apply_pool_allocator_to_temporaries(self, routine, item=None):
        """
        Apply raw stack allocator to local temporary arrays

        This appends the relevant argument to the routine's dummy argument list and
        creates the assignment for the local copy of the stack type.
        For all local arrays, a Cray pointer is instantiated and the temporaries
        are mapped via Cray pointers to the pool-allocated memory region.

        The cumulative size of all temporary arrays is determined and returned.

        Parameters
        ----------
        routine : :any:`Subroutine`
            Subroutine object to apply transformation to

        Returns
        -------
        stack_dict : :any:`dict`
            dict with required stack size mapped to type and kind
        """

        # get all temporary dicts and sort them according to dtype and kind
        temporary_arrays = self._filter_temporary_arrays(routine)
        temporary_array_dict = self._sort_arrays_by_type(temporary_arrays)

        integers = []
        allocations = []
        var_map = {}

        stack_dict = {}

        temp_array_map = CaseInsensitiveDict()

        for (dtype, kind_dict) in temporary_array_dict.items():

            if dtype not in stack_dict:
                stack_dict[dtype] = {}

            for (kind, arrays) in kind_dict.items():

                stack_used_name = self._get_stack_int_name('J', dtype, kind, 'STACK_USED')
                stack_used_var = self._get_int_var(name=stack_used_name, scope=routine)

                # initialize stack_used to 0
                stack_used = IntLiteral(1)
                if kind not in stack_dict[dtype]:
                    stack_dict[dtype][kind] = Literal(0)

                # store type information of temporary allocation
                if item:
                    if kind in routine.imported_symbols:
                        item.trafo_data[self._key]['kind_imports'][kind] = routine.import_map[kind.name].module.lower()
                    for array in arrays:
                        dims = [d for d in array.shape if d in routine.imported_symbols]
                        for d in dims:
                            item.trafo_data[self._key]['kind_imports'][d] = routine.import_map[d.name].module.lower()

                # get the stack variable
                stack_var = self._get_stack_var(routine, dtype, kind)
                old_int_var = stack_used_var
                old_array_size = ()

                # loop over arrays
                for array in arrays:

                    int_var_name = self.local_int_var_name_pattern.format(name=array.name)
                    int_var = self._get_int_var(name=int_var_name, scope=routine)
                    integers += [int_var]

                    # compute array size
                    array_size = IntLiteral(1)
                    for d in array.shape:
                        if isinstance(d, RangeIndex):
                            d_extent = Sum((d.upper, Product((-1,d.lower)), IntLiteral(1)))
                        else:
                            d_extent = d
                        array_size = simplify(Product((array_size, d_extent)))

                    # add to stack dict and list of allocations
                    stack_dict[dtype][kind] = simplify(Sum((stack_dict[dtype][kind], array_size)))
                    allocations += [Assignment(lhs=int_var, rhs=Sum((old_int_var,) + old_array_size))]

                    # store the old int variable to calculate offset for next array
                    old_int_var = int_var
                    if isinstance(array_size, Sum):
                        old_array_size = array_size.children
                    else:
                        old_array_size = (array_size,)

                    # save for later usage
                    temp_array_map[array.name] = (array, stack_var, int_var)

                # compute stack used
                stack_used = simplify(Sum((int_var, array_size)))
                stack_used_name = self._get_stack_int_name('J', dtype, kind, 'STACK_USED')
                stack_used_var = self._get_int_var(name=stack_used_name, scope=routine)

                # list up integers and allocations generated
                allocations += [Assignment(lhs=stack_used_var, rhs=stack_used)]

        var_map = self._map_temporary_array(temp_array_map, routine)
        if var_map:
            var_map = recursive_expression_map_update(var_map)
            routine.body = SubstituteExpressions(var_map).visit(routine.body)

        # add variables to routines and allocations to body
        routine.variables = as_tuple(v for v in routine.variables if v not in temporary_arrays) + as_tuple(integers)
        routine.body.prepend(allocations)

        return stack_dict

    def _map_temporary_array(self, temp_array_map, routine):
        """
        Find all instances of temporary arrays and
        map them to to the corresponding stack_var and position in stack stack_var.
        Position in stack is stored in the relevant int_var.
        """

        # list instances of temp_array
        temp_arrays = [v for v in FindVariables().visit(routine.body) if v.name.lower() in temp_array_map.keys()]
        temp_map = {}
        stack_dimensions = [None]

        # loop over instances of temp_array
        for t in temp_arrays:

            stack_var = temp_array_map[t.name][1]
            int_var = temp_array_map[t.name][2]

            offset = IntLiteral(1)
            stack_size = IntLiteral(1)

            if t.dimensions:
                # if t has dimensions, we must compute the offsets in the stack
                # taking each dimension into account

                # check if lead dimension is contiguous
                contiguous = (isinstance(t.dimensions[0], RangeIndex) and
                             (t.dimensions[0] == self._get_horizontal_range(routine) or
                             (t.dimensions[0].lower is None and t.dimensions[0].upper is None)))

                s_offset = IntLiteral(1)
                for d, s in zip(t.dimensions, t.shape):

                    # check if there are range indices in shape to account for
                    if isinstance(s, RangeIndex):
                        s_lower = s.lower
                        s_upper = s.upper
                        s_extent = Sum((s_upper, Product((-1, s_lower)), IntLiteral(1)))
                    else:
                        s_lower = IntLiteral(1)
                        s_upper = s
                        s_extent = s

                    if isinstance(d, RangeIndex):

                        # TODO: introduce warning here
                        # if dimension is a rangeindex, compute the indices
                        # stop if there is any non contiguous access to the array
                        if not contiguous:
                            # raise RuntimeError(f'Discontiguous access of array {t}')
                            print(f'Discontiguous access of array {t} within {routine}')

                        d_lower = d.lower or s_lower
                        d_upper = d.upper or s_upper

                        # store if this dimension was contiguous
                        contiguous = (d_upper == s_upper) and (d_lower == s_lower)

                        # multiply stack_size by current dimension
                        stack_size = Product((stack_size, Sum((d_upper, Product((-1, d_lower)), IntLiteral(1)))))

                    else:

                        # only need a single index to compute offset
                        d_lower = d

                    # compute dimension and shape offsets
                    d_offset =  Sum((d_lower, Product((-1, s_lower))))
                    offset = Sum((offset, Product((d_offset, s_offset))))
                    s_offset = Product((s_offset, s_extent))

            else:
                # if t does not have dimensions,
                # we can just access (1:horizontal.size, 1:stack_size)

                for s in t.shape:
                    if isinstance(s, RangeIndex):
                        s_lower = s.lower
                        s_upper = s.upper
                        s_extent = Sum((s_upper, Product((-1, s_lower)), IntLiteral(1)))
                    else:
                        s_lower = IntLiteral(1)
                        s_upper = s
                        s_extent = s

                    stack_size = Product((stack_size, s_extent))

            offset = simplify(offset)
            stack_size = simplify(stack_size)

            # add offset to int_var
            lower = Sum((int_var,) + offset.children if isinstance(offset, Sum) else (offset,))

            if stack_size == IntLiteral(1):
                # if a single element is accessed, we only need a number
                stack_dimensions[0] = lower

            else:
                # else we'll  have to construct a range index
                offset = simplify(Sum((offset, stack_size, Product((-1, IntLiteral(1))))))
                upper = Sum((int_var,) + offset.children if isinstance(offset, Sum) else (offset,))
                stack_dimensions[0] = RangeIndex((lower, upper))

            # finally add to the mapping
            temp_map[t] = stack_var.clone(dimensions=as_tuple(stack_dimensions))

        return temp_map
loki-ecmwf-0.3.6/loki/transformations/temporaries/pool_allocator.py0000664000175000017500000013205415167130205026052 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import re
from collections import  defaultdict

from loki.batch import Transformation
from loki.expression import (
    IntLiteral, LogicLiteral, Variable, Array, Sum, Literal,
    Product, InlineCall, Comparison, RangeIndex, Cast,
    ProcedureSymbol, simplify, is_dimension_constant,
    DetachScopesMapper
)
from loki.ir import (
    FindNodes, FindVariables, FindInlineCalls, Transformer, Intrinsic,
    Assignment, Conditional, CallStatement, Import, Allocation,
    Deallocation, Loop, Pragma, Interface, get_pragma_parameters,
    SubstituteExpressions
)
from loki.logging import warning, debug
from loki.tools import as_tuple, OrderedSet
from loki.types import SymbolAttributes, BasicType, DerivedType

from loki.transformations.utilities import recursive_expression_map_update


__all__ = ['TemporariesPoolAllocatorTransformation', 'EcstackPoolAllocatorTransformation']


class TemporariesPoolAllocatorTransformation(Transformation):
    """
    Transformation to inject a pool allocator that allocates a large scratch space per block
    on the driver and maps temporary arrays in kernels to this scratch space

    The stack is provided via two integer variables, ``_L`` and ``_U``, which are
    used as a stack pointer and stack end pointer, respectively.
    Naming is flexible and can be changed via options to the transformation.

    The transformation needs to be applied in reverse order, which will do the following for each **kernel**:

    * Add an argument/arguments to the kernel call signature to pass the stack integer(s)
        * either only the stack pointer is passed or the stack end pointer additionally if bound checking is active
    * Create a local copy of the stack derived type inside the kernel
    * Determine the combined size of all local arrays that are to be allocated by the pool allocator,
      taking into account calls to nested kernels. This is reported in :any:`Item`'s ``trafo_data``.
    * Inject Cray pointer assignments and stack pointer increments for all temporaries
    * Pass the local copy/copies of the stack integer(s) as argument to any nested kernel calls

    In a **driver** routine, the transformation will:

    * Determine the required scratch space from ``trafo_data``
    * Allocate the scratch space to that size
    * Insert data transfers (for OpenACC offloading)
    * Insert data sharing clauses into OpenMP or OpenACC pragmas
    * Assign stack base pointer and end pointer for each block (identified via :data:`block_dim`)
    * Pass the stack argument(s) to kernel calls


    With ``cray_ptr_loc_rhs=False`` the following stack/pool allocator will be generated:

    .. code-block:: fortran

        SUBROUTINE DRIVER (...)
          ...
          INTEGER(KIND=8) :: ISTSZ
          REAL(KIND=REAL64), ALLOCATABLE :: ZSTACK(:, :)
          INTEGER(KIND=8) :: YLSTACK_L
          INTEGER(KIND=8) :: YLSTACK_U
          ISTSZ = ISHFT(7 + C_SIZEOF(REAL(1, kind=jprb))***, -3) + ...
          ALLOCATE (ZSTACK(ISTSZ, nb))
          DO b=1,nb
            YLSTACK_L = LOC(ZSTACK(1, b))
            YLSTACK_U = YLSTACK_L + ISTSZ*C_SIZEOF(REAL(1, kind=REAL64))
            CALL KERNEL(..., YDSTACK_L=YLSTACK_L, YDSTACK_U=YLSTACK_U)
          END DO
          DEALLOCATE (ZSTACK)
        END SUBROUTINE DRIVER

        SUBROUTINE KERNEL(...)
          ...
          INTEGER(KIND=8) :: YLSTACK_L
          INTEGER(KIND=8) :: YLSTACK_U
          INTEGER(KIND=8), INTENT(INOUT) :: YDSTACK_L
          INTEGER(KIND=8), INTENT(INOUT) :: YDSTACK_U
          POINTER(IP_tmp1, tmp1)
          POINTER(IP_tmp2, tmp2)
          ...
          YLSTACK_L = YDSTACK_L
          YLSTACK_U = YDSTACK_U
          IP_tmp1 = YLSTACK_L
          YLSTACK_L = YLSTACK_L + ISHFT(ISHFT(**C_SIZEOF(REAL(1, kind=JPRB)) + 7, -3), 3)
          IF (YLSTACK_L > YLSTACK_U) STOP
          IP_tmp2 = YLSTACK_L
          YLSTACK_L = YLSTACK_L + ISHFT(ISHFT(...*C_SIZEOF(REAL(1, kind=JPRB)) + 7, -3), 3)
          IF (YLSTACK_L > YLSTACK_U) STOP
        END SUBROUTINE KERNEL

    With ``cray_ptr_loc_rhs=True`` the following stack/pool allocator will be generated:

    .. code-block:: fortran

        SUBROUTINE driver (NLON, NZ, NB, field1, field2)
          ...
          INTEGER(KIND=8) :: ISTSZ
          REAL(KIND=REAL64), ALLOCATABLE :: ZSTACK(:, :)
          INTEGER(KIND=8) :: YLSTACK_L
          INTEGER(KIND=8) :: YLSTACK_U
          ISTSZ = ISTSZ = ISHFT(7 + C_SIZEOF(REAL(1, kind=jprb))***, -3) + ...
          ALLOCATE (ZSTACK(ISTSZ, nb))
          DO b=1,nb
            YLSTACK_L = 1
            YLSTACK_U = YLSTACK_L + ISTSZ
            CALL KERNEL(..., YDSTACK_L=YLSTACK_L, YDSTACK_U=YLSTACK_U, ZSTACK=ZSTACK(:, b))
          END DO
          DEALLOCATE (ZSTACK)
        END SUBROUTINE driver

        SUBROUTINE KERNEL(...)
          ...
          INTEGER(KIND=8) :: YLSTACK_L
          INTEGER(KIND=8) :: YLSTACK_U
          INTEGER(KIND=8), INTENT(INOUT) :: YDSTACK_L
          INTEGER(KIND=8), INTENT(INOUT) :: YDSTACK_U
          REAL(KIND=REAL64), CONTIGUOUS, INTENT(INOUT) :: ZSTACK(:)
          POINTER(IP_tmp1, tmp1)
          POINTER(IP_tmp2, tmp2)
          ...
          YLSTACK_L = YDSTACK_L
          YLSTACK_U = YDSTACK_U
          IP_tmp1 = LOC(ZSTACK(YLSTACK_L))
          YLSTACK_L = YLSTACK_L + ISHFT(**C_SIZEOF(REAL(1, kind=JPRB)) + 7, -3)
          IF (YLSTACK_L > YLSTACK_U) STOP
          IP_tmp2 = LOC(ZSTACK(YLSTACK_L))
          YLSTACK_L = YLSTACK_L + ISHFT(...*C_SIZEOF(REAL(1, kind=JPRB)) + 7, -3)
          IF (YLSTACK_L > YLSTACK_U) STOP
        END SUBROUTINE KERNEL


    Parameters
    ----------
    block_dim : :any:`Dimension`
        :any:`Dimension` object to define the blocking dimension
        to use for hoisted column arrays if hoisting is enabled.
    stack_ptr_name : str, optional
        Name of the stack pointer variable to be appended to the generic
        stack name (default: ``'L'``) resulting in e.g., ``'_L'``
    stack_end_name : str, optional
        Name of the stack end pointer variable to be appendend to the generic
        stack name (default: ``'U'``) resulting in e.g., ``'_L'``
    stack_size_name : str, optional
        Name of the variable that holds the size of the scratch space in the
        driver (default: ``'ISTSZ'``)
    stack_storage_name : str, optional
        Name of the scratch space variable that is allocated in the
        driver (default: ``'ZSTACK'``)
    stack_argument_name : str, optional
        Name of the stack argument that is added to kernels (default: ``'YDSTACK'``)
    stack_local_var_name : str, optional
        Name of the local copy of the stack argument (default: ``'YLSTACK'``)
    local_ptr_var_name_pattern : str, optional
        Python format string pattern for the name of the Cray pointer variable
        for each temporary (default: ``'IP_{name}'``)
    stack_int_type_kind: :any:`Literal` or :any:`Variable`
        Integer type kind used for the stack pointer variable(s) (default: ``'8'``
        resulting in ``'INTEGER(KIND=8)'``)
    directive : str, optional
        Can be ``'openmp'`` or ``'openacc'``. If given, insert data sharing clauses for
        the stack derived type, and insert data transfer statements (for OpenACC only).
    check_bounds : bool, optional
        Insert bounds-checks in the kernel to make sure the allocated stack size is not
        exceeded (default: `True`)
    cray_ptr_loc_rhs : bool, optional
        Whether to only pass the stack variable as integer to the kernel(s) or
        whether to pass the whole stack array to the driver and the calls to ``LOC()``
        within the kernel(s) itself (default: `False`)
    stack_size_var_kind: :any:`Literal` or :any:`Variable`
        Defaults to ``'stack_int_type_kind'``, however, can be overriden if necessary.
    """

    _key = 'TemporariesPoolAllocatorTransformation'

    # Traverse call tree in reverse when using Scheduler
    reverse_traversal = True

    process_ignored_items = True

    def __init__(
            self, block_dim, horizontal=None, stack_ptr_name='L', stack_end_name='U', stack_size_name='ISTSZ',
            stack_storage_name='ZSTACK', stack_argument_name='YDSTACK', stack_local_var_name='YLSTACK',
            local_ptr_var_name_pattern='IP_{name}', stack_int_type_kind=IntLiteral(8), directive=None,
            check_bounds=True, cray_ptr_loc_rhs=False, stack_size_var_kind=None
    ):
        self.block_dim = block_dim
        self.horizontal = horizontal
        self.stack_ptr_name = stack_ptr_name
        self.stack_end_name = stack_end_name
        self.stack_size_name = stack_size_name
        self.stack_storage_name = stack_storage_name
        self.stack_argument_name = stack_argument_name
        self.stack_local_var_name = stack_local_var_name
        self.local_ptr_var_name_pattern = local_ptr_var_name_pattern
        self.stack_int_type_kind = stack_int_type_kind
        self.directive = directive
        self.check_bounds = check_bounds
        self.cray_ptr_loc_rhs = cray_ptr_loc_rhs
        self.stack_size_var_kind = stack_size_var_kind or self.stack_int_type_kind

        if self.stack_ptr_name == self.stack_end_name:
            raise ValueError(f'"stack_ptr_name": "{self.stack_ptr_name}" and '
                f'"stack_end_name": "{self.stack_end_name}" must be different!')

    def transform_subroutine(self, routine, **kwargs):

        role = kwargs['role']
        item = kwargs.get('item', None)
        ignore = item.ignore if item else ()
        targets = as_tuple(kwargs.get('targets', None))

        if item:
            # Initialize set to store kind imports
            item.trafo_data[self._key] = {'kind_imports': {}}

        # add iso_c_binding import if necessary
        self.import_c_sizeof(routine)
        # add iso_fortran_env import if necessary
        self.import_real64(routine)

        sub_sgraph = kwargs.get('sub_sgraph', None)
        successors = as_tuple(sub_sgraph.successors(item)) if sub_sgraph is not None else ()

        if role == 'kernel':
            stack_size = self.apply_pool_allocator_to_temporaries(routine, item=item)
            if item:
                stack_size = self._determine_stack_size(routine, successors, stack_size, item=item)
                item.trafo_data[self._key]['stack_size'] = stack_size

        elif role == 'driver':
            self.add_driver_imports(routine)
            stack_size = self._determine_stack_size(routine, successors, item=item)
            if item:
                # import variable type specifiers used in stack allocations
                self.import_allocation_types(routine, item)
            self.create_pool_allocator(routine, stack_size)

        self.inject_pool_allocator_into_calls(routine, targets, ignore, driver=role=='driver')

    def add_driver_imports(self, routine):
        pass

    @staticmethod
    def import_c_sizeof(routine):
        """
        Import the c_sizeof symbol if necesssary.
        """

        # add qualified iso_c_binding import
        if not 'C_SIZEOF' in routine.imported_symbols:
            imp = Import(
                module='ISO_C_BINDING', symbols=as_tuple(ProcedureSymbol('C_SIZEOF', scope=routine)),
                nature='intrinsic'
            )
            routine.spec.prepend(imp)

    @staticmethod
    def import_real64(routine):
        """
        Import the real64 symbol if necesssary.
        """

        # add qualified iso_fortran_env import
        if not 'REAL64' in routine.imported_symbols:
            imp = Import(
                module='ISO_FORTRAN_ENV', symbols=as_tuple(ProcedureSymbol('REAL64', scope=routine)),
                nature='intrinsic'
            )
            routine.spec.prepend(imp)

    def import_allocation_types(self, routine, item):
        """
        Import all the variable types used in allocations.
        """

        new_imports = defaultdict(OrderedSet)
        for s, m in item.trafo_data[self._key]['kind_imports'].items():
            new_imports[m] |= OrderedSet(as_tuple(s))

        import_map = {i.module.lower(): i for i in routine.imports}
        for mod, symbs in new_imports.items():
            if mod in import_map:
                import_map[mod]._update(symbols=as_tuple(OrderedSet(import_map[mod].symbols + as_tuple(symbs))))
            else:
                _symbs = [s for s in symbs if not (s.name.lower() in routine.variable_map or
                                                   s.name.lower() in routine.imported_symbol_map)]
                if _symbs:
                    imp = Import(module=mod, symbols=as_tuple(_symbs))
                    routine.spec.prepend(imp)

    def _get_local_stack_var(self, routine):
        """
        Utility routine to get the local stack variable

        The variable is created and added to :data:`routine` if it doesn't exist, yet.
        """
        if f'{self.stack_local_var_name}_{self.stack_ptr_name}' in routine.variables:
            return routine.variable_map[f'{self.stack_local_var_name}_{self.stack_ptr_name}']

        stack_type = SymbolAttributes(dtype=BasicType.INTEGER, kind=self.stack_int_type_kind)
        stack_var = Variable(name=f'{self.stack_local_var_name}_{self.stack_ptr_name}', type=stack_type, scope=routine)
        routine.variables += (stack_var,)
        return stack_var

    def _get_local_stack_var_end(self, routine):
        """
        Utility routine to get the local stack variable end

        The variable is created and added to :data:`routine` if it doesn't exist, yet.
        """
        if f'{self.stack_local_var_name}_{self.stack_end_name}' in routine.variables:
            return routine.variable_map[f'{self.stack_local_var_name}_{self.stack_end_name}']

        stack_type = SymbolAttributes(dtype=BasicType.INTEGER, kind=self.stack_int_type_kind)
        var_name = f'{self.stack_local_var_name}_{self.stack_end_name}'
        stack_var_end = Variable(name=var_name, type=stack_type, scope=routine)
        routine.variables += (stack_var_end,)
        return stack_var_end

    def _get_stack_arg(self, routine):
        """
        Utility routine to get the stack argument

        The argument is created and added to the dummy argument list of :data:`routine`
        if it doesn't exist, yet.
        """
        if f'{self.stack_argument_name}_{self.stack_ptr_name}' in routine.arguments:
            return routine.variable_map[f'{self.stack_argument_name}_{self.stack_ptr_name}']

        stack_type = SymbolAttributes(dtype=BasicType.INTEGER, intent='inout', kind=self.stack_int_type_kind)
        var_name = f'{self.stack_argument_name}_{self.stack_ptr_name}'
        stack_arg = Variable(name=var_name, type=stack_type, scope=routine)
        routine.arguments += (stack_arg,)

        return stack_arg

    def _get_stack_arg_end(self, routine):
        """
        Utility routine to get the stack argument end

        The argument is created and added to the dummy argument list of :data:`routine`
        if it doesn't exist, yet.
        """
        if f'{self.stack_argument_name}_{self.stack_end_name}' in routine.arguments:
            return routine.variable_map[f'{self.stack_argument_name}_{self.stack_end_name}']

        stack_type = SymbolAttributes(dtype=BasicType.INTEGER, intent='inout', kind=self.stack_int_type_kind)
        var_name = f'{self.stack_argument_name}_{self.stack_end_name}'
        stack_arg_end = Variable(name=var_name, type=stack_type, scope=routine)
        routine.arguments += (stack_arg_end,)

        return stack_arg_end

    def _get_stack_ptr(self, routine):
        """
        Utility routine to get the stack pointer variable
        """
        return Variable(
                name=f'{self.stack_local_var_name}_{self.stack_ptr_name}',
                scope=routine
                )

    def _get_stack_end(self, routine):
        """
        Utility routine to get the stack end pointer variable
        """
        return Variable(
            name=f'{self.stack_local_var_name}_{self.stack_end_name}',
            scope=routine
        )

    def _get_stack_alloc(self, routine, stack_storage, stack_size_var, block_size): # pylint: disable=unused-argument
        stack_alloc = Allocation(variables=(stack_storage.clone(dimensions=(  # pylint: disable=no-member
            stack_size_var, block_size)),))
        return stack_alloc

    def _get_stack_dealloc(self, routine, stack_storage, stack_size_var, block_size): # pylint: disable=unused-argument
        return Deallocation(variables=(stack_storage.clone(dimensions=None),))

    def _get_pragma_start(self, routine, stack_storage): # pylint: disable=unused-argument
        pragma = Pragma(
            keyword='loki',
            content=f'structured-data create({stack_storage.name})' # pylint: disable=no-member
        )
        return pragma

    def _get_pragma_end(self, routine, stack_storage): # pylint: disable=unused-argument
        return Pragma(keyword='loki', content='end structured-data')

    def _get_stack_type(self, routine):
        stack_type = SymbolAttributes(
            dtype=BasicType.REAL,
            kind=Variable(name='REAL64', scope=routine),
            shape=(RangeIndex((None, None)), RangeIndex((None, None))),
            allocatable=True,
        )
        return stack_type

    def _get_stack_storage_and_size_var(self, routine, stack_size):
        """
        Utility routine to obtain storage array and size variable for the pool allocator

        If array or size variable already exist, matching the provided names :attr:`stack_size_name`
        and :attr:`stack_storage_name`, they are used directly. Note that this does not validate
        that :data:`stack_size` matches the allocated array size.

        If array or size variable do not exist, yet, they are created as required and initialized or
        allocated accordingly.
        """
        variable_map = routine.variable_map  # Local copy to look-up variables by name

        # Nodes to prepend/append to the routine's body
        body_prepend = []
        body_append = []

        variables_append = []  # New variables to declare in the routine

        if self.stack_size_name in variable_map:
            # Use an existing stack size declaration
            stack_size_var = routine.variable_map[self.stack_size_name]

        else:
            # Create a variable for the stack size and assign the size
            stack_size_var_type = SymbolAttributes(BasicType.INTEGER, kind=self.stack_size_var_kind)
            stack_size_var = Variable(name=self.stack_size_name, type=stack_size_var_type)

            # Retrieve kind parameter of stack storage
            _kind = routine.symbol_map.get('REAL64', None) or Variable(name='REAL64')

            # Convert stack_size from bytes to integer
            stack_type_bytes = Cast(name='REAL', expression=Literal(1), kind=_kind)
            stack_type_bytes = InlineCall(Variable(name='C_SIZEOF'),
                                          parameters=as_tuple(stack_type_bytes))
            stack_size_assign = Assignment(lhs=stack_size_var, rhs=stack_size)
            body_prepend += [stack_size_assign]
            variables_append += [stack_size_var]

        if self.stack_storage_name in variable_map:
            # Use an existing stack storage array
            stack_storage = routine.variable_map[self.stack_storage_name]
        else:
            # Create a variable for the stack storage array and create corresponding
            # allocation/deallocation statements
            stack_type = self._get_stack_type(routine)
            stack_storage = Variable(
                name=self.stack_storage_name, type=stack_type,
                dimensions=stack_type.shape, scope=routine
            )
            variables_append += [stack_storage]

            block_size = routine.resolve_typebound_var(self.block_dim.size, routine.symbol_map)
            stack_alloc = self._get_stack_alloc(routine, stack_storage, stack_size_var, block_size)
            stack_dealloc = self._get_stack_dealloc(routine, stack_storage, stack_size_var, block_size)

            body_prepend += [stack_alloc]
            pragma_data_start = self._get_pragma_start(routine, stack_storage)
            body_prepend += [pragma_data_start]
            pragma_data_end = self._get_pragma_end(routine, stack_storage)
            body_append += [pragma_data_end]
            if stack_dealloc is not None:
                body_append += [stack_dealloc]

        # Inject new variables and body nodes
        if variables_append:
            routine.variables += as_tuple(variables_append)
        if body_prepend:
            if not self._insert_stack_at_loki_pragma(routine, body_prepend):
                routine.body.prepend(body_prepend)
        if body_append:
            routine.body.append(body_append)

        return stack_storage, stack_size_var

    @staticmethod
    def _insert_stack_at_loki_pragma(routine, insert):
        pragma_map = {}
        for pragma in FindNodes(Pragma).visit(routine.body):
            if pragma.keyword == 'loki' and 'stack-insert' in pragma.content:
                pragma_map[pragma] = insert
        if pragma_map:
            routine.body = Transformer(pragma_map).visit(routine.body)
            return True
        return False

    def _determine_stack_size(self, routine, successors, local_stack_size=None, item=None):
        """
        Utility routine to determine the stack size required for the given :data:`routine`,
        including calls to subroutines

        Parameters
        ----------
        routine : :any:`Subroutine`
            The subroutine object for which to determine the stack size
        successors : list of :any:`Item`
            The items corresponding to successor routines called from :data:`routine`
        local_stack_size : :any:`Expression`, optional
            The stack size required for temporaries in :data:`routine`
        item : :any:`Item`
            Scheduler work item corresponding to routine.

        Returns
        -------
        :any:`Expression` :
            The expression representing the required stack size.
        """

        # Collect variable kind imports from successors
        if item:
            item.trafo_data[self._key]['kind_imports'].update({
                k: v
                for s in successors
                for k, v in s.trafo_data.get(self._key, {}).get('kind_imports', {}).items()
            })

        # Note: we are not using a CaseInsensitiveDict here to be able to search directly with
        # Variable instances in the dict. The StrCompareMixin takes care of case-insensitive
        # comparisons in that case
        successor_map = {
            successor.local_name.lower(): successor
            for successor in successors
        }

        # Collect stack sizes for successors
        # Note that we need to translate the names of variables used in the expressions to the
        # local names according to the call signature
        stack_sizes = []
        for call in FindNodes(CallStatement).visit(routine.body):
            if call.name in successor_map and self._key in successor_map[call.name].trafo_data:
                successor_stack_size = successor_map[call.name].trafo_data[self._key]['stack_size']
                # Replace any occurence of routine arguments in the stack size expression
                arg_map = dict(call.arg_iter())
                expr_map = {
                    expr: DetachScopesMapper()(arg_map[expr]) for expr in FindVariables().visit(successor_stack_size)
                    if expr in arg_map
                }
                if expr_map:
                    expr_map = recursive_expression_map_update(expr_map)
                    successor_stack_size = SubstituteExpressions(expr_map).visit(successor_stack_size)
                stack_sizes += [successor_stack_size]

        # Unwind "max" expressions from successors and inject the local stack size into the expressions
        stack_sizes = [
            d for s in stack_sizes
            for d in (s.parameters if isinstance(s, InlineCall) and s.function == 'MAX' else [s])
        ]
        if local_stack_size:
            local_stack_size = DetachScopesMapper()(simplify(local_stack_size))
            stack_sizes = [simplify(Sum((local_stack_size, s))) for s in stack_sizes]

        if not stack_sizes:
            # Return only the local stack size if there are no callees
            return local_stack_size or Literal(0)

        if len(stack_sizes) == 1:
            # For a single successor, it is sufficient to add the local stack size to the expression
            return stack_sizes[0]

        # Re-build the max expressions, taking into account the local stack size and calls to successors
        stack_size = InlineCall(function=Variable(name='MAX'), parameters=as_tuple(stack_sizes), kw_parameters=())
        return stack_size

    def _get_c_sizeof_arg(self, arr):
        """
        Return an inline declaration of an intrinsic type, to be used as an argument to
        `C_SIZEOF`.
        """

        if arr.type.dtype == BasicType.REAL:
            param = Cast(name='REAL', expression=IntLiteral(1))
        elif arr.type.dtype == BasicType.INTEGER:
            param = Cast(name='INT', expression=IntLiteral(1))
        elif arr.type.dtype == BasicType.CHARACTER:
            param = Cast(name='CHAR', expression=IntLiteral(1))
        elif arr.type.dtype == BasicType.LOGICAL:
            param = Cast(name='LOGICAL', expression=LogicLiteral('.TRUE.'))
        elif arr.type.dtype == BasicType.COMPLEX:
            param = Cast(name='CMPLX', expression=(IntLiteral(1), IntLiteral(1)))

        param.kind = getattr(arr.type, 'kind', None) # pylint: disable=possibly-used-before-assignment

        return param

    def _create_stack_allocation(self, stack_ptr, stack_end, ptr_var, arr, stack_size, stack_storage=None):
        """
        Utility routine to "allocate" a temporary array on the pool allocator's "stack"

        This creates the pointer assignment, stack pointer increment and adds a stack size check.

        Parameters
        ----------
        stack_ptr : :any:`Variable`
            The stack pointer variable
        stack_end : :any:`Variable`
            The pointer variable that points to the end of the stack, used to verify stack size
        ptr_var : :any:`Variable`
            The pointer variable to use for the temporary array
        arr : :any:`Variable`
            The temporary array to allocate on the pool allocator's "stack"
        stack_size : :any:`Variable`
            The size in bytes of the pool allocator's "stack"

        Returns
        -------
        list
            The IR nodes to add for the stack allocation: an :any:`Assignment` for the pointer
            association, an :any:`Assignment` for the stack pointer increment, and a
            :any:`Conditional` that verifies that the stack is big enough
        """

        if self.cray_ptr_loc_rhs:
            ptr_assignment = Assignment(lhs=ptr_var, rhs=InlineCall(
                        function=Variable(name='LOC'),
                        parameters=(
                            stack_storage.clone(
                                dimensions=(stack_ptr.clone(),)
                            ),
                        ),
                        kw_parameters=None
                    )
                )
        else:
            ptr_assignment = Assignment(lhs=ptr_var, rhs=stack_ptr)

        # Build expression for array size in bytes
        dims = ()
        for d in arr.dimensions:
            if isinstance(d, RangeIndex):
                dims += (Sum((d.upper, Product((-1, d.lower)), 1)),)
            else:
                dims += (d,)
        dim = Product(dims)
        arr_type_bytes = InlineCall(Variable(name='C_SIZEOF'),
                                            parameters=as_tuple(self._get_c_sizeof_arg(arr)))
        arr_size = Product((dim, arr_type_bytes))

        # Allocation is expressed in terms of REAL64, i.e., 8 byte values
        # We obtain the allocation size by dividing the required size by 8 and rounding up,
        # i.e., (size + 7) // 8, with the division implemented as bit shifts
        ishift_func = InlineCall(function=Variable(name='ISHFT'))
        arr_size = ishift_func.clone(parameters=(Sum((arr_size, 7)), -3))

        # Increment stack size
        stack_size = simplify(Sum((stack_size, arr_size)))

        if self.cray_ptr_loc_rhs:
            ptr_increment = Assignment(lhs=stack_ptr, rhs=Sum((stack_ptr, arr_size)))
        else:
            ptr_increment = Assignment(lhs=stack_ptr, rhs=Sum((stack_ptr, ishift_func.clone(parameters=(arr_size, 3)))))
        if self.check_bounds:
            stack_size_check = Conditional(
                condition=Comparison(stack_ptr, '>', stack_end), inline=True,
                body=(Intrinsic('STOP'),), else_body=None
            )
            return ([ptr_assignment, ptr_increment, stack_size_check], stack_size)
        return ([ptr_assignment, ptr_increment], stack_size)

    def apply_pool_allocator_to_temporaries(self, routine, item=None):
        """
        Apply pool allocator to local temporary arrays

        This appends the relevant argument to the routine's dummy argument list and
        creates the assignment for the local copy of the stack type.
        For all local arrays, a Cray pointer is instantiated and the temporaries
        are mapped via Cray pointers to the pool-allocated memory region.

        The cumulative size of all temporary arrays is determined and returned.
        """

        # Find all temporary arrays
        arguments = routine.arguments
        temporary_arrays = [
            var for var in routine.variables
            if isinstance(var, Array) and var not in arguments
        ]

        # Filter out derived-type objects. Partly because the possibility of derived-type
        # nesting increases the complexity of determing allocation size, and partly because `C_SIZEOF`
        # doesn't account for the size of allocatable/pointer members of derived-types.
        if any(isinstance(var.type.dtype, DerivedType) for var in temporary_arrays):
            warning(f'[Loki::PoolAllocator] Derived-type vars in {routine} not supported in pool allocator')
        temporary_arrays = [
            var for var in temporary_arrays if not isinstance(var.type.dtype, DerivedType)
        ]

        # Filter out unused vars
        #  this used to rely on dataflow_analysis only putting temporaries that are defined/written to
        #  however, there exist some cases where temporaries are only read and it is not the
        #  responsibility of this transformation to decide whether that is reasonable.
        #  The following just removes unused temporaries so that they are not put on the stack.
        used_vars = {v.name.lower() for v in FindVariables().visit(routine.body)}
        temporary_arrays = [
                var for var in temporary_arrays
                if var.name.lower() in used_vars
        ]

        # Filter out variables whose size is known at compile-time
        temporary_arrays = [
            var for var in temporary_arrays
            if not all(is_dimension_constant(d) for d in var.shape)
        ]

        # Filter out pointers
        temporary_arrays = [
            var for var in temporary_arrays
            if not var.type.pointer or var.type.allocatable
        ]

        # Create stack argument and local stack var
        stack_var = self._get_local_stack_var(routine)
        stack_var_end = self._get_local_stack_var_end(routine) if self.check_bounds else None
        stack_arg = self._get_stack_arg(routine)
        stack_arg_end = self._get_stack_arg_end(routine) if self.check_bounds else None

        stack_storage = None
        if self.cray_ptr_loc_rhs:
            stack_type = SymbolAttributes(
                    dtype=BasicType.REAL,
                    kind=Variable(name='REAL64', scope=routine),
                    shape=(RangeIndex((None, None)),), intent='inout', contiguous=True,
            )
            stack_storage = Variable(
                    name=self.stack_storage_name, type=stack_type,
                    dimensions=stack_type.shape, scope=routine,
            )
            arg_pos = [routine.arguments.index(arg) for arg in routine.arguments if arg.type.optional]
            if arg_pos:
                routine.arguments = routine.arguments[:arg_pos[0]] + (stack_storage,) + routine.arguments[arg_pos[0]:]
            else:
                routine.arguments += (stack_storage,)

        allocations = [Assignment(lhs=stack_var, rhs=stack_arg)]
        if self.check_bounds:
            allocations.append(Assignment(lhs=stack_var_end, rhs=stack_arg_end))

        # Determine size of temporary arrays
        stack_size = Literal(0)

        # Create Cray pointer declarations and "stack allocations"
        declarations = []
        stack_ptr = self._get_stack_ptr(routine)
        stack_end = self._get_stack_end(routine)
        for arr in temporary_arrays:
            ptr_var = Variable(name=self.local_ptr_var_name_pattern.format(name=arr.name), scope=routine)
            declarations += [Intrinsic(f'POINTER({ptr_var.name}, {arr.name})')]  # pylint: disable=no-member
            allocation, stack_size = self._create_stack_allocation(stack_ptr, stack_end, ptr_var, arr,
                    stack_size, stack_storage)
            allocations += allocation

            # Store type and size information of temporary allocation
            if item:
                if (kind := arr.type.kind):
                    if kind in routine.imported_symbols:
                        item.trafo_data[self._key]['kind_imports'][kind] = routine.import_map[kind.name].module.lower()
                dims = [d for d in arr.shape if d in routine.imported_symbols]
                for d in dims:
                    item.trafo_data[self._key]['kind_imports'][d] = routine.import_map[d.name].module.lower()

        routine.spec.append(declarations)
        routine.body.prepend(allocations)

        return stack_size

    def create_pool_allocator(self, routine, stack_size):
        """
        Create a pool allocator in the driver
        """
        # Create and allocate the stack
        stack_storage, stack_size_var = self._get_stack_storage_and_size_var(routine, stack_size)
        stack_var = self._get_local_stack_var(routine)
        stack_var_end = self._get_local_stack_var_end(routine) if self.check_bounds else None
        stack_ptr = self._get_stack_ptr(routine)
        stack_end = self._get_stack_end(routine)

        pragma_map = {}
        pragmas = [p for p in FindNodes(Pragma).visit(routine.body) if p.keyword.lower() == 'loki']
        for pragma in pragmas:
            if pragma.content.lower().startswith('loop gang'):
                parameters = get_pragma_parameters(pragma, starts_with='loop gang', only_loki_pragmas=False)
                if 'private' in [p.lower() for p in parameters]:
                    var_end_str = f' {stack_var_end.name},' if self.check_bounds else ''
                    content = re.sub(r'\bprivate\(', f'private({stack_var.name},{var_end_str} ',
                            pragma.content.lower())
                else:
                    var_end_str = f', {stack_var_end.name}' if self.check_bounds else ''
                    content = pragma.content + f' private({stack_var.name}{var_end_str})'
                pragma_map[pragma] = pragma.clone(content=content)
        # problem being that code, like e.g. ecwam transformed for 'idem-stack', already having
        #  OpenMP pragmas rely on the following. Once we (decide to) implement a
        #  'reverse PragmaModel' trafo that converts e.g., OpenMP pragmas to generic Loki pragmas
        #  we do not longer rely on the following
        omp_pragmas = [p for p in FindNodes(Pragma).visit(routine.body) if p.keyword.lower() == 'omp']
        for pragma in omp_pragmas:
            if pragma.content.lower().startswith('parallel'):
                parameters = get_pragma_parameters(pragma, starts_with='parallel', only_loki_pragmas=False)
                if 'private' in [p.lower() for p in parameters]:
                    var_end_str = f' {stack_var_end.name},' if self.check_bounds else ''
                    content = re.sub(r'\bprivate\(', f'private({stack_var.name},{var_end_str}',
                            pragma.content.lower())
                else:
                    var_end_str = f', {stack_var_end.name}' if self.check_bounds else ''
                    content = pragma.content + f' private({stack_var.name}{var_end_str})'
                pragma_map[pragma] = pragma.clone(content=content)

        if pragma_map:
            routine.body = Transformer(pragma_map).visit(routine.body)

        # Find first block loop and assign local stack pointers there
        loop_map = {}
        for loop in FindNodes(Loop).visit(routine.body):
            assignments = FindNodes(Assignment).visit(loop.body)
            if loop.variable != self.block_dim.index:
                # Check if block variable is assigned in loop body
                for assignment in assignments:
                    if assignment.lhs == self.block_dim.index:
                        assert assignment in loop.body
                        # Need to insert the pointer assignment after block dimension is set
                        assign_pos = loop.body.index(assignment)
                        break
                else:
                    warning(
                        f'{self.__class__.__name__}: '
                        f'Could not find a block dimension for loop with variable {loop.variable} and '
                        f'bounds {loop.bounds} in {routine.name}; no stack pointer assignment inserted!'
                    )
                    continue
            else:
                # block variable is the loop variable: pointer assignment can happen
                # at the beginning of the loop body
                assign_pos = -1

            # Check for existing pointer assignment
            if any(a.lhs == f'{self.stack_local_var_name}_{self.stack_ptr_name}' for a in assignments):
                debug(
                    f'{self.__class__.__name__}: '
                    f'Stack (pointer) already exists within/for loop with variable {loop.variable} and '
                    f'bounds {loop.bounds} in {routine.name}; thus no stack pointer assignment inserted!'
                )
                break
            if self.cray_ptr_loc_rhs:
                ptr_assignment = Assignment(lhs=stack_ptr, rhs=IntLiteral(1))
            else:
                ptr_assignment = Assignment(
                    lhs=stack_ptr, rhs=InlineCall(
                        function=Variable(name='LOC'),
                        parameters=(
                            stack_storage.clone(
                                dimensions=(Literal(1), Variable(name=self.block_dim.index, scope=routine))
                            ),
                        ),
                        kw_parameters=None
                    )
                )

            # Retrieve kind parameter of stack storage
            _kind = routine.imported_symbol_map.get('REAL64')

            # Stack increment
            if self.cray_ptr_loc_rhs:
                stack_incr = Assignment(
                    lhs=stack_end, rhs=Sum((stack_ptr, stack_size_var))
                )
            else:
                _real_size_bytes = Cast(name='REAL', expression=Literal(1), kind=_kind)
                _real_size_bytes = InlineCall(Variable(name='C_SIZEOF'),
                                              parameters=as_tuple(_real_size_bytes))
                stack_incr = Assignment(
                    lhs=stack_end, rhs=Sum((stack_ptr, Product((stack_size_var, _real_size_bytes))))
                )
            new_assignments = (ptr_assignment,)
            if self.check_bounds:
                new_assignments += (stack_incr,)
            loop_map[loop] = loop.clone(
                body=loop.body[:assign_pos + 1] + new_assignments + loop.body[assign_pos + 1:]
            )

        if loop_map:
            routine.body = Transformer(loop_map).visit(routine.body)

    def inject_pool_allocator_into_calls(self, routine, targets, ignore, driver=False):
        """
        Add the pool allocator argument into subroutine calls
        """
        call_map = {}

        # Careful to not use self._get_stack_arg, as it will
        # inject a delaration which the driver cannot do!
        stack_var = self._get_local_stack_var(routine)
        stack_arg_name = f'{self.stack_argument_name}_{self.stack_ptr_name}'
        new_kwarguments = ((stack_arg_name, stack_var),)

        if self.check_bounds:
            stack_var_end = self._get_local_stack_var_end(routine)
            stack_arg_end_name = f'{self.stack_argument_name}_{self.stack_end_name}'
            new_kwarguments += ((stack_arg_end_name, stack_var_end),)

        if self.cray_ptr_loc_rhs:
            stack_storage_var = routine.variable_map[self.stack_storage_name]
            if driver:
                stack_storage_var_dim = list(stack_storage_var.dimensions)
                stack_storage_var_dim[1] = routine.variable_map[self.block_dim.index]
            else:
                stack_storage_var_dim = None
            dimensions = as_tuple(stack_storage_var_dim)
            new_kwarguments += ((stack_storage_var.name, stack_storage_var.clone(dimensions=dimensions)),)

        for call in FindNodes(CallStatement).visit(routine.body):
            if call.name in targets or call.routine.name.lower() in ignore:
               # If call is declared via an explicit interface, the ProcedureSymbol corresponding to the call is the
               # interface block rather than the Subroutine itself. This means we have to update the interface block
               # accordingly
                if call.name in [s for i in FindNodes(Interface).visit(routine.spec) for s in i.symbols]:
                    _ = self._get_stack_arg(call.routine)

                if call.routine != BasicType.DEFERRED and stack_arg_name in call.routine.arguments:
                    call_map[call] = call.clone(
                        kwarguments=call.kwarguments + new_kwarguments
                    )

        if call_map:
            routine.body = Transformer(call_map).visit(routine.body)

        # Now repeat the process for InlineCalls
        call_map = {}
        for call in FindInlineCalls().visit(routine.body):
            if call.name.lower() in [t.lower() for t in targets]:
                call_map[call] = call.clone(
                    kw_parameters=as_tuple(call.kw_parameters) + new_kwarguments
                )

        if call_map:
            routine.body = SubstituteExpressions(call_map).visit(routine.body)


class EcstackPoolAllocatorTransformation(TemporariesPoolAllocatorTransformation):
    """
    Analog to :any:`TemporariesPoolAllocatorTransformation`, however, instead of
    inserting offload pragmas use an external defined module to get a pointer to
    an offloaded chunk of memory.

    The minimal interface expected from ECSTACK should look like:

    .. code-block:: fortran

        MODULE ECSTACK_MOD

        IMPLICIT NONE

        TYPE TECSTACK
          ...
        CONTAINS
          PROCEDURE :: GET_STACK_PTR
        END TYPE TECSTACK

        PRIVATE

        TYPE(TECSTACK) :: ECSTACK
        PUBLIC :: TECSTACK, ECSTACK

        CONTAINS

        SUBROUTINE GET_STACK_PTR(SELF, PTR, KSIZE, NGPBLKS)
           CLASS(TECSTACK) :: SELF
           REAL(KIND=JPRD), POINTER, CONTIGUOUS, INTENT(INOUT) :: PTR(:, :)
           INTEGER(KIND=JPIM), INTENT(IN) :: KSIZE
           INTEGER(KIND=JPIM), INTENT(IN) :: NGPBLKS

           ...
        END SUBROUTINE GET_STACK_PTR

        END MODULE ECSTACK_MOD
    """

    def add_driver_imports(self, routine):
        self.import_ecstack(routine)

    @staticmethod
    def import_ecstack(routine):
        if not 'ECSTACK' in routine.imported_symbols:
            imp = Import(
                    module='ECSTACK_MOD', symbols=as_tuple(Variable(name='ECSTACK')),
            )
            routine.spec.prepend(imp)

    def _get_stack_alloc(self, routine, stack_storage, stack_size_var, block_size):
        stack_alloc_call_name = ProcedureSymbol(name="GET_STACK_PTR",
                parent=routine.imported_symbol_map['ecstack'], scope=routine)
        stack_alloc = CallStatement(name=stack_alloc_call_name,
                arguments=(stack_storage.clone(dimensions=None), stack_size_var, block_size))
        return stack_alloc

    def _get_stack_dealloc(self, routine, stack_storage, stack_size_var, block_size):
        return None

    def _get_pragma_start(self, routine, stack_storage):
        pragma = Pragma(
            keyword='loki',
            content=f'structured-data present({stack_storage.name})' # pylint: disable=no-member
        )
        return pragma

    def _get_pragma_end(self, routine, stack_storage):
        return Pragma(keyword='loki', content='end structured-data')

    def _get_stack_type(self, routine):
        stack_type = SymbolAttributes(
            dtype=BasicType.REAL,
            kind=Variable(name='REAL64', scope=routine),
            shape=(RangeIndex((None, None)), RangeIndex((None, None))),
            pointer=True, contiguous=True
        )
        return stack_type
loki-ecmwf-0.3.6/loki/transformations/pragma_model.py0000664000175000017500000003756015167130205023144 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import inspect

from loki.batch import Transformation, ProcedureItem, ModuleItem
from loki.ir import (
    FindNodes, Pragma, Transformer, get_pragma_command_and_parameters
)

__all__ = ['PragmaModelTransformation']


class GenericPragmaMapper:
    """
    A generic pragma mapper class.

    Pragmas in the form

    ``!$loki command-optionally-with-hyphen [param] [param_with_val(val)]``

    get a visitor/handler method that looks like

    .. code-block::
       def visit_command_optionally_with_hyphen(self, pragma, [**kwargs]):
           pass

    The handler is responsible for returning either None or the updated
    pragma.
    """
    # pylint: disable=unused-argument
    def __init__(self):
        handlers = {}
        prefix = "pmap_"
        for (name, meth) in inspect.getmembers(self, predicate=inspect.ismethod):
            if not name.startswith(prefix):
                continue
            argspec = inspect.getfullargspec(meth)
            if len(argspec.args) < 2:
                raise RuntimeError("Visit method signature must be "
                                   "pmap_foo(self, pragma, [**kwargs])")
            handlers[name[len(prefix):]] = meth
        self._handlers = handlers

    def lookup_method(self, starts_with):
        try:
            return self._handlers[starts_with.lower()]
        except KeyError:
            return None

    @classmethod
    def default_retval(cls):
        """
        Default return value for handler methods.

        Returns
        -------
        None
        """
        return None

    def pmap(self, pragma, **kwargs):
        starts_with, parameters = get_pragma_command_and_parameters(pragma)
        meth = self.lookup_method(starts_with.lower().replace('-', '_'))
        if meth is not None:
            return meth(pragma, parameters, **kwargs)
        return self.default_retval()


class OpenACCPragmaMapper(GenericPragmaMapper):
    """
    Loki generic pragmas to OpenACC mapper.
    """
    # pylint: disable=unused-argument
    def pmap_create(self, pragma, parameters, **kwargs):
        if param_device := parameters.get('device'):
            return Pragma(keyword='acc', content=f'declare create({param_device})')
        return self.default_retval()

    def pmap_update(self, pragma, parameters, **kwargs):
        content = ''
        if param_device := parameters.get('device'):
            content += f' device({param_device})'
        if param_host := parameters.get('host'):
            content += f' self({param_host})'
        if content:
            return Pragma(keyword='acc', content=f'update{content}')
        return self.default_retval()

    def pmap_unstructured_data(self, pragma, parameters, **kwargs):
        content = ''
        if param_in := parameters.get('in'):
            content += f' copyin({param_in})'
        if param_create := parameters.get('create'):
            content += f' create({param_create})'
        if param_attach := parameters.get('attach'):
            content += f' attach({param_attach})'
        if content:
            return Pragma(keyword='acc', content=f'enter data{content}')
        return self.default_retval()

    def pmap_exit_unstructured_data(self, pragma, parameters, **kwargs):
        content = ''
        if params_out := parameters.get('out'):
            content += f' copyout({params_out})'
        if params_delete := parameters.get('delete'):
            content += f' delete({params_delete})'
        if param_detach := parameters.get('detach'):
            content += f' detach({param_detach})'
        if content:
            # Rather than simply decrementing the dynamic reference counter,
            # finalize forces it to zero. This isn't needed for OpenMP, where
            # target exit data map(delete:<>) statement already sets the
            # dynamic reference counter to 0
            final = ' finalize' if 'finalize' in parameters else ''
            return Pragma(keyword='acc', content=f'exit data{content}{final}')
        return self.default_retval()

    def pmap_structured_data(self, pragma, parameters, **kwargs):
        content = ''
        if params_in := parameters.get('in'):
            content += f' copyin({params_in})'
        if params_inout := parameters.get('inout'):
            content += f' copy({params_inout})'
        if params_out := parameters.get('out'):
            content += f' copyout({params_out})'
        if params_create := parameters.get('create'):
            content += f' create({params_create})'
        if params_default := parameters.get('default'):
            content += f' default({params_default})'
        if params_default := parameters.get('present'):
            content += f' present({params_default})'
        if params_asynchronous := parameters.get('async'):
            content += f' async({params_asynchronous})'
        if content:
            return Pragma(keyword='acc', content=f'data{content}')
        return self.default_retval()

    def pmap_end_structured_data(self, pragma, parameters, **kwargs):
        return Pragma(keyword='acc', content='end data')

    def pmap_routine(self, pragma, parameters, **kwargs):
        if 'seq' in parameters:
            return Pragma(keyword='acc', content='routine seq')
        if 'vector' in parameters:
            return Pragma(keyword='acc', content='routine vector')
        return self.default_retval()

    def pmap_loop(self, pragma, parameters, **kwargs):
        if 'seq' in parameters:
            return Pragma(keyword='acc', content='loop seq')
        if 'vector' in parameters:
            private_param = parameters.get('private')
            private = f' private({private_param})' if private_param else ''
            fprivate_param = parameters.get('firstprivate')
            fprivate = f' firstprivate({fprivate_param})' if fprivate_param else ''
            reduction_param = parameters.get('reduction')
            reduction = f' reduction({reduction_param})' if reduction_param else ''
            content = f'loop vector{private}{fprivate}{reduction}'
            return Pragma(keyword='acc', content=content)
        if 'gang' in parameters:
            private_param = parameters.get('private')
            private = f' private({private_param})' if private_param else ''
            fprivate_param = parameters.get('firstprivate')
            fprivate = f' firstprivate({fprivate_param})' if fprivate_param else ''
            vlength_param = parameters.get('vlength')
            vlength = f' vector_length({vlength_param})' if vlength_param else ''
            asynchronous_param = parameters.get('async')
            asynchronous = f' async({asynchronous_param})' if asynchronous_param else ''
            content = f'parallel loop gang{private}{fprivate}{vlength}{asynchronous}'
            return Pragma(keyword='acc', content=content)
        return self.default_retval()

    def pmap_end_loop(self, pragma, parameters, **kwargs):
        if 'gang' in parameters:
            return Pragma(keyword='acc', content='end parallel loop')
        return self.default_retval()

    def pmap_device_present(self, pragma, parameters, **kwargs):
        asynchronous_param = parameters.get('async')
        asynchronous = f' async({asynchronous_param})' if asynchronous_param else ''
        if param_vars := parameters.get('vars'):
            return Pragma(keyword='acc', content=f'data present({param_vars})'+asynchronous)
        return self.default_retval()

    def pmap_end_device_present(self, pragma, parameters, **kwargs):
        return Pragma(keyword='acc', content='end data')

    def pmap_device_ptr(self, pragma, parameters, **kwargs):
        asynchronous_param = parameters.get('async')
        asynchronous = f' async({asynchronous_param})' if asynchronous_param else ''
        if param_vars := parameters.get('vars'):
            return Pragma(keyword='acc', content=f'data deviceptr({param_vars})'+asynchronous)
        return self.default_retval()

    def pmap_end_device_ptr(self, pragma, parameters, **kwargs):
        return Pragma(keyword='acc', content='end data')


class OpenMPOffloadPragmaMapper(GenericPragmaMapper):
    """
    Loki generic pragmas to OpenMP offload/GPU mapper.

    TODO: this is not yet complete!
    """
    # pylint: disable=unused-argument
    def pmap_create(self, pragma, parameters, **kwargs):
        if param_device := parameters.get('device'):
            return Pragma(keyword='omp', content=f'declare target({param_device})')
        return self.default_retval()

    def pmap_update(self, pragma, parameters, **kwargs):
        content = ''
        if param_device := parameters.get('device'):
            content += f' to({param_device})'
        if param_host := parameters.get('host'):
            content += f' from({param_host})'
        if content:
            return Pragma(keyword='omp', content=f'target update{content}')
        return self.default_retval()

    def pmap_unstructured_data(self, pragma, parameters, **kwargs):
        content = ''
        if param_in := parameters.get('in'):
            content += f' map(to: {param_in})'
        if param_create := parameters.get('create'):
            content += f' map(alloc: {param_create})'
        if content:
            return Pragma(keyword='omp', content=f'target enter data{content}')
        return self.default_retval()

    def pmap_exit_unstructured_data(self, pragma, parameters, **kwargs):
        content = ''
        if params_out := parameters.get('out'):
            content += f' map(from: {params_out})'
        if params_delete := parameters.get('delete'):
            content += f' map(delete: {params_delete})'
        if content:
            return Pragma(keyword='omp', content=f'target exit data{content}')
        return self.default_retval()

    def pmap_structured_data(self, pragma, parameters, **kwargs):
        content = ''
        params_in = parameters.get('in', None)
        params_present = parameters.get('present', None)
        # both 'in'/'copyin' and 'present' map to 'map(to: ...)'
        if params_in is not None and params_present is not None:
            content += f' map(to: {params_in}, {params_present})'
        else:
            if params_in is not None:
                content += f' map(to: {params_in})'
            if params_present is not None:
                content += f' map(to: {params_present})'
        if params_inout := parameters.get('inout'):
            content += f' map(tofrom: {params_inout})'
        if params_out := parameters.get('out'):
            content += f' map(from: {params_out})'
        if params_create := parameters.get('create'):
            content += f' map(alloc: {params_create})'
        if content:
            return Pragma(keyword='omp', content=f'target data{content}')
        return self.default_retval()

    def pmap_end_structured_data(self, pragma, parameters, **kwargs):
        return Pragma(keyword='omp', content='end target data')

    def pmap_routine(self, pragma, parameters, **kwargs):
        if 'seq' in parameters:
            return Pragma(keyword='omp', content='declare target')
        return self.default_retval()

    def pmap_loop(self, pragma, parameters, **kwargs):
        if 'vector' in parameters:
            # TODO: private and reduction clause?
            content = 'parallel do'
            return Pragma(keyword='omp', content=content)
        if 'gang' in parameters:
            # TODO: private clause?
            vlength_param = parameters.get('vlength')
            vlength = f' thread_limit({vlength_param})' if vlength_param else ''
            content = f'target teams distribute{vlength}'
            return Pragma(keyword='omp', content=content)
        return self.default_retval()

    def pmap_end_loop(self, pragma, parameters, **kwargs):
        if 'vector' in parameters:
            return Pragma(keyword='omp', content='end parallel do')
        if 'gang' in parameters:
            return Pragma(keyword='omp', content='end target teams distribute')
        return self.default_retval()

    def pmap_omp_update_global_vars(self, pragma, parameters, **kwargs):
        # this shouldn't be necessary but is currently necessary because of a bug in OpenMP
        if params_in := parameters.get('in'):
            return Pragma(keyword='omp', content=f'target enter data map(to: {params_in})')
        return self.default_retval()



class OpenMPThreadingPragmaMapper(GenericPragmaMapper):
    """
    Loki generic pragmas to OpenMP CPU mapper.

    TODO: this is obviously incomplete!
    """
    # pylint: disable=unused-argument
    def pmap_loop(self, pragma, parameters, **kwargs):
        if 'gang' in parameters:
            private_param = parameters.get('private')
            private = f' private({private_param})' if private_param else ''
            fprivate_param = parameters.get('firstprivate')
            fprivate = f' firstprivate({fprivate_param})' if fprivate_param else ''
            default_param = parameters.get('default')
            default = f' default({default_param})' if default_param else ''
            content = f'parallel do {default}{private}{fprivate}'
            return Pragma(keyword='omp', content=content)
        return self.default_retval()


class PragmaModelTransformation(Transformation):
    """
    Transformation to map Loki generic pragmas to a specific
    pragma model using a child class of :any:`GenericPragmaMapper`.

    For the mapping between Loki directives and programming model-specific annotations,
    see :ref:`programming_models:Loki directives`.

    Parameters
    ----------
    directive : False, str
        The directive(s) to be used, used to determine which
        child class of :any:`GenericPragmaMapper` is used.  Use
        ``False`` to suppress the directive translation entirely.
    keep_loki_pragmas: bool
        Keep or remove generic Loki pragmas that are not
        mapped.
    """
    item_filter = (ProcedureItem, ModuleItem)

    def __init__(self, directive=False, keep_loki_pragmas=True):
        assert directive in [False, 'openacc', 'omp-gpu', 'openmp']
        self.directive = directive
        self.keep_loki_pragmas = keep_loki_pragmas
        pmapper_cls_map = {
            'openacc': OpenACCPragmaMapper,
            'omp-gpu': OpenMPOffloadPragmaMapper,
            'openmp': OpenMPThreadingPragmaMapper,
        }
        pmapper_cls = pmapper_cls_map.get(self.directive, None if self.keep_loki_pragmas else GenericPragmaMapper)
        self.pmapper = pmapper_cls() if pmapper_cls else None

    def _create_pragma_map(self, pragmas):
        pragma_map = {}
        for pragma in pragmas:
            new_pragma = self.pmapper.pmap(pragma)
            # either keep loki pragmas that do not have a mapping
            if self.keep_loki_pragmas:
                if new_pragma is not None:
                    pragma_map[pragma] = new_pragma
            # or remove, since pmap returns None
            else:
                pragma_map[pragma] = new_pragma
        return pragma_map

    def transform_module(self, module, **kwargs):
        if self.pmapper is None:
            return
        loki_pragmas = [pragma for pragma in FindNodes(Pragma).visit(module.spec) if pragma.keyword.lower() == 'loki']
        pragma_map = self._create_pragma_map(loki_pragmas)

        module.spec = Transformer(pragma_map).visit(module.spec)

    def transform_subroutine(self, routine, **kwargs):
        if self.pmapper is None:
            return
        loki_pragmas = [pragma for pragma in FindNodes(Pragma).visit(routine.ir) if pragma.keyword.lower() == 'loki']
        pragma_map = self._create_pragma_map(loki_pragmas)

        routine.spec = Transformer(pragma_map).visit(routine.spec)
        routine.body = Transformer(pragma_map).visit(routine.body)
loki-ecmwf-0.3.6/loki/transformations/remove_code.py0000664000175000017500000005526315167130205023004 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Collection of utilities to automatically remove code elements or
section and to perform Dead Code Elimination.
"""

import operator as op

from loki.analyse import dataflow_analysis_attached
from loki.batch import Transformation
from loki.expression import simplify, symbols as sym, symbolic_op
from loki.ir import nodes as ir, Transformer, FindNodes, FindVariables
from loki.ir.pragma_utils import (
    is_loki_pragma, pragma_regions_attached, get_pragma_parameters
)
from loki.program_unit import ProgramUnit
from loki.tools import flatten, as_tuple, OrderedSet
from loki.types import BasicType


__all__ = [
    'RemoveCodeTransformation',
    'do_remove_dead_code', 'RemoveDeadCodeTransformer',
    'do_remove_marked_regions', 'RemoveRegionTransformer',
    'do_remove_calls', 'RemoveCallsTransformer', 'do_remove_unused_vars'
]


class RemoveCodeTransformation(Transformation):
    """
    A :any:`Transformation` that provides named call and import
    removal, code removal of pragma-marked regions and Dead Code
    Elimination for batch processing via the :any:`Scheduler`.

    The transformation will apply the following methods in order:

    * :any:`do_remove_calls`
    * :any:`do_remove_marked_regions`
    * :any:`do_remove_dead_code`
    * :any:`do_remove_unused_vars`

    Parameters
    ----------
    remove_marked_regions : boolean
        Flag to trigger the use of :meth:`remove_marked_regions`;
        default: ``True``
    mark_with_comment : boolean
        Flag to trigger the insertion of a marker comment when
        removing a region; default: ``True``.
    replacement_call : optional, str
        Name of the "abort" subroutine to call if a replacement call
        is to be inserted in :meth:`do_remove_marked_regions`.
    replacement_msg : optional, str
        Optional error message that will be passed as argument to
        the replacement call in :meth:`do_remove_marked_regions`.
    replacement_module : optional, str
        Optional name of the module from which to import the
        replacement subroutine in :meth:`do_remove_marked_regions`.
    remove_dead_code : boolean
        Flag to trigger the use of :meth:`remove_dead_code`;
        default: ``False``
    use_simplify : boolean
        Use :any:`simplify` when branch pruning in during
        :meth:`remove_dead_code`.
    call_names : list of str
        List of subroutine names against which to match
        :any:`CallStatement` nodes during :meth:`remove_calls`.
    intrinsic_names : list of str
        List of module names against which to match :any:`Intrinsic`
        nodes during :meth:`remove_calls`.
    remove_imports : boolean
        Flag indicating whether to remove symbols from :any:`Import`
        objects during :meth:`remove_calls`; default: ``True``
    kernel_only : boolean
        Only apply the configured removal to subroutines marked as
        "kernel"; default: ``False``
    remove_unused_args : boolean
        Remove unused dummy arguments from routines.
    remove_unused_vars : boolean
        Remove unused variables/locals from routines.
    remove_only_arrays : boolean
        Whether to only remove unused arrays from routines
        or all variables/locals.
    """

    _key = 'RemoveCodeTransformation'

    # Recurse to subroutines in ``contains`` clause
    recurse_to_internal_procedures = True
    reverse_traversal = True

    def __init__(
            self, remove_marked_regions=True, mark_with_comment=True,
            replacement_call=None, replacement_msg=None, replacement_module=None,
            remove_dead_code=False, use_simplify=True,
            call_names=None, intrinsic_names=None,
            remove_imports=True, kernel_only=False,
            remove_unused_args=False, remove_unused_vars=False,
            remove_only_arrays=True
    ):
        self.remove_marked_regions = remove_marked_regions
        self.mark_with_comment = mark_with_comment
        self.replacement_call = replacement_call
        self.replacement_msg = replacement_msg
        self.replacement_module = replacement_module

        self.remove_dead_code = remove_dead_code
        self.use_simplify = use_simplify

        self.call_names = as_tuple(call_names)
        self.intrinsic_names = as_tuple(intrinsic_names)
        self.remove_imports = remove_imports

        self.kernel_only = kernel_only
        self.remove_unused_args = remove_unused_args

        self.remove_unused_vars = remove_unused_vars
        self.remove_only_arrays = remove_only_arrays

    def transform_subroutine(self, routine, **kwargs):

        if kwargs.get('role') == 'kernel' or not self.kernel_only:
            # Apply named node removal to strip specific calls
            if self.call_names or self.intrinsic_names:
                do_remove_calls(
                    routine, call_names=self.call_names,
                    intrinsic_names=self.intrinsic_names,
                    remove_imports=self.remove_imports
                )

            # Apply marked region removal
            if self.remove_marked_regions:
                do_remove_marked_regions(
                    routine, mark_with_comment=self.mark_with_comment,
                    replacement_call=self.replacement_call,
                    replacement_msg=self.replacement_msg,
                    replacement_module=self.replacement_module
                )

            # Apply Dead Code Elimination
            if self.remove_dead_code:
                do_remove_dead_code(routine, use_simplify=self.use_simplify)

            if self.remove_unused_vars:
                do_remove_unused_vars(routine, remove_only_arrays=self.remove_only_arrays)

        if self.remove_unused_args and (item := kwargs['item']):
            # collect unused args from successors
            successors = kwargs['sub_sgraph'].successors(item=item)
            unused_args_map = {successor.ir: successor.trafo_data.get(self._key, {}).get('unused_args', {})
                               for successor in successors}
            do_remove_unused_call_args(routine, unused_args_map)

            if item.config.get('remove_unused_args', True) and kwargs['role'] == 'kernel':
                # find unused args
                unused_args, _ = find_unused_dummy_args_and_vars(routine)
                do_remove_unused_dummy_args(routine, unused_args)
                # store unused args
                item.trafo_data[self._key] = {'unused_args': unused_args}


def do_remove_unused_dummy_args(routine, unused_args):
    """
    Utility routine to remove unused dummy arguments from
    a given routine.

    Parameters
    ----------
    routine : :any:`Subroutine`
        A :any:`Subroutine` whose unused dummy arguments will be removed.
    unused_args : dict
       A dict mapping the unused dummy argument symbol to its position in the
       routine's argument list.This must be retrieved using the
       :any:`find_unused_dummy_args_and_vars` utility.
    """

    routine.variables = [a for a in routine.variables
                         if not a.name.lower() in unused_args]

def do_remove_unused_vars(routine, unused_vars=None, remove_only_arrays=True):
    """
    Utility routine to remove unused variables (or only local arrays) from a given routine.

    Parameters
    ----------
    routine : :any:`Subroutine`
        A :any:`Subroutine` whose unused dummy arguments will be removed.
    unused_args : dict, optional
        A list of unused vars. This can be retrieved using the
       :any:`find_unused_dummy_args_and_vars` utility.
    remove_only_arrays : bool, optional
        Whether to only remove arrays or all variables/temporaries
        that are unused within the routine.
    """
    if unused_vars is None:
        _, unused_vars = find_unused_dummy_args_and_vars(routine)
    if remove_only_arrays:
        unused_vars = [var for var in unused_vars if isinstance(var, sym.Array)]
    routine.variables = [var for var in routine.variables
                         if not var.name.lower() in unused_vars]

def do_remove_unused_call_args(routine, unused_args_map):
    """
    Utility routine to remove unused arguments from all the
    :any:`CallStatement`s in a given routine.

    Parameters
    ----------
    routine : :any:`Subroutine`
        A :any:`Subroutine` whose call statements will be updated.
    unused_args_map : dict
       A dict mapping the :any:`Subroutine` corresponding to the :any:`CallStatement`,
       accessed via the `any:`CallStatement`.routine property, to its unused arguments.
       The unused arguments must be retrieved using the :any:`find_unused_dummy_args`
       utility.
    """

    for call in FindNodes(ir.CallStatement).visit(routine.body):
        if call.routine is BasicType.DEFERRED or not unused_args_map.get(call.routine, None):
            continue

        unused_args = [call.arguments[c] for c in unused_args_map[call.routine].values() if c < len(call.arguments)]
        unused_kwargs = [(kw, arg) for kw, arg in call.kwarguments if kw.lower() in unused_args_map[call.routine]]

        new_args = [arg for arg in call.arguments if not arg in unused_args]
        new_kwargs = [(kw, arg) for kw, arg in call.kwarguments if not (kw, arg) in unused_kwargs]

        call._update(arguments=as_tuple(new_args), kwarguments=as_tuple(new_kwargs))


def find_unused_dummy_args_and_vars(routine):
    """
    Utility routine to find all the unused arguments in a :any:`Subroutine`.

    Parameters
    ----------
    routine : :any:`Subroutine`
        A :any:`Subroutine` to search for unused dummy arguments.

    Return
    ------
    unused_args : dict
       A dict mapping the unused dummy argument symbol to its position in the
       routine's argument list.
    """

    variable_map = routine.symbol_map
    with dataflow_analysis_attached(routine):
        used_or_defined_symbols = routine.body.uses_symbols | routine.body.defines_symbols

        # we search for symbols used to define array sizes
        used_or_defined_array_shapes = [s.shape for s in used_or_defined_symbols if isinstance(s, sym.Array)]
        used_or_defined_symbols |= FindVariables(unique=True).visit(used_or_defined_array_shapes)

        used_or_defined_symbols |= OrderedSet(variable_map.get(v.name_parts[0], v).clone(dimensions=None)
                                              for v in used_or_defined_symbols)

        unused_args = {a.clone(dimensions=None): c for c, a in enumerate(routine.arguments)
                       if not a.name.lower() in used_or_defined_symbols}
        routine_arg_names = [arg.name.lower() for arg in routine.arguments]
        local_vars = [var for var in routine.variables if var.name.lower() not in routine_arg_names]
        unused_vars = [var.clone(dimensions=None) for var in local_vars
                if not var.name.lower() in used_or_defined_symbols]

    return unused_args, unused_vars


def do_remove_dead_code(routine, use_simplify=True):
    """
    Perform Dead Code Elimination on the given :any:`Subroutine` object.

    Parameters
    ----------
    routine : :any:`Subroutine`
        The subroutine to which to apply dead code elimination.
    simplify : boolean
        Use :any:`simplify` when evaluating expressions for branch pruning.
    """

    transformer = RemoveDeadCodeTransformer(use_simplify=use_simplify)
    routine.body = transformer.visit(routine.body)


class RemoveDeadCodeTransformer(Transformer):
    """
    :any:`Transformer` class that removes provably unreachable code paths.

    The primary modification performed is to prune individual code branches
    under :any:`Conditional` nodes.

    Parameters
    ----------
    use_simplify : boolean
        Use :any:`simplify` when evaluating expressions for branch pruning.
    """

    def __init__(self, use_simplify=True, **kwargs):
        super().__init__(**kwargs)
        self.use_simplify = use_simplify

    def visit_Conditional(self, o, **kwargs):
        condition = self.visit(o.condition, **kwargs)
        body = as_tuple(flatten(as_tuple(self.visit(o.body, **kwargs))))
        else_body = as_tuple(flatten(as_tuple(self.visit(o.else_body, **kwargs))))

        if self.use_simplify:
            condition = simplify(condition)

        if condition == 'True':
            return body

        if condition == 'False':
            return else_body

        has_elseif = o.has_elseif and else_body and isinstance(else_body[0], ir.Conditional)
        return self._rebuild(o, tuple((condition,) + (body,) + (else_body,)), has_elseif=has_elseif)

    def visit_MultiConditional(self, o, **kwargs):
        # Get select expression and simplify if requested
        expr = self.visit(o.expr, **kwargs)
        expr = simplify(expr) if self.use_simplify else expr

        values = self.visit(o.values, **kwargs)
        bodies = self.visit(o.bodies, **kwargs)
        else_body = self.visit(o.else_body, **kwargs)

        for val, body in zip(values, bodies):
            # Equate select expression against case values
            for v in val:
                if symbolic_op(expr, op.eq, v):
                    return body

        if expr == 'False':
            # Simplify to default if always false
            return else_body

        return self._rebuild(o, tuple((expr,) + (values,) + (bodies,) + (else_body,)), name=o.name)


def do_remove_marked_regions(
        routine, mark_with_comment=True, replacement_call=None,
        replacement_msg=None, replacement_module=None
):
    """
    Utility routine to remove code regions marked with
    ``!$loki remove`` pragmas from a subroutine's body.

    Optionally, any removed region might be marked with a
    comment and/or a simple single-argument "abort" call. For this,
    a subroutine name and message can be specified and an optional
    check for an import can also be defined to ensure the interface
    for the abort procedure is available. To bypass the replacement
    call insertion for individual pragma regions use
    ``!$loki remove no-replacement-call``.

    Parameters
    ----------
    routine : :any:`Subroutine`
        The subroutine to which to apply dead code elimination.
    mark_with_comment : boolean
        Flag to trigger the insertion of a marker comment when
        removing a region; default: ``True``.
    replacement_call : optional, str
        Name of the "abort" subroutine to call if a replacement call
        is to be inserted.
    replacement_msg : optional, str
        Optional error message that will be passed as a single
        argument to the replacement call.
    replacement_module : optional, str
        Optional name of the module from which to import the
        replacement subroutine. This will only be inserted if a
        replacement was perfored and will not replace existing imports
        of the same module or symbols.
    """

    transformer = RemoveRegionTransformer(
        mark_with_comment=mark_with_comment,
        replacement_call=replacement_call,
        replacement_msg=replacement_msg,
    )

    with pragma_regions_attached(routine, keyword='loki'):
        routine.body = transformer.visit(routine.body, scope=routine)

    if transformer.replacement_done and replacement_module:
        # Get newly inject procedure symbol for the replacement call
        callsym = sym.ProcedureSymbol(replacement_call, scope=routine)

        # Inject import of replacement module if it does not exist
        import_map = {i.module: i for i in routine.imports}
        if imprt := import_map.get(replacement_module):
            if not any(s == replacement_call for s in imprt.symbols):
                imprt._update(symbols=imprt.symbols + (callsym,))
        else:
            routine.spec.prepend(ir.Import(
                module=f'{replacement_module}', symbols=(callsym,), c_import=False)
            )


class RemoveRegionTransformer(Transformer):
    """
    A :any:`Transformer` that removes code regions marked with
    ``!$loki remove`` pragmas.

    This :any:`Transformer` only removes :any:`PragmaRegion` nodes,
    and thus requires the IR tree to have pragma regions attached, for
    example via :meth:`pragma_regions_attached`.

    When removing a marked code region the transformer may leave a
    comment or a replacement call to trigger "abort" errors in the
    source to mark the previous location.

    Parameters
    ----------
    mark_with_comment : boolean
        Flag to trigger the insertion of a marker comment when
        removing a region; default: ``True``.
    replacement_call : optional, str
        Name of the "abort" subroutine to call if a replacement call
        is to be inserted.
    replacement_msg : optional, str
        Optional error message that will be passed as a single
        argument to the replacmeent call.
    """

    def __init__(
            self, mark_with_comment=True, replacement_call=None, replacement_msg=None, **kwargs
    ):
        super().__init__(**kwargs)

        self.mark_with_comment = mark_with_comment

        # Replace section with call to trigger abort messages!
        self.replacement_call = replacement_call
        self.replacement_msg = replacement_msg
        self.replacement_done = False

    def visit_PragmaRegion(self, o, **kwargs):
        """ Remove :any:`PragmaRegion` nodes with ``!$loki remove`` pragmas """

        # Skip if the bypass clause is present
        bypass = 'no-replacement-call' in get_pragma_parameters(o.pragma, starts_with='remove')
        if is_loki_pragma(o.pragma, starts_with='remove'):

            # Leave a comment to mark the removed region in source
            replacement = []
            if self.mark_with_comment:
                replacement.append(ir.Comment(text='! [Loki] Removed content of pragma-marked region!'))

            if self.replacement_call and not bypass:
                # Get the outer scope, to avoid picking associates
                routine = kwargs['scope']
                while not isinstance(routine, ProgramUnit):
                    routine = routine.parent

                # If requested add a call to a simple subroutine with an error message arg
                replacement.append(ir.CallStatement(
                    name=sym.ProcedureSymbol(self.replacement_call, scope=routine),
                    arguments=sym.Literal(str(self.replacement_msg.format(routine.name)))
                ))
                # Set a flag to trigger import injections
                self.replacement_done = True

            return as_tuple(replacement)

        # Recurse into the pragama region and rebuild
        rebuilt = tuple(self.visit(i, **kwargs) for i in o.children)
        return self._rebuild(o, rebuilt)


def do_remove_calls(
        routine, call_names=None, intrinsic_names=None, remove_imports=True
):
    """
    Utility routine to remove all :any:`CallStatement` nodes
    to specific named subroutines in a :any:`Subroutine`.

    For more information, see :any:`RemoveCallsTransformer`.

    Parameters
    ----------
    call_names : list of str
        List of subroutine names against which to match
        :any:`CallStatement` nodes.
    intrinsic_names : list of str
        List of module names against which to match :any:`Intrinsic`
        nodes.
    remove_imports : boolean
        Flag indicating whether to remove the respective procedure
        symbols from :any:`Import` objects; default: ``True``.
    """

    transformer = RemoveCallsTransformer(
        call_names=call_names, intrinsic_names=intrinsic_names,
        remove_imports=remove_imports
    )
    routine.spec = transformer.visit(routine.spec)
    routine.body = transformer.visit(routine.body)


class RemoveCallsTransformer(Transformer):
    """
    A :any:`Transformer` that removes all :any:`CallStatement` nodes
    to specific named subroutines.

    This :any:`Transformer` will by default also remove the enclosing
    inline-conditional when encountering calls of the form ```if
    (flag) call named_procedure()``.

    This :any:`Transformer` will also attempt to match and remove
    :any:`Intrinsic` nodes against a given list of name strings.  This
    allows removing intrinsic calls like ``write (*,*) "..."``.

    In addition, this :any:`Transformer` can also attempt to match and
    remove :any:`Import` nodes if given a list of strings to
    match. This can be used to remove the associated imports of the
    removed subroutines.

    Parameters
    ----------
    call_names : list of str
        List of subroutine names against which to match
        :any:`CallStatement` nodes.
    intrinsic_names : list of str
        List of module names against which to match :any:`Intrinsic`
        nodes.
    remove_imports : boolean
        Flag indicating whether to remove the respective procedure
        symbols from :any:`Import` objects; default: ``True``.
    """

    def __init__(
            self, call_names=None, intrinsic_names=None,
            remove_imports=True, **kwargs
    ):
        super().__init__(**kwargs)

        self.call_names = as_tuple(call_names)
        self.intrinsic_names = as_tuple(intrinsic_names)
        self.remove_imports = remove_imports

    def visit_CallStatement(self, o, **kwargs):
        """ Match and remove :any:`CallStatement` nodes against name patterns """
        if o.name in self.call_names:
            return None

        rebuilt = tuple(self.visit(i, **kwargs) for i in o.children)
        return self._rebuild(o, rebuilt)

    def visit_Conditional(self, o, **kwargs):
        """ Remove inline-conditionals after recursing into their body """

        # First, recurse into condition and bodies
        cond, body, else_body = tuple(self.visit(i, **kwargs) for i in o.children)

        # Capture and remove newly empty inline conditionals
        if o.inline and len(body) == 0:
            return None

        return self._rebuild(o, (cond, body, else_body))

    def visit_Intrinsic(self, o, **kwargs):
        """ Match and remove :any:`Intrinsic` nodes against name patterns """
        if self.intrinsic_names:
            if any(str(c).lower() in o.text.lower() for c in self.intrinsic_names):
                return None

        rebuilt = tuple(self.visit(i, **kwargs) for i in o.children)
        return self._rebuild(o, rebuilt)

    def visit_Import(self, o, **kwargs):
        """ Remove the symbol of any named calls from Import nodes """

        symbols_found = any(s in self.call_names for s in o.symbols)
        if self.remove_imports and symbols_found:
            new_symbols = tuple(s for s in o.symbols if s not in self.call_names)
            return o.clone(symbols=new_symbols) if new_symbols else None

        rebuilt = tuple(self.visit(i, **kwargs) for i in o.children)
        return self._rebuild(o, rebuilt)
loki-ecmwf-0.3.6/loki/transformations/utilities.py0000664000175000017500000007743615167130205022536 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Collection of utility routines to deal with general language conversion.
"""

import platform
from collections import defaultdict
from pymbolic.primitives import Expression
from loki.expression import (
    symbols as sym, SubstituteExpressionsMapper, ExpressionRetriever,
    TypedSymbol, MetaSymbol
)
from loki.ir import (
    nodes as ir, Import, TypeDef, VariableDeclaration, is_loki_pragma,
    StatementFunction, Transformer, FindNodes, FindVariables,
    FindInlineCalls, FindLiterals, SubstituteExpressions,
    ExpressionFinder
)
from loki.module import Module
from loki.subroutine import Subroutine
from loki.tools import CaseInsensitiveDict, as_tuple, OrderedSet
from loki.types import SymbolAttributes, BasicType, DerivedType, ProcedureType
from loki.config import config_override
from loki.logging import warning
from loki.ir.visitor import Visitor


__all__ = [
    'convert_to_lower_case', 'replace_intrinsics', 'rename_variables',
    'sanitise_imports', 'replace_selected_kind',
    'single_variable_declaration', 'recursive_expression_map_update',
    'get_integer_variable', 'get_loop_bounds', 'is_pragma_driver_loop', 'find_driver_loops',
    'get_local_arrays', 'check_routine_sequential', 'substitute_variables_for_definitions'
]


def single_variable_declaration(routine, variables=None, group_by_shape=False):
    """
    Modify/extend variable declarations to

    * default: only declare one variable each time while preserving the order if ``variables=None`` and
      ``group_by_shape=False``
    * declare variables specified in ``variables``in single/unique declarations if ``variables`` is a tuple
      of variables
    * variable declarations to be grouped according to their shapes if ``group_by_shape=True``

    Parameters
    ----------
    routine: :any:`Subroutine`
        The subroutine in which to modify the variable declarations
    variables: tuple
        Variables to grant unique/single declaration for
    group_by_shape: bool
        Whether to strictly make unique variable declarations or to only disassemble non-arrays and arrays and among
        arrays, arrays with differing shapes.
    """
    decl_map = {}
    for decl in FindNodes(VariableDeclaration).visit(routine.spec):
        if len(decl.symbols) > 1:
            if not group_by_shape:
                unique_symbols = [s for s in decl.symbols if variables is None or s.name in variables]
                if unique_symbols:
                    new_decls = tuple(decl.clone(symbols=(s,)) for s in unique_symbols)
                    retain_symbols = tuple(s for s in decl.symbols if variables is not None and s.name not in variables)
                    if retain_symbols:
                        decl_map[decl] = (decl.clone(symbols=retain_symbols),) + new_decls
                    else:
                        decl_map[decl] = new_decls
            else:
                smbls_by_shape = defaultdict(list)
                for smbl in decl.symbols:
                    smbls_by_shape[getattr(smbl, 'shape', None)] += [smbl]
                decl_map[decl] = tuple(decl.clone(symbols=as_tuple(smbls)) for smbls in smbls_by_shape.values())
    routine.spec = Transformer(decl_map).visit(routine.spec)
    # if variables defined and group_by_shape, first call ignores the variables, thus second call
    if variables and group_by_shape:
        single_variable_declaration(routine=routine, variables=variables, group_by_shape=False)


def convert_to_lower_case(routine):
    """
    Converts all variables and symbols in a subroutine to lower-case.

    Note, this is intended for conversion to case-sensitive languages.

    TODO: Should be extended to `Module` objects.
    """

    # Force all variables in a subroutine body to lower-caps
    variables = FindVariables(unique=False).visit(routine.ir)
    vmap = {
        v: v.clone(name=v.name.lower()) for v in variables
        if isinstance(v, (sym.Scalar, sym.Array, sym.DeferredTypeSymbol)) and not v.name.islower()\
                and not v.case_sensitive
    }

    # Capture nesting by applying map to itself before applying to the routine
    vmap = recursive_expression_map_update(vmap, case_sensitive=True)
    routine.body = SubstituteExpressions(vmap).visit(routine.body)
    routine.spec = SubstituteExpressions(vmap).visit(routine.spec)

    # Downcase inline calls to, but only after the above has been propagated,
    # so that we  capture the updates from the variable update in the arguments
    mapper = {
        c: c.clone(function=c.function.clone(name=c.name.lower() if not c.function.case_sensitive else c.name))
        for c in FindInlineCalls().visit(routine.ir) if not c.name.islower()
    }
    mapper.update(
        (stmt.variable, stmt.variable.clone(name=stmt.variable.name.lower()))
        for stmt in FindNodes(StatementFunction).visit(routine.spec)
    )
    mapper = recursive_expression_map_update(mapper, case_sensitive=True)
    routine.spec = SubstituteExpressions(mapper).visit(routine.spec)
    routine.body = SubstituteExpressions(mapper).visit(routine.body)


def replace_intrinsics(routine, function_map=None, symbol_map=None, case_sensitive=False):
    """
    Replace known intrinsic functions and symbols.

    Parameters
    ----------
    routine : :any:`Subroutine`
        The subroutine object in which to replace intrinsic calls
    function_map : dict[str, str]
        Mapping from function names (:any:`InlineCall` names) to
        their replacement
    symbol_map : dict[str, str]
        Mapping from intrinsic symbol names to their replacement
    case_sensitive : bool
        Match case for name lookups in :data:`function_map` and :data:`symbol_map`
    """
    symbol_map = symbol_map or {}
    function_map = function_map or {}
    if not case_sensitive:
        symbol_map = CaseInsensitiveDict(symbol_map)
        function_map = CaseInsensitiveDict(function_map)
    # (intrinsic) functions
    callmap = {}
    for call in FindInlineCalls(unique=False).visit(routine.ir):
        if call.name in symbol_map:
            callmap[call] = sym.Variable(name=symbol_map[call.name], scope=routine)

        if call.name in function_map:
            callmap[call.function] = sym.ProcedureSymbol(name=function_map[call.name], scope=routine)

    routine.spec = SubstituteExpressions(callmap).visit(routine.spec)
    routine.body = SubstituteExpressions(callmap).visit(routine.body)

def rename_variables(routine, symbol_map=None):
    """
    Rename symbols/variables including (routine) arguments.

    Parameters
    ----------
    routine : :any:`Subroutine`
        The subroutine object in which to rename variables.
    symbol_map : dict[str, str]
        Mapping from symbol/variable names to their replacement.
    """
    symbol_map = CaseInsensitiveDict(symbol_map) or {}
    # rename arguments if necessary
    arguments = ()
    renamed_arguments = ()
    for arg in routine.arguments:
        if arg.name in symbol_map:
            arguments += (arg.clone(name=symbol_map[arg.name]),)
            renamed_arguments += (arg,)
        else:
            arguments += (arg,)
    routine.arguments = arguments
    # remove variable declarations
    var_decls = FindNodes(VariableDeclaration).visit(routine.spec)
    var_decl_map = {}
    for var_decl in var_decls:
        new_symbols = ()
        for symbol in var_decl.symbols:
            if symbol not in renamed_arguments:
                new_symbols += (symbol,)
        if new_symbols:
            var_decl_map[var_decl] = var_decl.clone(symbols=new_symbols)
        else:
            var_decl_map[var_decl] = None
    routine.spec = Transformer(var_decl_map).visit(routine.spec)
    # rename variable declarations and usages
    var_map = {}
    for var in FindVariables(unique=False).visit(routine.ir):
        if var.name in symbol_map:
            new_var = symbol_map[var.name]
            if new_var is not None:
                var_map[var] = var.clone(name=symbol_map[var.name])
    if var_map:
        routine.spec = SubstituteExpressions(var_map).visit(routine.spec)
        routine.body = SubstituteExpressions(var_map).visit(routine.body)
    # remove duplicated variable declarations
    var_decls = FindNodes(VariableDeclaration).visit(routine.spec)
    already_declared = ()
    var_decl_map = {}
    for var_decl in var_decls:
        symbols = ()
        for symbol in var_decl.symbols:
            if symbol not in already_declared:
                symbols += (symbol,)
                already_declared += (symbol,)
        if symbols:
            if symbols != var_decl.symbols:
                var_decl_map[var_decl] = var_decl.clone(symbols=symbols)
        else:
            var_decl_map[var_decl] = None
    if var_decl_map:
        routine.spec = Transformer(var_decl_map).visit(routine.spec)
    # update symbol table - remove entries under the previous name
    var_map_names = [key.name.lower() for key in var_map]
    delete = [key for key in routine.symbol_attrs if key.lower() in var_map_names\
            or key.split('%')[0].lower() in var_map_names] # derived types
    for key in delete:
        del routine.symbol_attrs[key]

def used_names_from_symbol(symbol, modifier=str.lower):
    """
    Helper routine that yields the symbol names for the different types of symbols
    we may encounter.
    """
    if isinstance(symbol, str):
        return OrderedSet([modifier(symbol)])

    if isinstance(symbol, (sym.TypedSymbol, sym.MetaSymbol)):
        return OrderedSet([modifier(symbol.name)]) | used_names_from_symbol(symbol.type, modifier=modifier)

    if isinstance(symbol, SymbolAttributes):
        if isinstance(symbol.dtype, BasicType) and symbol.kind is not None:
            return {modifier(str(symbol.kind))}
        return used_names_from_symbol(symbol.dtype, modifier=modifier)

    if isinstance(symbol, (DerivedType, ProcedureType)):
        return OrderedSet([modifier(symbol.name)])

    return OrderedSet()


def eliminate_unused_imports(module_or_routine, used_symbols):
    """
    Eliminate any imported symbols (or imports alltogether) that are not
    in the set of used symbols.
    """
    imports = FindNodes(Import).visit(module_or_routine.spec)
    imported_symbols = [s for im in imports for s in im.symbols or []]

    redundant_symbols = {s for s in imported_symbols if s.name.lower() not in used_symbols}

    if redundant_symbols:
        imprt_map = {}
        for im in imports:
            if im.symbols is not None:
                symbols = tuple(s for s in im.symbols if s not in redundant_symbols)
                if not symbols:
                    # Symbol list is empty: Remove the import
                    imprt_map[im] = None
                elif len(symbols) < len(im.symbols):
                    # Symbol list is shorter than before: We need to replace that import
                    imprt_map[im] = im.clone(symbols=symbols)
        module_or_routine.spec = Transformer(imprt_map).visit(module_or_routine.spec)


def find_and_eliminate_unused_imports(routine):
    """
    Find all unused imported symbols and eliminate them from their import statements
    in the given routine and all contained members.
    Empty import statements are removed.

    The accumulated set of used symbols is returned.
    """
    # We need a custom expression retriever that does not return symbols used in Imports
    class SymbolRetriever(ExpressionFinder):

        retriever = ExpressionRetriever(lambda e: isinstance(e, (TypedSymbol, MetaSymbol)))

        def visit_Import(self, o, **kwargs):  # pylint: disable=unused-argument
            return ()

    # Find all used symbols
    used_symbols = OrderedSet.union(*[used_names_from_symbol(s)
                               for s in SymbolRetriever().visit([routine.spec, routine.body])])
    used_symbols |= OrderedSet.union(*[used_names_from_symbol(s) for s in routine.variables])
    for typedef in FindNodes(TypeDef).visit(routine.spec):
        used_symbols |= OrderedSet.union(*[used_names_from_symbol(s) for s in typedef.variables])

    # Recurse for contained subroutines/functions
    for member in routine.members:
        used_symbols |= find_and_eliminate_unused_imports(member)

    eliminate_unused_imports(routine, used_symbols)
    return used_symbols


def sanitise_imports(module_or_routine):
    """
    Sanitise imports by removing unused symbols and eliminating imports
    with empty symbol lists.

    Note that this is currently limited to imports that are identified to be :class:`Scalar`,
    :class:`Array`, or :class:`ProcedureSymbol`.
    """
    if isinstance(module_or_routine, Subroutine):
        find_and_eliminate_unused_imports(module_or_routine)
    elif isinstance(module_or_routine, Module):
        used_symbols = OrderedSet()
        for routine in module_or_routine.subroutines:
            used_symbols |= find_and_eliminate_unused_imports(routine)
        eliminate_unused_imports(module_or_routine, used_symbols)


class IsoFortranEnvMapper:
    """
    Mapper to convert other Fortran kind specifications to their definitions
    from ``iso_fortran_env``.
    """

    selected_kind_calls = ('selected_int_kind', 'selected_real_kind')

    def __init__(self, arch=None):
        if arch is None:
            arch = platform.machine()
        self.arch = arch.lower()
        self.used_names = CaseInsensitiveDict()

    @classmethod
    def is_selected_kind_call(cls, call):
        """
        Return ``True`` if the given call is a transformational function to
        select the kind of an integer or real type.
        """
        return isinstance(call, sym.InlineCall) and call.name.lower() in cls.selected_kind_calls

    @staticmethod
    def _selected_int_kind(r):
        """
        Return number of bytes required by the smallest signed integer type that
        is able to represent all integers n in the range -10**r < n < 10**r.

        This emulates the behaviour of Fortran's ``SELECTED_INT_KIND(R)``.

        Source: numpy.f2py.crackfortran
        https://github.com/numpy/numpy/blob/9e26d1d2be7a961a16f8fa9ff7820c33b25415e2/numpy/f2py/crackfortran.py#L2431-L2444

        :returns int: the number of bytes or -1 if no such type exists.
        """
        m = 10 ** r
        if m <= 2 ** 8:
            return 1
        if m <= 2 ** 16:
            return 2
        if m <= 2 ** 32:
            return 4
        if m <= 2 ** 63:
            return 8
        if m <= 2 ** 128:
            return 16
        return -1

    def map_selected_int_kind(self, scope, r):
        """
        Return the kind of the smallest signed integer type defined in
        ``iso_fortran_env`` that is able to represent all integers n
        in the range -10**r < n < 10**r.
        """
        byte_kind_map = {b: f'INT{8 * b}' for b in [1, 2, 4, 8]}
        kind = self._selected_int_kind(r)
        if kind in byte_kind_map:
            kind_name = byte_kind_map[kind]
            self.used_names[kind_name] = sym.Variable(name=kind_name, scope=scope)
            return self.used_names[kind_name]
        return sym.IntLiteral(-1)

    def _selected_real_kind(self, p, r=0, radix=0):  # pylint: disable=unused-argument
        """
        Return number of bytes required by the smallest real type that fulfils
        the given requirements:

        - decimal precision at least ``p``;
        - decimal exponent range at least ``r``;
        - radix ``r``.

        This resembles the behaviour of Fortran's ``SELECTED_REAL_KIND([P, R, RADIX])``.
        NB: This honors only ``p`` at the moment!

        Source: numpy.f2py.crackfortran
        https://github.com/numpy/numpy/blob/9e26d1d2be7a961a16f8fa9ff7820c33b25415e2/numpy/f2py/crackfortran.py#L2447-L2463

        :returns int: the number of bytes or -1 if no such type exists.
        """
        if p < 7:
            return 4
        if p < 16:
            return 8
        if self.arch.startswith(('aarch64', 'power', 'ppc', 'riscv', 's390x', 'sparc')):
            if p <= 20:
                return 16
        else:
            if p < 19:
                return 10
            if p <= 20:
                return 16
        return -1

    def map_selected_real_kind(self, scope, p, r=0, radix=0):
        """
        Return the kind of the smallest real type defined in
        ``iso_fortran_env`` that is able to fulfil the given requirements
        for decimal precision (``p``), decimal exponent range (``r``) and
        radix (``r``).
        """
        byte_kind_map = {b: f'REAL{8 * b}' for b in [4, 8, 16]}
        kind = self._selected_real_kind(p, r, radix)
        if kind in byte_kind_map:
            kind_name = byte_kind_map[kind]
            self.used_names[kind_name] = sym.Variable(name=kind_name, scope=scope)
            return self.used_names[kind_name]
        return sym.IntLiteral(-1)

    def map_call(self, call, scope):
        if not self.is_selected_kind_call(call):
            return call

        func = getattr(self, f'map_{call.name.lower()}')
        args = [int(arg) for arg in call.parameters]
        kwargs = {key: int(val) for key, val in call.kw_parameters.items()}

        return func(scope, *args, **kwargs)


def replace_selected_kind(routine):
    """
    Find all uses of ``selected_real_kind`` or ``selected_int_kind`` and
    replace them by their ``iso_fortran_env`` counterparts.

    This inserts imports for all used constants from ``iso_fortran_env``.
    """
    mapper = IsoFortranEnvMapper()

    # Find all selected_x_kind calls in spec and body
    calls = [call for call in FindInlineCalls().visit(routine.ir)
             if mapper.is_selected_kind_call(call)]

    # Need to pick out kinds in Literals explicitly
    calls += [literal.kind for literal in FindLiterals().visit(routine.ir)
              if hasattr(literal, 'kind') and mapper.is_selected_kind_call(literal.kind)]

    map_call = {call: mapper.map_call(call, routine) for call in calls}

    # Flush mapping through spec and body
    routine.spec = SubstituteExpressions(map_call).visit(routine.spec)
    routine.body = SubstituteExpressions(map_call).visit(routine.body)

    # Replace calls and literals hidden in variable kinds and inits
    for variable in routine.variables:
        if variable.type.kind is not None and mapper.is_selected_kind_call(variable.type.kind):
            kind = mapper.map_call(variable.type.kind, routine)
            routine.symbol_attrs[variable.name] = variable.type.clone(kind=kind)
        if variable.type.initial is not None:
            if mapper.is_selected_kind_call(variable.type.initial):
                initial = mapper.map_call(variable.type.initial, routine)
                routine.symbol_attrs[variable.name] = variable.type.clone(initial=initial)
            else:
                init_calls = [literal.kind for literal in FindLiterals().visit(variable.type.initial)
                              if hasattr(literal, 'kind') and mapper.is_selected_kind_call(literal.kind)]
                if init_calls:
                    init_map = {call: mapper.map_call(call, routine) for call in init_calls}
                    initial = SubstituteExpressions(init_map).visit(variable.type.initial)
                    routine.symbol_attrs[variable.name] = variable.type.clone(initial=initial)

    # Make sure iso_fortran_env symbols are imported
    if mapper.used_names:
        for imprt in FindNodes(Import).visit(routine.spec):
            if imprt.module.lower() == 'iso_fortran_env':
                # Update the existing iso_fortran_env import
                imprt_symbols = {str(s).lower() for s in imprt.symbols}
                missing_symbols = OrderedSet(mapper.used_names.keys()) - imprt_symbols
                symbols = as_tuple(imprt.symbols) + tuple(mapper.used_names[s] for s in missing_symbols)

                # Flush the change through the spec
                routine.spec = Transformer({imprt: Import(imprt.module, symbols=symbols)}).visit(routine.spec)
                break
        else:
            # No iso_fortran_env import present, need to insert a new one
            imprt = Import('iso_fortran_env', symbols=as_tuple(mapper.used_names.values()))
            routine.spec.prepend(imprt)


def recursive_expression_map_update(expr_map, max_iterations=10, mapper_cls=SubstituteExpressionsMapper,
                                    case_sensitive=None):
    """
    Utility function to apply a substitution map for expressions to itself

    The expression substitution mechanism :any:`SubstituteExpressions` and the
    underlying mapper :any:`SubstituteExpressionsMapper` replace nodes that
    are found in the substitution map by their corresponding replacement.

    However, expression nodes can be nested inside other expression nodes,
    e.g. via the ``parent`` or ``dimensions`` properties of variables.
    In situations, where such expression nodes as well as expression nodes
    appearing inside such properties are marked for substitution, it may
    be necessary to apply the substitution map to itself first. This utility
    routine takes care of that.

    Parameters
    ----------
    expr_map : dict
        The substitution map that should be updated
    max_iterations : int
        Maximum number of iterations, corresponds to the maximum level of
        nesting that can be replaced.
    mapper_cls: :any:`SubstituteExpressionsMapper`
       The underlying mapper to be used (default: :any:`SubstituteExpressionsMapper`).
    case_sensitive: bool (optional)
        Whether to check w/o case-sensitiviy for early termination opportunities
        (default: None, use the default/global case-sensitivy setting).
    """
    def apply_to_init_arg(name, arg, expr, mapper):
        # Helper utility to apply the mapper only to expression arguments and
        # retain the scope while rebuilding the node
        if isinstance(arg, (tuple, Expression)):
            return mapper(arg)
        if name == 'scope':
            return expr.scope
        return arg

    for _ in range(max_iterations):
        # We update the expression map by applying it to the children of each replacement
        # node, thus making sure node replacements are also applied to nested attributes,
        # e.g. call arguments or array subscripts etc.
        if issubclass(mapper_cls, SubstituteExpressionsMapper):
            # Need to check if we should pass the `expr_map` argument
            mapper = mapper_cls(expr_map)
        else:
            mapper = mapper_cls()
        prev_map, expr_map = expr_map, {
            expr: type(replacement)(**{
                name: apply_to_init_arg(name, arg, expr, mapper)
                for name, arg in zip(replacement.init_arg_names, replacement.__getinitargs__())
            })
            for expr, replacement in expr_map.items()
        }

        # Check for early termination opportunities, either with case-sensitivity
        #  being the default (`{}`) or with the provided value (`case-sensitive`)
        _case_sensitive = {'case-sensitive': case_sensitive} if case_sensitive is not None else {}
        with config_override(_case_sensitive):
            if prev_map == expr_map:
                break

    return expr_map


def get_integer_variable(routine, name):
    """
    Find a local variable in the routine, or create an integer-typed one.

    Parameters
    ----------
    routine : :any:`Subroutine`
        The subroutine in which to find the variable
    name : string
        Name of the variable to find the in the routine.
    """

    symbol_map = routine.symbol_map
    if name.split('%', maxsplit=1)[0] in symbol_map:
        v_index = routine.resolve_typebound_var(name, symbol_map)
    else:
        dtype = SymbolAttributes(BasicType.INTEGER)
        v_index = sym.Variable(name=name, type=dtype, scope=routine)
    return v_index


def get_loop_bounds(routine, dimension):
    """
    Check loop bounds for a particular :any:`Dimension` in a
    :any:`Subroutine`.

    Parameters
    ----------
    routine : :any:`Subroutine`
        Subroutine to perform checks on.
    dimension : :any:`Dimension`
        :any:`Dimension` object describing the variable conventions
        used to define the data dimension and iteration space.
    """

    bounds = ()
    variable_map = routine.variable_map
    for name, _bounds in zip(['start', 'end'], [dimension.lower, dimension.upper]):
        for bound in as_tuple(_bounds):
            # Recognise numeric strings, eg. "1" in ``1:n``
            if isinstance(bound, str) and bound.isnumeric():
                bounds += (sym.Literal(int(bound)),)
                break

            # Recognise typebound bound variables
            if bound.split('%', maxsplit=1)[0] in variable_map:
                bounds += (routine.resolve_typebound_var(bound, variable_map),)
                break
        else:
            raise RuntimeError(
                f'No {name} variable matching {_bounds[0]} found in {routine.name}'
            )

    return bounds


def is_pragma_driver_loop(loop):
    if loop.pragma:
        for pragma in loop.pragma:
            if is_loki_pragma(pragma, starts_with='driver-loop') or \
               is_loki_pragma(pragma, starts_with='loop driver'):
                return True
    return False


def is_driver_loop(loop, targets):
    """
    Test/check whether a given loop is a *driver loop*.

    Parameters
    ----------
    loop : :any: `Loop`
        The loop to test if it is a *driver loop*.
    targets : list or string
        List of subroutines that are to be considered as part of
        the transformation call tree.
    """
    if loop.pragma:
        for pragma in loop.pragma:
            if is_loki_pragma(pragma, starts_with='driver-loop') or \
               is_loki_pragma(pragma, starts_with='loop driver'):
                return True
    for call in FindNodes(ir.CallStatement).visit(loop.body):
        if call.name in targets:
            return True
    return False


def find_driver_loops(section, targets):
    """
    Find and return all driver loops in a given `section`.

    A *driver loop* is specified either by a call to a routine within
    `targets` or by the pragma `!$loki driver-loop`.

    If there are nested loops, then the highest level pragma-marked driver
    loop will take precedence and be considered the driver loop. If there
    is no pragma-marked driver loop, then the highest level loop that does
    not contain a driver loop will be considered the driver loop.

    Parameters
    ----------
    section : :any:`Section` or tuple
        The subroutine in which to find the driver loops.
    targets : list or string
        List of subroutines that are to be considered as part of
        the transformation call tree.
    """
    targets = [str(t).lower() for t in as_tuple(targets)]

    class FindDriverLoops(Visitor):
        """
        A  visitor that collects all driver loops in a section of the IR.

        If there are nested loops, then the highest level pragma-marked driver
        loop will take precedence and be considered the driver loop. If there
        is no pragma-marked driver loop, then the highest level loop that does
        not contain a driver loop will be considered the driver loop.
        """

        @classmethod
        def default_retval(cls):
            return False, False, []

        def __init__(self):
            super().__init__()
            self.driver_loops = []

        def visit_tuple(self, o, **kwargs):
            nested_pragma_loop = False
            nested_target_loops = []
            has_target_call = False
            for i in o:
                retval = self.visit(i, **kwargs)
                nested_pragma_loop |= retval[0]
                has_target_call |= retval[1]
                nested_target_loops += retval[2]

            return nested_pragma_loop, has_target_call, nested_target_loops

        visit_list = visit_tuple

        def visit_CallStatement(self, call, **_kwargs):
            if call.name in targets:
                return False, True, []
            return self.default_retval()

        def visit_Loop(self, loop, **kwargs):
            depth = kwargs.pop('depth', 0)
            if is_pragma_driver_loop(loop):
                # Propagate the presence of the pragma only if inside a loop nest
                self.driver_loops.append(loop)
                return depth > 0, False, []

            # Recurse into the (potential) loop nest
            nested_pragma_loop, has_target_call, nested_target_loops = self.visit(
                loop.body, depth=depth+1, **kwargs
            )

            if nested_pragma_loop:
                # If there is a pragma-marked driver loop, this takes precedence and
                # we reset the list of nested target loops
                if has_target_call:
                    warning("[Loki::find_driver_loops] Nested pragma marked driver loop inside loop"
                            f" with target call (skipping {loop}")
                    nested_target_loops = []

            elif has_target_call:
                # If there is a target call directly inside the loop, the current loop
                # is a potential driver loop (unless a pragma-annotated loop is present,
                # which is why the target loops are collected into a list first and only
                # added to self.driver_loops once we're back at depth==0)
                nested_target_loops = [loop]

            elif nested_target_loops and depth > 0:
                # If this is simply another loop around a nested target loop, we raise
                # the driver loop one level further up
                nested_target_loops = [loop]

            if depth == 0:
                if nested_target_loops:
                    self.driver_loops.extend(nested_target_loops)
                return self.default_retval()

            return nested_pragma_loop, False, nested_target_loops

    find_driver = FindDriverLoops()
    find_driver.visit(section, targets=targets)
    return find_driver.driver_loops


def get_local_arrays(routine, section, unique=True):
    """
    Collect all local temporary array symbols in a given section.

    Parameters
    ----------
    routine : :any:`Subroutine`
        The subroutine in which to find local arrays.
    section : :any:`Section` or tuple of :any:`Node`
        The section or list of nodes to scan for local temporary
        symbols.
    unique : bool, optional
        Flag whether to return unique instances of each symbol;
        default: ``False``
    """
    imported_symbols = routine.imported_symbols
    arg_names = tuple(a.lower() for a in routine._dummies)
    variables = FindVariables(unique=unique).visit(section)

    # Filter all variables by argument name to get local arrays
    arrays = [v for v in variables if isinstance(v, sym.Array) and not v.parent]
    arrays = [v for v in arrays if str(v.name).lower() not in arg_names]
    arrays = [v for v in arrays if v.name not in imported_symbols]

    return arrays


def check_routine_sequential(routine):
    """
    Check if routine is marked as "sequential".

    Parameters
    ----------
    routine : :any:`Subroutine`
        Subroutine to perform checks on.
    """
    for pragma in FindNodes(ir.Pragma).visit(routine.ir):
        if is_loki_pragma(pragma, starts_with='routine seq'):
            return True

    return False

def substitute_variables_for_definitions(routine, variables):
    """
    Substitute variables for definitions if applicable.

    Parameters
    ----------
    routine : :any:`Subroutine`
        Subroutine to remap the variables.
    variables : :any:`Expression`, list, tuple
        List of variables to remap.
    """
    variables = as_tuple(variables)
    var_map = {
        assignment.lhs: assignment.rhs
        for assignment in FindNodes(ir.Assignment).visit(routine.body)
        if assignment.lhs in variables
    }
    if var_map:
        remapped_variables = [var_map[var] if var in var_map else var for var in variables]
        return remapped_variables
    return variables
loki-ecmwf-0.3.6/loki/transformations/transform_loop.py0000664000175000017500000011301115167130205023543 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Collection of utility routines that provide loop transformations.

"""
import functools
from collections import defaultdict
import operator as op
import numpy as np

from loki.analyse import (
    dataflow_analysis_attached, read_after_write_vars,
    loop_carried_dependencies
)
from loki.batch import Transformation
from loki.expression import (
    symbols as sym, simplify, is_constant, symbolic_op, parse_expr,
    IntLiteral, get_pyrange, LoopRange
)
from loki.ir import (
    Loop, Conditional, Comment, Pragma, FindNodes, Transformer,
    NestedMaskedTransformer, is_parent_of, is_loki_pragma,
    get_pragma_parameters, pragmas_attached, SubstituteExpressions,
    FindVariables
)
from loki.logging import info, warning
from loki.tools import (
    flatten, as_tuple, CaseInsensitiveDict, binary_insertion_sort,
    optional, OrderedSet
)
from loki.transformations.array_indexing import (
    promotion_dimensions_from_loop_nest, promote_nonmatching_variables
)


__all__ = ['do_loop_interchange', 'do_loop_fusion', 'do_loop_fission', 'do_loop_unroll',
           'TransformLoopsTransformation']


from loki.analyse.util_polyhedron import Polyhedron

def eliminate_variable(polyhedron, index_or_variable):
    """
    Eliminate a variable from the polyhedron.

    Mathematically, this is a projection of the polyhedron onto the hyperplane
    H:={x|x_j=0} with x_j the dimension corresponding to the eliminated variable.

    This is an implementation of the Fourier-Motzkin elimination.

    :param :class:``Polyhedron`` polyhedron: the polyhedron to be reduced in dimension.
    :param index_or_variable: the index, name, or expression symbol that is to be
                              eliminated.
    :type index_or_variable: int or str or sym.Array or sym.Scalar

    :return: the reduced polyhedron.
    :rtype: :class:``Polyhedron``
    """
    if isinstance(index_or_variable, int):
        j = index_or_variable
    else:
        j = polyhedron.variable_to_index(index_or_variable)

    # Indices of lower bounds on x_j
    L = [i for i in range(polyhedron.A.shape[0]) if polyhedron.A[i,j] < 0]
    # Indices of upper bounds on x_j
    U = [i for i in range(polyhedron.A.shape[0]) if polyhedron.A[i,j] > 0]
    # Indices of constraints not involving x_j
    Z = [i for i in range(polyhedron.A.shape[0]) if i not in L+U]
    # Cartesian product of lower and upper bounds
    R = [(l, u) for l in L for u in U]

    # Project polyhedron onto hyperplane H:={x|x_j = 0}
    A = np.zeros(polyhedron.A.shape, dtype=np.dtype(int))
    b = np.zeros(polyhedron.b.shape, dtype=np.dtype(int))
    next_constraint = 0
    for idx in Z:
        A[next_constraint,:] = polyhedron.A[idx,:]
        b[next_constraint] = polyhedron.b[idx]
        next_constraint += 1
    for l, u in R:
        A[next_constraint,:] = polyhedron.A[u,j] * polyhedron.A[l,:] - polyhedron.A[l,j] * polyhedron.A[u,:]
        b[next_constraint] = polyhedron.A[u,j] * polyhedron.b[l] - polyhedron.A[l,j] * polyhedron.b[u]
        next_constraint += 1

    # TODO: normalize rows

    # Trim matrix and right hand side, eliminate j-th column
    A = np.delete(A[:next_constraint,:], j, axis=1)
    b = b[:next_constraint]
    variables = polyhedron.variables
    if variables is not None:
        variables = variables[:j] + variables[j+1:]
    return Polyhedron(A, b, variables)


def generate_loop_bounds(iteration_space, iteration_order):
    """
    Generate loop bounds according to a changed iteration order.

    This creates a new polyhedron representing the iteration space for the
    provided iteration order.

    :param :class:``Polyhedron`` iteration_space: the iteration space that
            should be reordered.
    :param list iteration_order: the new iteration order as a list of
            indices of iteration variables.

    :return: the reordered iteration space.
    :rtype: :class:``Polyhedron``
    """
    assert iteration_space.variables is not None
    assert len(iteration_order) <= len(iteration_space.variables)

    lower_bounds= [None] * len(iteration_order)
    upper_bounds= [None] * len(iteration_order)
    index_map = list(range(len(iteration_order)))
    reduced_polyhedron = iteration_space

    # Find projected loop bounds
    constraint_count = 0
    for var_idx in reversed(iteration_order):
        # Get index of variable in reduced polyhedron
        idx = index_map[var_idx]
        assert idx is not None
        # Store bounds for variable
        lower_bounds[var_idx] = reduced_polyhedron.lower_bounds(idx)
        upper_bounds[var_idx] = reduced_polyhedron.upper_bounds(idx)
        constraint_count += len(lower_bounds[var_idx]) + len(upper_bounds[var_idx])
        # Eliminate variable from polyhedron
        reduced_polyhedron = eliminate_variable(reduced_polyhedron, idx)
        # Update index map after variable elimination
        index_map[var_idx] = None
        index_map[var_idx+1:] = [i-1 for i in index_map[var_idx+1:]]

    # Build new iteration space polyhedron
    variables = [iteration_space.variables[i] for i in iteration_order]
    variables += iteration_space.variables[len(iteration_order):]
    A = np.zeros([constraint_count, len(variables)], dtype=np.dtype(int))
    b = np.zeros([constraint_count], dtype=np.dtype(int))
    next_constraint = 0
    for new_idx, var_idx in enumerate(iteration_order):
        # TODO: skip lower/upper bounds already fulfilled
        for bound in lower_bounds[var_idx]:
            lhs, rhs = Polyhedron.generate_entries_for_lower_bound(bound, variables, new_idx)
            A[next_constraint,:] = lhs
            b[next_constraint] = rhs
            next_constraint += 1
        for bound in upper_bounds[var_idx]:
            lhs, rhs = Polyhedron.generate_entries_for_lower_bound(bound, variables, new_idx)
            A[next_constraint,:] = -lhs
            b[next_constraint] = -rhs
            next_constraint += 1
    A = A[:next_constraint,:]
    b = b[:next_constraint]
    return Polyhedron(A, b, variables)


def get_nested_loops(loop, depth):
    """
    Helper routine to extract all loops in a loop nest.
    """
    loops = [loop]
    for _ in range(1, depth):
        loops_in_body = [node for node in loop.body if isinstance(node, Loop)]
        assert len(loops_in_body) == 1
        loop = loops_in_body[0]
        loops += [loop]
    return as_tuple(loops)


def get_loop_components(loops):
    """
    Helper routine to extract loop variables, ranges and bodies of list of loops.
    """
    loop_variables, loop_ranges, loop_bodies = zip(*[(loop.variable, loop.bounds, loop.body) for loop in loops])
    return (as_tuple(loop_variables), as_tuple(loop_ranges), as_tuple(loop_bodies))


def do_loop_interchange(routine, project_bounds=False):
    """
    Search for loops annotated with the `loki loop-interchange` pragma and attempt
    to reorder them.

    Note that this effectively just exchanges variable and bounds for each of the loops,
    leaving the rest (including bodies, pragmas, etc.) intact.
    """
    with pragmas_attached(routine, Loop):
        loop_map = {}
        for loop_nest in FindNodes(Loop).visit(routine.body):
            if not is_loki_pragma(loop_nest.pragma, starts_with='loop-interchange'):
                continue

            # Get variable order from pragma
            var_order = get_pragma_parameters(loop_nest.pragma).get('loop-interchange', None)
            if var_order:
                var_order = [var.strip().lower() for var in var_order.split(',')]
                depth = len(var_order)
            else:
                depth = 2

            # Extract loop nest
            loops = get_nested_loops(loop_nest, depth)
            loop_variables, loop_ranges, *_ = get_loop_components(loops)

            # Find the loop order from the variable order
            if var_order is None:
                var_order = [str(var).lower() for var in reversed(loop_variables)]
            loop_variable_names = [var.name.lower() for var in loop_variables]
            loop_order = [loop_variable_names.index(var) for var in var_order]

            # Project iteration space
            if project_bounds:
                iteration_space = Polyhedron.from_loop_ranges(loop_variables, loop_ranges)
                iteration_space = generate_loop_bounds(iteration_space, loop_order)

            # Rebuild loops starting with innermost
            inner_loop_map = None
            for idx, (loop, loop_idx) in enumerate(zip(reversed(loops), reversed(loop_order))):
                if project_bounds:
                    new_idx = len(loop_order) - idx - 1
                    ignore_variables = list(range(new_idx+1, len(loop_order)))
                    lower_bounds = iteration_space.lower_bounds(new_idx, ignore_variables)
                    upper_bounds = iteration_space.upper_bounds(new_idx, ignore_variables)

                    if len(lower_bounds) == 1:
                        lower_bounds = lower_bounds[0]
                    else:
                        fct_symbol = sym.ProcedureSymbol('max', scope=routine)
                        lower_bounds = sym.InlineCall(fct_symbol, parameters=as_tuple(lower_bounds))

                    if len(upper_bounds) == 1:
                        upper_bounds = upper_bounds[0]
                    else:
                        fct_symbol = sym.ProcedureSymbol('min', scope=routine)
                        upper_bounds = sym.InlineCall(fct_symbol, parameters=as_tuple(upper_bounds))

                    bounds = sym.LoopRange((lower_bounds, upper_bounds))
                else:
                    bounds = loop_ranges[loop_idx]

                outer_loop = loop.clone(variable=loop_variables[loop_idx], bounds=bounds)
                if inner_loop_map is not None:
                    outer_loop = Transformer(inner_loop_map).visit(outer_loop)
                inner_loop_map = {loop: outer_loop}

            # Annotate loop-interchange in a comment
            old_vars = ', '.join(loop_variable_names)
            new_vars = ', '.join(var_order)
            comment = Comment(f'! Loki loop-interchange ({old_vars} <--> {new_vars})')

            # Strip loop-interchange pragma and register new loop nest in map
            pragmas = tuple(p for p in as_tuple(loops[0].pragma)
                            if not is_loki_pragma(p, starts_with='loop-interchange'))
            loop_map[loop_nest] = (comment, outer_loop.clone(pragma=pragmas))

        # Apply loop-interchange mapping
        if loop_map:
            routine.body = Transformer(loop_map).visit(routine.body)
            info('%s: interchanged %d loop nest(s)', routine.name, len(loop_map))


def pragma_ranges_to_loop_ranges(parameters, scope):
    """
    Convert loop ranges given in the pragma parameters from string to a tuple of `LoopRange`
    objects.
    """
    if 'range' not in parameters:
        return None
    ranges = []
    for item in parameters['range'].split(','):
        bounds = [parse_expr(bound, scope=scope) for bound in item.split(':')]
        ranges += [sym.LoopRange(as_tuple(bounds))]

    return as_tuple(ranges)


def do_loop_fusion(routine):
    """
    Search for loops annotated with the `loki loop-fusion` pragma and attempt
    to fuse them into a single loop.
    """
    fusion_groups = defaultdict(list)
    loop_map = {}
    with pragmas_attached(routine, Loop):
        # Extract all annotated loops and sort them into fusion groups
        for loop in FindNodes(Loop).visit(routine.body):
            if is_loki_pragma(loop.pragma, starts_with='loop-fusion'):
                parameters = get_pragma_parameters(loop.pragma, starts_with='loop-fusion')
                group = parameters.get('group', 'default')
                fusion_groups[group] += [(loop, parameters)]

        if not fusion_groups:
            return

        # Merge loops in each group and put them in the position of the group's first loop
        #  UNLESS 'insert-loc' location is specified for at least one of the group's fusion
        #  pragmas, in this case the position is the first occurence of 'insert-loc' for each group
        for group, loop_parameter_lists in fusion_groups.items():
            loop_list, parameters = zip(*loop_parameter_lists)

            # First, determine the collapse depth and extract user-annotated loop ranges from pragmas
            collapse = [param.get('collapse', None) for param in parameters]
            insert_locs = [param.get('insert-loc', False) for param in parameters]
            if collapse != [collapse[0]] * len(collapse):
                raise RuntimeError(f'Conflicting collapse values in group "{group}"')
            collapse = int(collapse[0]) if collapse[0] is not None else 1

            pragma_ranges = [pragma_ranges_to_loop_ranges(param, routine) for param in parameters]

            # If we have a pragma somewhere with an explicit loop range, we use that for the fused loop
            range_set = {r for r in pragma_ranges if r is not None}
            if len(range_set) not in (0, 1):
                raise RuntimeError(f'Pragma-specified loop ranges in group "{group}" do not match')

            fusion_ranges = None
            if range_set:
                fusion_ranges = range_set.pop()

            # Next, extract loop ranges for all loops in group and convert to iteration space
            # polyhedrons for easier alignment
            loop_variables, loop_ranges, loop_bodies = \
                    zip(*[get_loop_components(get_nested_loops(loop, collapse)) for loop in loop_list])
            iteration_spaces = [Polyhedron.from_loop_ranges(variables, ranges)
                                for variables, ranges in zip(loop_variables, loop_ranges)]

            # Find the fused iteration space (if not given by a pragma)
            if fusion_ranges is None:
                fusion_ranges = []
                for level in range(collapse):
                    lower_bounds, upper_bounds = [], []
                    ignored_variables = list(range(level+1, collapse))

                    for p in iteration_spaces:
                        for bound in p.lower_bounds(level, ignored_variables):
                            # Decide if we learn something new from this bound, which could be because:
                            # (1) we don't have any bounds, yet
                            # (2) bound is smaller than existing lower bounds (i.e. diff < 0)
                            # (3) bound is not constant and none of the existing bounds are lower (i.e. diff >= 0)
                            diff = [simplify(bound - b) for b in lower_bounds]
                            is_any_negative = any(is_constant(d) and symbolic_op(d, op.lt, 0) for d in diff)
                            is_any_not_negative = any(is_constant(d) and symbolic_op(d, op.ge, 0) for d in diff)
                            is_new_bound = (not lower_bounds or is_any_negative or
                                            (not is_constant(bound) and not is_any_not_negative))
                            if is_new_bound:
                                # Remove any lower bounds made redundant by bound:
                                lower_bounds = [b for b, d in zip(lower_bounds, diff)
                                                if not (is_constant(d) and symbolic_op(d, op.lt, 0))]
                                lower_bounds += [bound]

                        for bound in p.upper_bounds(level, ignored_variables):
                            # Decide if we learn something new from this bound, which could be because:
                            # (1) we don't have any bounds, yet
                            # (2) bound is larger than existing upper bounds (i.e. diff > 0)
                            # (3) bound is not constant and none of the existing bounds are larger (i.e. diff <= 0)
                            diff = [simplify(bound - b) for b in upper_bounds]
                            is_any_positive = any(is_constant(d) and symbolic_op(d, op.gt, 0) for d in diff)
                            is_any_not_positive = any(is_constant(d) and symbolic_op(d, op.le, 0) for d in diff)
                            is_new_bound = (not upper_bounds or is_any_positive or
                                            (not is_constant(bound) and not is_any_not_positive))
                            if is_new_bound:
                                # Remove any lower bounds made redundant by bound:
                                upper_bounds = [b for b, d in zip(upper_bounds, diff)
                                                if not (is_constant(d) and symbolic_op(d, op.gt, 0))]
                                upper_bounds += [bound]

                    if len(lower_bounds) == 1:
                        lower_bounds = lower_bounds[0]
                    else:
                        # TODO: could/should be ProcedureSymbol, however refer to issue: #390
                        fct_symbol = sym.DeferredTypeSymbol(name='min', scope=routine)
                        lower_bounds = sym.InlineCall(fct_symbol, parameters=as_tuple(lower_bounds))

                    if len(upper_bounds) == 1:
                        upper_bounds = upper_bounds[0]
                    else:
                        # TODO: could/should be ProcedureSymbol, however refer to issue: #390
                        fct_symbol = sym.DeferredTypeSymbol(name='max', scope=routine)
                        upper_bounds = sym.InlineCall(fct_symbol, parameters=as_tuple(upper_bounds))

                    fusion_ranges += [sym.LoopRange((lower_bounds, upper_bounds))]

            # Align loop ranges and collect bodies
            fusion_bodies = []
            fusion_variables = loop_variables[0]
            for idx, (variables, ranges, bodies, p) in enumerate(
                    zip(loop_variables, loop_ranges, loop_bodies, iteration_spaces)):
                # TODO: This throws away anything that is not in the inner-most loop body.
                body = flatten([Comment(f'! Loki loop-fusion - body {idx} begin'),
                                bodies[-1],
                                Comment(f'! Loki loop-fusion - body {idx} end')])

                # Replace loop variables if necessary
                var_map = {}
                for loop_variable, fusion_variable in zip(variables, fusion_variables):
                    if loop_variable != fusion_variable:
                        var_map.update({var: fusion_variable for var in FindVariables().visit(body)
                                        if var.name.lower() == loop_variable.name})
                if var_map:
                    body = SubstituteExpressions(var_map).visit(body)

                # Wrap in conditional if loop bounds are different
                conditions = []
                for loop_range, fusion_range, variable in zip(ranges, fusion_ranges, fusion_variables):
                    if symbolic_op(loop_range.start, op.ne, fusion_range.start):
                        conditions += [sym.Comparison(variable, '>=', loop_range.start)]
                    if symbolic_op(loop_range.stop, op.ne, fusion_range.stop):
                        conditions += [sym.Comparison(variable, '<=', loop_range.stop)]
                if conditions:
                    if len(conditions) == 1:
                        condition = conditions[0]
                    else:
                        condition = sym.LogicalAnd(as_tuple(conditions))
                    body = Conditional(condition=condition, body=as_tuple(body), else_body=())

                fusion_bodies += [body]

            # Create the nested fused loop and replace original loops
            fusion_loop = flatten(fusion_bodies)
            for fusion_variable, fusion_range in zip(reversed(fusion_variables), reversed(fusion_ranges)):
                fusion_loop = Loop(variable=fusion_variable, body=as_tuple(fusion_loop), bounds=fusion_range)

            comment = Comment(f'! Loki loop-fusion group({group})')
            insert_loc = insert_locs.index(None) if None in insert_locs else 0
            loop_map[loop_list[insert_loc]] = (comment, Pragma(keyword='loki',
                content=f'fused-loop group({group})'), fusion_loop)
            comment = Comment(f'! Loki loop-fusion group({group}) - loop hoisted')
            loop_map.update({loop: comment for i_loop, loop in enumerate(loop_list) if i_loop != insert_loc})

        # Apply transformation
        routine.body = Transformer(loop_map).visit(routine.body)
        info('%s: fused %d loops in %d groups.', routine.name,
             sum(len(loop_list) for loop_list in fusion_groups.values()), len(fusion_groups))


class FissionTransformer(NestedMaskedTransformer):
    """
    Bespoke transformer that splits loops or loop nests at
    ``!$loki loop-fission`` pragmas.

    For that, the subtree that makes up the loop body is traversed multiple,
    times capturing everything before, after or in-between fission pragmas
    in each traversal, using :class:``NestedMaskedTransformer``.
    Any intermediate nodes that define sections (e.g. conditionals) are
    reproduced in each subtree traversal.

    This works also for nested loops with individually different fission
    annotations.

    Parameters
    ----------
    loop_pragmas : dict of (:any:`Loop`, list of :any:`Pragma`)
        Mapping of all loops to the list of contained
        ``loop-fission`` pragmas at which they should be split.
    """

    def __init__(self, loop_pragmas, active=True, **kwargs):
        super().__init__(active=active, require_all_start=True, greedy_stop=True, **kwargs)
        self.loop_pragmas = loop_pragmas
        self.split_loops = {}

    def visit_Loop(self, o, **kwargs):
        if o not in self.loop_pragmas:
            # loops that are not marked for fission can be handled as
            # in the regular NestedMaskedTransformer
            return super().visit_InternalNode(o, **kwargs)

        if not (self.active or self.start):
            # this happens if we encounter a loop marked for fission while
            # already traversing the subtree of an enclosing fission loop.
            # no more macros are marked to make this subtree active, thus
            # we can bail out here
            return None

        # Recurse for all children except the body
        body_index = o._traversable.index('body')
        visited = tuple(self.visit(c, **kwargs) for i, c in enumerate(o.children) if i != body_index)

        # Save current state so we can restore for each subtree
        _start, _stop, _active = self.start, self.stop, self.active

        def rebuild_fission_branch(start_node, stop_node, **kwargs):
            if start_node is None:
                # This subtree is either active already or we have a fission pragma
                # with collapse in _start from an enclosing loop
                self.active = _active
                self.start = _start.copy()
            else:
                # We build a subtree after a fission pragma. Make sure that all
                # pragmas have been encountered before processing the subtree
                self.active = False
                self.start = _start.copy() | {start_node}
                self.mapper[start_node] = None
            # we stop when encountering this or any previously defined stop nodes
            self.stop = _stop.copy() | OrderedSet(as_tuple(stop_node))
            body = flatten(self.visit(o.body, **kwargs))
            if start_node is not None:
                self.mapper.pop(start_node)
            if not body:
                return [()]
            # inject a comment to mark where the loop was split
            comment = [] if start_node is None else [Comment(f'! Loki - {start_node.content}')]
            return comment + [self._rebuild(o, visited[:body_index] + (body,) + visited[body_index:])]

        # Use masked transformer to build subtrees from/to pragma
        rebuilt = rebuild_fission_branch(None, self.loop_pragmas[o][0], **kwargs)
        for start, stop in zip(self.loop_pragmas[o][:-1], self.loop_pragmas[o][1:]):
            rebuilt += rebuild_fission_branch(start, stop, **kwargs)
        rebuilt += rebuild_fission_branch(self.loop_pragmas[o][-1], None, **kwargs)

        # Register the new loops in the mapping
        loops = [l for l in rebuilt if isinstance(l, Loop)]
        self.split_loops.update({pragma: loops[i:] for i, pragma in enumerate(self.loop_pragmas[o])})

        # Restore original state (except for the active status because this has potentially
        # been changed when traversing the loop body)
        self.start, self.stop = _start, _stop

        return as_tuple(i for i in rebuilt if i)


def do_loop_fission(routine, promote=True, warn_loop_carries=True):
    """
    Search for ``!$loki loop-fission`` pragmas in loops and split them.

    The expected pragma syntax is
    ``!$loki loop-fission [collapse(n)] [promote(var-name, var-name, ...)]``
    where ``collapse(n)`` gives the loop nest depth to be split (defaults to n=1)
    and ``promote`` optionally specifies a list of variable names to be promoted
    by the split iteration space dimensions.

    Parameters
    ----------
    routine : :any:`Subroutine`
        The subroutine in which loop fission is to be applied.
    promote : bool, optional
        Try to automatically detect read-after-write across fission points
        and promote corresponding variables. Note that this does not affect
        promotion of variables listed directly in the pragma's ``promote``
        option.
    warn_loop_carries : bool, optional
        Try to automatically detect loop-carried dependencies and warn
        when the fission point sits after the initial read and before the
        final write.
    """
    promotion_vars_dims = CaseInsensitiveDict()

    pragma_loops = defaultdict(list)  # List of enclosing loops per fission pragmas
    loop_pragmas = defaultdict(list)  # List of pragmas splitting a loop
    promotion_vars_dims = {}  # Variables to promote with new dimension
    promotion_vars_index = {}  # Variable subscripts to promote with new indices
    loop_carried_vars = {}  # List of loop carried dependencies in original loop

    # First, find the loops enclosing each pragma
    for loop in FindNodes(Loop).visit(routine.body):
        for pragma in FindNodes(Pragma).visit(loop.body):
            if is_loki_pragma(pragma, starts_with='loop-fission'):
                pragma_loops[pragma] += [loop]

    if not pragma_loops:
        return

    with optional(promote or warn_loop_carries, dataflow_analysis_attached, routine):
        for pragma in pragma_loops:
            # Now, sort the loops enclosing each pragma from outside to inside and
            # keep only the ones relevant for fission
            loops = binary_insertion_sort(pragma_loops[pragma], lt=is_parent_of)
            collapse = int(get_pragma_parameters(pragma).get('collapse', 1))
            pragma_loops[pragma] = loops[-collapse:]

            # Attach the pragma to the list of pragmas to be processed for the
            # outermost loop
            loop_pragmas[loops[-collapse]] += [pragma]

            # Promote variables given in promotion list
            promote_vars = [var.strip().lower()
                            for var in get_pragma_parameters(pragma).get('promote', '').split(',') if var]

            # Automatically determine promotion variables
            if promote:
                promote_vars += [v.name.lower() for v in read_after_write_vars(loops[-1].body, pragma)
                                 if v.name.lower() not in promote_vars]
            promotion_vars_dims, promotion_vars_index = promotion_dimensions_from_loop_nest(
                promote_vars, pragma_loops[pragma], promotion_vars_dims, promotion_vars_index)

            # Store loop-carried dependencies for later analysis
            if warn_loop_carries:
                loop_carried_vars[pragma] = loop_carried_dependencies(pragma_loops[pragma][0])

    fission_trafo = FissionTransformer(loop_pragmas)
    routine.body = fission_trafo.visit(routine.body)
    info('%s: split %d loop(s) at %d loop-fission pragma(s).', routine.name, len(loop_pragmas), len(pragma_loops))

    # Warn about broken loop-carried dependencies
    if warn_loop_carries:
        with dataflow_analysis_attached(routine):
            for pragma, loop_carries in loop_carried_vars.items():
                loop, *remainder = fission_trafo.split_loops[pragma]
                if not remainder:
                    continue

                # The loop before the pragma has to read the variable ...
                broken_loop_carries = loop_carries & loop.uses_symbols
                # ... but it is written after the pragma
                broken_loop_carries &= OrderedSet.union(*[l.defines_symbols for l in remainder])

                if broken_loop_carries:
                    if pragma.source and pragma.source.lines:
                        line_info = f' at l. {pragma.source.lines[0]}'
                    else:
                        line_info = ''
                    warning(f'Loop-fission{line_info} potentially breaks loop-carried dependencies' +
                            f'for variables: {", ".join(str(v) for v in broken_loop_carries)}')

    promote_nonmatching_variables(routine, promotion_vars_dims, promotion_vars_index)


class LoopUnrollTransformer(Transformer):
    """
    Transformer that unrolls loops or loop nests at
    ``!$loki loop-unroll`` pragmas.

    For loops to be unrolled, they must have literal bounds and step.
    If not, then they are simply ignored.

    This works also for nested loops with individually different unroll
    annotations. However, a child nested loop with a more restrictive depth
    will not be able to override its parent's depth.
    """

    def __init__(self, warn_iterations_length=True):
        self.warn_iterations_length = warn_iterations_length
        super().__init__()

    # depth is treated as an option of some depth or none, i.e. unroll all
    def visit_Loop(self, o, depth=None):
        """
        Apply this :class:`Transformer` to an IR tree.

        Parameters
        ----------
        o : :any:`Node`
            The node to visit.
        depth : 'Int', optional
            How deep down a loop nest unrolling should be applied.
        """

        # If the step isn't explicitly given, then it's implicitly 1
        step = o.bounds.step if o.bounds.step is not None else IntLiteral(1)
        start, stop = o.bounds.start, o.bounds.stop

        depth = depth - 1 if depth is not None else None

        # Only unroll if we have all literal bounds and step
        if is_constant(start) and is_constant(stop) and is_constant(step):

            #  int() to truncate any floats - which are not invalid in all specs!
            unroll_range = get_pyrange(LoopRange((start, stop, step)))
            if self.warn_iterations_length and len(unroll_range) > 32:
                warning(f"Unrolling loop over 32 iterations ({len(unroll_range)}), this may take a long time & "
                        f"provide few performance benefits.")

            acc = functools.reduce(op.add,
                                   [
                                       # Create a copy of the loop body for every value of the iterator
                                       SubstituteExpressions({o.variable: sym.IntLiteral(i)}).visit(o.body)
                                       for i in unroll_range
                                   ],
                                   ())

            if depth is None or depth >= 1:
                acc = [self.visit(a, depth=depth) for a in acc]

            return as_tuple(flatten(acc))

        _pragma = tuple(
            p for p in o.pragma if not is_loki_pragma(p, starts_with='loop-unroll')
        ) if o.pragma else None
        _pragma_post = tuple(
            p for p in o.pragma_post if not is_loki_pragma(p, starts_with='loop-unroll')
        ) if o.pragma_post else None

        return Loop(
            variable=o.variable,
            body=self.visit(o.body, depth=depth),
            bounds=o.bounds, pragma=_pragma, pragma_post=_pragma_post
        )


def do_loop_unroll(routine, warn_iterations_length=True):
    """
    Search for ``!$loki loop-unroll`` pragmas in loops and unroll them.

    The expected pragma syntax is
    ``!$loki loop-unroll [depth(n)]``
    where ``depth(n)`` controls the unrolling of nested loops. For instance,
    ``depth(1)`` will only unroll the top most loop of a set of nested loops.
    However, a child nested loop with a more restrictive depth will not be
    able to override its parent's depth. If ``depth(n)`` is not specified,
    then all loops nested under a parent with this pragma will be unrolled.
    E.g. The code sample below will only unroll A and B, but not C:

    ! Loop A
    !$loki loop-unroll depth(1)
    DO a = 1, 10
        ! Loop B
        !$loki loop-unroll
        DO b = 1, 10
            ...
        END DO
        ! Loop C - will not be unrolled
        DO c = 1, 10
            ...
        END DO
    END DO

    Parameters
    ----------
    routine : :any:`Subroutine`
        The subroutine in which loop unrolling is to be applied.
    warn_iterations_length : 'Boolean', optional
        This specifies if warnings should be generated when unrolling
        loops with a large number of iterations (32). It's mainly to
        disable warnings when loops are being unrolled for internal
        transformations and analysis.
    """

    class PragmaLoopUnrollTransformer(Transformer):
        def __init__(self, warn_iterations_length=True):
            self.warn_iterations_length = warn_iterations_length
            super().__init__()

        def visit_Loop(self, o, *args, **kwargs):
            # Check for pragmas
            if is_loki_pragma(o.pragma, starts_with='loop-unroll'):
                parameters = get_pragma_parameters(o.pragma, starts_with='loop-unroll')

                # Get the depth
                param = parameters.get('depth', None)
                depth = int(param) if param is not None else None

                # Unroll and recurse
                unrolled_loop = LoopUnrollTransformer(self.warn_iterations_length).visit(o, depth=depth)

                # unrolled_loop could be either an unrollable Loop() or a Tuple() of Nodes
                try:
                    return as_tuple(flatten([self.visit(a) for a in as_tuple(flatten(unrolled_loop))]))
                # Loop() is not iterable
                except TypeError:
                    return self.visit(unrolled_loop, *args, **kwargs)

            return super().visit_Node(o, *args, **kwargs)

    with pragmas_attached(routine, Loop):
        routine.body = PragmaLoopUnrollTransformer(warn_iterations_length=warn_iterations_length).visit(routine.body)


class TransformLoopsTransformation(Transformation):
    """
    A :any:`Transformation` that provides a common location for the various loop transformations to be called
    in a :any:`Scheduler` pipeline.

    The transformation applies the following methods in order:

    * :any:`do_loop_interchange`
    * :any:`do_loop_fusion`
    * :any:`do_loop_fission`
    * :any:`do_loop_unroll`

    Parameters
    ----------
    loop_interchange : bool
        Run the ``do_loop_interchange`` utility. Default: ``False``.
    loop_fusion : bool
        Run the ``do_loop_fusion`` utility. Default: ``False``.
    loop_fission : bool
        Run the ``do_loop_fission`` utility. Default: ``False``.
    loop_unroll : bool
        Run the ``do_loop_unroll`` utility. Default: ``False``.
    interchange_project_bounds : bool
        Project loop bounds whilst performing loop interchange. Default: ``False``.
    fission_promote : bool
        Try to automatically detect read-after-write across fission points
        and promote corresponding variables. Note that this does not affect
        promotion of variables listed directly in the pragma's ``promote``
        option. Default: ``True``.
    fission_warn_loop_carries : bool
        Try to automatically detect loop-carried dependencies and warn
        when the fission point sits after the initial read and before the
        final write. Default: ``True``.
    unroll_warn_iterations_length : bool
        This specifies if warnings should be generated when unrolling
        loops with a large number of iterations (32). It's mainly to
        disable warnings when loops are being unrolled for internal
        transformations and analysis. Default: ``True``.
    """

    def __init__(
            self, loop_interchange=False, loop_fusion=False, loop_fission=False,
            loop_unroll=False, interchange_project_bounds=False, fission_promote=True,
            fission_warn_loop_carries=True, unroll_warn_iterations_length=True
    ):
        self.loop_interchange = loop_interchange
        self.loop_fusion = loop_fusion
        self.loop_fission = loop_fission
        self.loop_unroll = loop_unroll
        self.interchange_project_bounds = interchange_project_bounds
        self.fission_promote = fission_promote
        self.fission_warn_loop_carries = fission_warn_loop_carries
        self.unroll_warn_iterations_length = unroll_warn_iterations_length

    def transform_subroutine(self, routine, **kwargs):

        # Interchange loops
        if self.loop_interchange:
            do_loop_interchange(routine, project_bounds=self.interchange_project_bounds)

        # Fuse loops
        if self.loop_fusion:
            do_loop_fusion(routine)

        # Split loops
        if self.loop_fission:
            do_loop_fission(routine, promote=self.fission_promote, warn_loop_carries=self.fission_warn_loop_carries)

        # Unroll loops
        if self.loop_unroll:
            do_loop_unroll(routine, warn_iterations_length=self.unroll_warn_iterations_length)
loki-ecmwf-0.3.6/loki/transformations/parallel/0000775000175000017500000000000015167130205021724 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/transformations/parallel/__init__.py0000664000175000017500000000122015167130205024030 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Sub-package with utilities to remove, generate and manipulate parallel
regions.
"""

from loki.transformations.parallel.block_loop import * # noqa
from loki.transformations.parallel.field_views import * # noqa
from loki.transformations.parallel.openmp_region import * # noqa
loki-ecmwf-0.3.6/loki/transformations/parallel/tests/0000775000175000017500000000000015167130205023066 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/transformations/parallel/tests/test_openmp_region.py0000664000175000017500000002561015167130205027344 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki import Subroutine, Module, Dimension
from loki.frontend import available_frontends, OMNI
from loki.ir import (
    nodes as ir, FindNodes, pragmas_attached, pragma_regions_attached,
    is_loki_pragma
)

from loki.transformations.parallel import (
    remove_openmp_regions, add_openmp_regions,
    remove_firstprivate_copies, add_firstprivate_copies
)


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('insert_loki_parallel', (True, False))
def test_remove_openmp_regions(frontend, insert_loki_parallel):
    """
    A simple test for :any:`remove_openmp_regions`
    """
    fcode = """
subroutine test_driver_openmp(n, arr)
  integer, intent(in) :: n
  real(kind=8), intent(inout) :: arr(n)
  integer :: i

  !$omp parallel private(i)
  !$omp do schedule dynamic(1)
  do i=1, n
    !$loki foo-bar
    arr(i) = arr(i) + 1.0
  end do
  !$omp end do
  !$omp end parallel


  !$OMP PARALLEL PRIVATE(i)
  !$OMP DO SCHEDULE DYNAMIC(1)
  do i=1, n
    !$loki foo-baz
    arr(i) = arr(i) + 1.0
    !$loki end foo-baz
  end do
  !$OMP END DO
  !$OMP END PARALLEL


  !$omp parallel do private(i)
  do i=1, n
    !$omp simd
    arr(i) = arr(i) + 1.0
  end do
  !$omp end parallel
end subroutine test_driver_openmp
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(ir.Loop).visit(routine.body)) == 3
    assert len(FindNodes(ir.Pragma).visit(routine.body)) == 14

    with pragma_regions_attached(routine):
        # Without attaching Loop-pragmas, all are recognised as regions
        assert len(FindNodes(ir.PragmaRegion).visit(routine.body)) == 6

    remove_openmp_regions(routine, insert_loki_parallel=insert_loki_parallel)

    assert len(FindNodes(ir.Loop).visit(routine.body)) == 3
    pragmas = FindNodes(ir.Pragma).visit(routine.body)
    assert len(pragmas) == (9 if insert_loki_parallel else 3)

    if insert_loki_parallel:
        with pragma_regions_attached(routine):
            pragma_regions = FindNodes(ir.PragmaRegion).visit(routine.body)
            assert len(pragma_regions) == 4
            assert is_loki_pragma(pragma_regions[0].pragma, starts_with='parallel')
            assert is_loki_pragma(pragma_regions[0].pragma_post, starts_with='end parallel')
            assert is_loki_pragma(pragma_regions[1].pragma, starts_with='parallel')
            assert is_loki_pragma(pragma_regions[1].pragma_post, starts_with='end parallel')
            assert is_loki_pragma(pragma_regions[2].pragma, starts_with='foo-baz')
            assert is_loki_pragma(pragma_regions[2].pragma_post, starts_with='end foo-baz')
            assert is_loki_pragma(pragma_regions[3].pragma, starts_with='parallel')
            assert is_loki_pragma(pragma_regions[3].pragma_post, starts_with='end parallel')


@pytest.mark.parametrize('frontend', available_frontends(
    skip=[(OMNI, 'OMNI has trouble mixing Loop and Section pragmas')]
))
def test_add_openmp_regions(tmp_path, frontend):
    """
    A simple test for :any:`add_openmp_regions`
    """
    fcode_type = """
module geom_mod
  type geom_type
    integer :: nproma, ngptot
  end type geom_type
end module geom_mod
"""

    fcode = """
subroutine test_add_openmp_loop(ydgeom, ydfields, arr)
  use geom_mod, only: geom_type, fld_type
  use kernel_mod, only: my_kernel, my_non_kernel
  implicit none
  type(geom_type), intent(in) :: ydgeom
  type(fld_type), intent(inout) :: ydfields
  type(fld_type), intent(inout) :: ylfields
  real(kind=8), intent(inout) :: arr(:,:,:)
  integer :: JKGLO, IBL, ICEND

  !$loki parallel

  ylfields = ydfields

  DO JKGLO=1,YDGEOM%NGPTOT,YDGEOM%NPROMA
    ICEND = MIN(YDGEOM%NPROMA, YDGEOM%NGPTOT - JKGLO + 1)
    IBL = (JKGLO - 1) / YDGEOM%NPROMA + 1

    CALL YDFIELDS%UPDATE_STUFF()

    CALL MY_KERNEL(ARR(:,:,IBL))
  END DO

  !$loki end parallel

  !$loki not-so-parallel

  DO JKGLO=1,YDGEOM%NGPTOT,YDGEOM%NPROMA
    call my_non_kernel(arr(1,1,1))
  END DO

  !$loki end not-so-parallel

end subroutine test_add_openmp_loop
"""
    _ = Module.from_source(fcode_type, frontend=frontend, xmods=[tmp_path])
    routine = Subroutine.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    assert len(FindNodes(ir.Pragma).visit(routine.body)) == 4
    with pragma_regions_attached(routine):
        regions = FindNodes(ir.PragmaRegion).visit(routine.body)
        assert len(regions) == 2
        assert is_loki_pragma(regions[0].pragma, starts_with='parallel')
        assert is_loki_pragma(regions[0].pragma_post, starts_with='end parallel')
        assert is_loki_pragma(regions[1].pragma, starts_with='not-so-parallel')
        assert is_loki_pragma(regions[1].pragma_post, starts_with='end not-so-parallel')

    block_dim = Dimension(index='JKGLO', size='YDGEOM%NGPBLK')
    add_openmp_regions(
        routine, dimension=block_dim,
        field_group_types=('fld_type',),
        shared_variables=('ydfields',)
    )

    # Ensure pragmas have been inserted
    pragmas = FindNodes(ir.Pragma).visit(routine.body)
    assert len(pragmas) == 6
    assert all(p.keyword == 'OMP' for p in pragmas[0:4])
    assert all(p.keyword == 'loki' for p in pragmas[5:6])

    with pragmas_attached(routine, node_type=ir.Loop):
        with pragma_regions_attached(routine):
            # Ensure pragma region has been created
            regions = FindNodes(ir.PragmaRegion).visit(routine.body)
            assert len(regions) == 2
            assert regions[0].pragma.keyword == 'OMP'
            assert regions[0].pragma.content.startswith('PARALLEL')
            assert regions[0].pragma_post.keyword == 'OMP'
            assert regions[0].pragma_post.content == 'END PARALLEL'
            assert is_loki_pragma(regions[1].pragma, starts_with='not-so-parallel')
            assert is_loki_pragma(regions[1].pragma_post, starts_with='end not-so-parallel')

            # Ensure shared, private and firstprivate have been set right
            assert 'PARALLEL DEFAULT(SHARED)' in regions[0].pragma.content
            assert 'PRIVATE(JKGLO, IBL, ICEND)' in regions[0].pragma.content
            assert 'FIRSTPRIVATE(ylfields)' in regions[0].pragma.content

            # Ensure loops has been annotated
            loops = FindNodes(ir.Loop).visit(routine.body)
            assert len(loops) == 2
            assert loops[0].pragma[0].keyword == 'OMP'
            assert loops[0].pragma[0].content == 'DO SCHEDULE(DYNAMIC,1)'
            assert not loops[1].pragma


@pytest.mark.parametrize('frontend', available_frontends(
    skip=[(OMNI, 'OMNI needs full type definitions for derived types')]
))
def test_remove_firstprivate_copies(frontend):
    """
    A simple test for :any:`remove_firstprivate_copies`
    """
    fcode = """
subroutine test_add_openmp_loop(ydgeom, state, arr)
  use geom_mod, only: geom_type
  use type_mod, only: state_type, flux_type, NewFlux
  implicit none
  type(geom_type), intent(in) :: ydgeom
  real(kind=8), intent(inout) :: arr(:,:,:)
  type(state_type), intent(in) :: state
  type(state_type) :: ydstate
  type(flux_type) :: ydflux
  integer :: jkglo, ibl, icend

  !$loki parallel

  ydstate = state

  ydflux = NewFlux()

  do jkglo=1,ydgeom%ngptot,ydgeom%nproma
    icend = min(ydgeom%nproma, ydgeom%ngptot - jkglo + 1)
    ibl = (jkglo - 1) / ydgeom%nproma + 1

    call ydstate%update_view(ibl)

    call my_kernel(ydstate%u(:,:), arr(:,:,ibl))
  end do

  !$loki end parallel
end subroutine test_add_openmp_loop
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    fprivate_map = {'ydstate' : 'state'}

    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assigns) == 4
    assert assigns[0].lhs == 'ydstate' and assigns[0].rhs == 'state'
    calls = FindNodes(ir.CallStatement).visit(routine.body)
    assert len(calls) == 2
    assert str(calls[0].name).startswith('ydstate%')
    assert calls[1].arguments[0].parent == 'ydstate'
    assert len(FindNodes(ir.Loop).visit(routine.body)) == 1

    # Remove the explicit copy of `ydstate = state` and adjust symbols
    routine.body = remove_firstprivate_copies(
        region=routine.body, fprivate_map=fprivate_map, scope=routine
    )

    # Check removal and symbol replacement
    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assigns) == 3
    assert assigns[0].lhs == 'ydflux'
    assert assigns[1].lhs == 'icend'
    assert assigns[2].lhs == 'ibl'
    calls = FindNodes(ir.CallStatement).visit(routine.body)
    assert len(calls) == 2
    assert str(calls[0].name).startswith('state%')
    assert calls[1].arguments[0].parent == 'state'
    assert len(FindNodes(ir.Loop).visit(routine.body)) == 1


@pytest.mark.parametrize('frontend', available_frontends(
    skip=[(OMNI, 'OMNI needs full type definitions for derived types')]
))
def test_add_firstprivate_copies(frontend):
    """
    A simple test for :any:`add_firstprivate_copies`
    """

    fcode = """
subroutine test_add_openmp_loop(ydgeom, state, arr)
  use geom_mod, only: geom_type
  implicit none
  type(geom_type), intent(in) :: ydgeom
  real(kind=8), intent(inout) :: arr(:,:,:)
  type(state_type), intent(in) :: state
  integer :: jkglo, ibl, icend

  !$loki parallel

  do jkglo=1,ydgeom%ngptot,ydgeom%nproma
    icend = min(ydgeom%nproma, ydgeom%ngptot - jkglo + 1)
    ibl = (jkglo - 1) / ydgeom%nproma + 1

    call state%update_view(ibl)

    call my_kernel(state%u(:,:), arr(:,:,ibl))
  end do

  !$loki end parallel

  !$loki not-so-parallel

  do jkglo=1,ydgeom%ngptot,ydgeom%nproma
    icend = min(ydgeom%nproma, ydgeom%ngptot - jkglo + 1)
    ibl = (jkglo - 1) / ydgeom%nproma + 1
  end do

  !$loki end not-so-parallel
end subroutine test_add_openmp_loop
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    fprivate_map = {'ydstate' : 'state'}

    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assigns) == 4
    assert assigns[0].lhs == 'icend'
    assert assigns[1].lhs == 'ibl'
    calls = FindNodes(ir.CallStatement).visit(routine.body)
    assert len(calls) == 2
    assert str(calls[0].name).startswith('state%')
    assert calls[1].arguments[0].parent == 'state'
    assert len(FindNodes(ir.Loop).visit(routine.body)) == 2

    # Put the explicit firstprivate copies back in
    add_firstprivate_copies(
        routine=routine, fprivate_map=fprivate_map
    )

    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assigns) == 5
    assert assigns[0].lhs == 'ydstate' and assigns[0].rhs == 'state'
    calls = FindNodes(ir.CallStatement).visit(routine.body)
    assert len(calls) == 2
    assert str(calls[0].name).startswith('ydstate%')
    assert calls[1].arguments[0].parent == 'ydstate'
    assert len(FindNodes(ir.Loop).visit(routine.body)) == 2
loki-ecmwf-0.3.6/loki/transformations/parallel/tests/__init__.py0000664000175000017500000000057015167130205025201 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
loki-ecmwf-0.3.6/loki/transformations/parallel/tests/test_field_views.py0000664000175000017500000003127015167130205027002 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki import Subroutine, Dimension
from loki.frontend import available_frontends, OMNI
from loki.ir import nodes as ir, FindNodes
from loki.expression import symbols as sym
from loki.types import BasicType, SymbolAttributes, Scope
from loki.logging import WARNING

from loki.transformations.field_api import (
    get_field_type, field_get_device_data, field_get_host_data, field_sync_host, field_sync_device,
    FieldAPITransferType, field_create_device_data, field_delete_device_data,
    FieldAPIAccessorType
)
from loki.transformations.parallel import (
    remove_field_api_view_updates, add_field_api_view_updates
)


@pytest.mark.parametrize('frontend', available_frontends(
    skip=[(OMNI, 'OMNI needs full type definitions for derived types')]
))
def test_field_api_remove_view_updates(caplog, frontend):
    """
    A simple test for :any:`remove_field_api_view_updates`
    """

    fcode = """
subroutine test_remove_block_loop(ngptot, nproma, nflux, dims, state, aux_fields, fluxes, ricks_fields)
  use type_module, only: dimension_type, state_type, aux_type, flux_type, ricks_type
  implicit none
  integer(kind=4), intent(in) :: ngptot, nproma, nflux
  type(dimension_type), intent(inout) :: dims
  type(STATE_TYPE), intent(inout) :: state
  type(aux_type), intent(inout) :: aux_fields
  type(FLUX_type), intent(inout) :: fluxes(nflux)
  type(ricks_type), intent(inout) :: ricks_fields

  integer :: JKGLO, IBL, ICEND, JK, JL, JF

  DO jkglo=1, ngptot, nproma
    icend = min(nproma, ngptot - JKGLO + 1)
    ibl = (jkglo - 1) / nproma + 1

    STATE = STATE%CLONE()

    CALL DIMS%UPDATE(IBL, ICEND, JKGLO)
    CALL STATE%update_VIEW(IBL)
    CALL AUX_FIELDS%UPDATE_VIEW(block_index=IBL)
    IF (NFLUX > 0) THEN
      DO jf=1, nflux
        CALL FLUXES(JF)%UPDATE_VIEW(IBL)
      END DO
    END IF
    CALL RICKS_FIELDS%UPDATE_VIEW(IBL)

    CALL MY_KERNEL(STATE%U, STATE%V, AUX_FIELDS%STUFF, FLUXES(1)%FOO, FLUXES(2)%BAR)
  END DO
end subroutine test_remove_block_loop
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 6
    assert len(FindNodes(ir.Conditional).visit(routine.body)) == 1
    assert len(FindNodes(ir.Loop).visit(routine.body)) == 2

    with caplog.at_level(WARNING):
        field_group_types = ['state_type', 'aux_type', 'flux_type']
        remove_field_api_view_updates(
            routine, field_group_types=field_group_types, dim_object='DIMS'
        )

        assert len(caplog.records) == 2
        assert '[Loki::ControlFlow] Found LHS field group assign: Assignment:: STATE = STATE%CLONE()'\
            in caplog.records[0].message
        assert '[Loki::ControlFlow] Removing RICKS_FIELDS%UPDATE_VIEW call, but not in field group types!'\
            in caplog.records[1].message

    calls = FindNodes(ir.CallStatement).visit(routine.body)
    assert len(calls) == 1
    assert calls[0].name == 'MY_KERNEL'

    assert len(FindNodes(ir.Conditional).visit(routine.body)) == 0
    loops = FindNodes(ir.Loop).visit(routine.body)
    assert len(loops) == 1
    assert loops[0].variable == 'jkglo'


@pytest.mark.parametrize('frontend', available_frontends(
    skip=[(OMNI, 'OMNI needs full type definitions for derived types')]
))
def test_field_api_add_view_updates(frontend):
    """
    A simple test for :any:`add_field_api_view_updates`.
    """

    fcode = """
subroutine test_remove_block_loop(ngptot, nproma, nflux, dims, state, aux_fields, fluxes)
  implicit none
  integer(kind=4), intent(in) :: ngptot, nproma, nflux
  type(dimension_type), intent(inout) :: dims
  type(state_type), intent(inout) :: state
  type(aux_type), intent(inout) :: aux_fields
  type(flux_type), intent(inout) :: fluxes

  integer :: JKGLO, IBL, ICEND, JK, JL, JF

  DO jkglo=1, ngptot, nproma
    icend = min(nproma, ngptot - jkglo + 1)
    ibl = (jkglo - 1) / nproma + 1

    CALL MY_KERNEL(STATE%U, STATE%V, AUX_FIELDS%STUFF, FLUXES%FOO, FLUXES%BAR)
  END DO
end subroutine test_remove_block_loop
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 1
    assert len(FindNodes(ir.Loop).visit(routine.body)) == 1

    block = Dimension(
        index=('jkglo', 'ibl'), step='NPROMA',
        lower=('1', 'ICST'), upper=('NGPTOT', 'ICEND')
    )
    field_group_types = ['state_type', 'aux_type', 'flux_type']
    add_field_api_view_updates(
        routine, dimension=block, field_group_types=field_group_types,
        dim_object='DIMS'
    )

    calls = FindNodes(ir.CallStatement).visit(routine.body)
    assert len(calls) == 5
    assert calls[0].name == 'DIMS%UPDATE' and calls[0].arguments == ('IBL', 'ICEND', 'JKGLO')
    assert calls[1].name == 'AUX_FIELDS%UPDATE_VIEW' and calls[1].arguments == ('IBL',)
    assert calls[2].name == 'FLUXES%UPDATE_VIEW' and calls[2].arguments == ('IBL',)
    assert calls[3].name == 'STATE%UPDATE_VIEW' and calls[3].arguments == ('IBL',)

    assert len(FindNodes(ir.Loop).visit(routine.body)) == 1


def test_get_field_type():
    type_map = ["jprb",
                "jpit",
                "jpis",
                "jpim",
                "jpib",
                "jpia",
                "jprt",
                "jprs",
                "jprm",
                "jprd",
                "jplm"]
    field_types = [
                    "field_1rb", "field_2rb", "field_3rb",
                    "field_1it", "field_2it", "field_3it",
                    "field_1is", "field_2is", "field_3is",
                    "field_1im", "field_2im", "field_3im",
                    "field_1ib", "field_2ib", "field_3ib",
                    "field_1ia", "field_2ia", "field_3ia",
                    "field_1rt", "field_2rt", "field_3rt",
                    "field_1rs", "field_2rs", "field_3rs",
                    "field_1rm", "field_2rm", "field_3rm",
                    "field_1rd", "field_2rd", "field_3rd",
                    "field_1lm", "field_2lm", "field_3lm",
                  ]

    def generate_fields(types):
        generated = []
        for type_name in types:
            for dim in range(1, 4):
                shape = tuple(None for _ in range(dim))
                a = sym.Variable(name='test_array',
                                 type=SymbolAttributes(BasicType.REAL,
                                                       shape=shape,
                                                       kind=sym.Variable(name=type_name)))
                generated.append(get_field_type(a))
        return generated

    generated = generate_fields(type_map)
    for field, field_name in zip(generated, field_types):
        assert isinstance(field, sym.DerivedType) and field.name == field_name

    generated = generate_fields([t.upper() for t in type_map])
    for field, field_name in zip(generated, field_types):
        assert isinstance(field, sym.DerivedType) and field.name == field_name


@pytest.mark.parametrize('transfer_type', list(FieldAPITransferType))
@pytest.mark.parametrize('accessor_type', list(FieldAPIAccessorType))
def test_field_api_call_namegen(transfer_type, accessor_type):
    """
    Test the correct generation of FIELD_API calls.
    """

    mode_str = {
        'READ_ONLY': 'rdonly',
        'READ_WRITE': 'rdwr',
        'WRITE_ONLY': 'wronly',
        'FORCE': 'force'
    }

    def _check_get_call(call, target, mode, accessor_type):
        if accessor_type == FieldAPIAccessorType.TYPE_BOUND:
            assert call.name.name.lower() == f'field%get_{target}_data_{mode_str[mode.name]}'
        else:
            assert call.name.name.lower() == f'sget_{target}_data_{mode_str[mode.name]}'

    def _check_sync_call(call, target, mode):
        assert call.name.name.lower() == f'field%sync_{target}_{mode_str[mode.name]}'

    routine = Subroutine('routine')
    field_object = sym.Variable(name='field')
    access_ptr = sym.Variable(name='field_access_ptr')

    # verify that get_device_data calls are generated corectly
    get_device_data = field_get_device_data(field_object, access_ptr, transfer_type, scope=routine,
            accessor_type=accessor_type)
    _check_get_call(get_device_data, 'device', transfer_type, accessor_type)
    # verify that get_host_data calls are generated correctly
    if transfer_type != FieldAPITransferType.WRITE_ONLY:
        get_host_data = field_get_host_data(field_object, access_ptr, transfer_type, scope=routine,
                accessor_type=accessor_type)
        _check_get_call(get_host_data, 'host', transfer_type, accessor_type)

    # verify that sync_device calls are generated corectly
    sync_device = field_sync_device(field_object, transfer_type, scope=routine)
    _check_sync_call(sync_device, 'device', transfer_type)
    # verify that sync_host calls are generated correctly
    if transfer_type != FieldAPITransferType.WRITE_ONLY:
        sync_host = field_sync_host(field_object, transfer_type, scope=routine)
        _check_sync_call(sync_host, 'host', transfer_type)


    # verify that create_device_data calls are generated corectly
    create_device_data = field_create_device_data(field_object, scope=routine)
    assert create_device_data.name.name.lower() == 'field%create_device_data'

    # verify that delete_device_data calls are generated corectly
    delete_device_data = field_delete_device_data(field_object, scope=routine)
    assert delete_device_data.name.name.lower() == 'field%delete_device_data'


@pytest.mark.parametrize('accessor_type', list(FieldAPIAccessorType))
@pytest.mark.parametrize("field_get_fn", [field_get_device_data, field_get_host_data])
def test_field_get_data(field_get_fn, accessor_type):
    scope= Scope()
    fptr = sym.Variable(name='fptr_var')
    dev_ptr = sym.Variable(name='data_var')
    queue = sym.IntLiteral(1)
    blk_bounds = sym.Array(name='blk_bounds', scope=scope, dimensions=(None, None))

    for fttype in FieldAPITransferType:
        if fttype == FieldAPITransferType.FORCE:
            sync_call = field_get_fn(fptr, dev_ptr, fttype, scope, queue, blk_bounds,
                    accessor_type=accessor_type)
            assert isinstance(sync_call, ir.CallStatement)
            if accessor_type == FieldAPIAccessorType.TYPE_BOUND:
                assert sync_call.name.parent == fptr, dev_ptr
            else:
                assert str(sync_call.name).lower()[0:5] == 'sget_'
        else:
            # queue and blk bounds can only be used for force methods
            with pytest.raises(ValueError):
                _ = field_get_fn(fptr, dev_ptr, fttype, scope, queue, blk_bounds,
                        accessor_type=accessor_type)
            # field_sync_host has not write-only option
            if fttype == FieldAPITransferType.WRITE_ONLY and field_get_fn.__name__ == 'field_get_host_data':
                with pytest.raises(TypeError):
                    _ = field_get_fn(fptr, dev_ptr, fttype, scope)
            else:
                sync_call = field_get_fn(fptr, dev_ptr, fttype, scope)
                assert isinstance(sync_call, ir.CallStatement)
                assert sync_call.name.parent == fptr

    with pytest.raises(TypeError):
        _ = field_get_fn(fptr, dev_ptr, "none_transfer_type", scope)


@pytest.mark.parametrize("field_sync_fn", [field_sync_device, field_sync_host])
def test_field_sync(field_sync_fn):
    scope= Scope()
    fptr = sym.Variable(name='fptr_var')
    queue = sym.IntLiteral(1)
    blk_bounds = sym.Array(name='blk_bounds', scope=scope, dimensions=(None, None))

    for fttype in FieldAPITransferType:
        if fttype == FieldAPITransferType.FORCE:
            sync_call = field_sync_fn(fptr, fttype, scope, queue, blk_bounds)
            assert isinstance(sync_call, ir.CallStatement)
            assert sync_call.name.parent == fptr
        else:
            # queue and blk bounds can only be used for force methods
            with pytest.raises(ValueError):
                _ = field_sync_fn(fptr, fttype, scope, queue, blk_bounds)
            # field_sync_host has not write-only option
            if fttype == FieldAPITransferType.WRITE_ONLY and field_sync_fn.__name__ == 'field_sync_host':
                with pytest.raises(TypeError):
                    _ = field_sync_fn(fptr, fttype, scope)
            else:
                sync_call = field_sync_fn(fptr, fttype, scope)
                assert isinstance(sync_call, ir.CallStatement)
                assert sync_call.name.parent == fptr

    with pytest.raises(TypeError):
        _ = field_sync_fn(fptr, "none_transfer_type", scope)
loki-ecmwf-0.3.6/loki/transformations/parallel/tests/test_block_loop.py0000664000175000017500000001210515167130205026621 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki import Subroutine, Module, Dimension
from loki.frontend import available_frontends, OMNI
from loki.ir import nodes as ir, FindNodes
from loki.tools import flatten

from loki.transformations.parallel import (
    remove_block_loops, add_block_loops
)


@pytest.mark.parametrize('frontend', available_frontends())
def test_remove_block_loops(tmp_path, frontend):
    """
    A simple test for :any:`remove_block_loops`
    """
    fcode_type = """
module geom_mod
  type geom_type
    integer :: nproma, ngptot
  end type geom_type
end module geom_mod
"""

    fcode = """
subroutine test_remove_block_loop(ydgeom, npoints, nlev, arr)
  use geom_mod, only: geom_type
  implicit none
  type(geom_type), intent(in) :: ydgeom
  integer(kind=4), intent(in) :: npoints, nlev
  real(kind=8), intent(inout) :: arr(:,:,:)
  integer :: JKGLO, IBL, ICEND, JK, JL

  DO JKGLO=1,YDGEOM%NGPTOT,YDGEOM%NPROMA
    ICEND = MIN(YDGEOM%NPROMA, YDGEOM%NGPTOT - JKGLO + 1)
    IBL = (JKGLO - 1) / YDGEOM%NPROMA + 1

    CALL MY_KERNEL(ARR(:,:,IBL))
  END DO


  DO JKGLO=1,YDGEOM%NGPTOT,YDGEOM%NPROMA
    ICEND = MIN(YDGEOM%NPROMA, YDGEOM%NGPTOT - JKGLO + 1)
    IBL = (JKGLO - 1) / YDGEOM%NPROMA + 1

    DO JK=1, NLEV
      DO JL=1, NPOINTS
        ARR(JL, JK, IBL) = 42.0
      END DO
    END DO
  END DO
end subroutine test_remove_block_loop
"""
    _ = Module.from_source(fcode_type, frontend=frontend, xmods=[tmp_path])
    routine = Subroutine.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    loops = FindNodes(ir.Loop).visit(routine.body)
    assert len(loops) == 4
    assert loops[0].variable == 'jkglo'
    assert loops[1].variable == 'jkglo'
    assert loops[2].variable == 'jk'
    assert loops[3].variable == 'jl'
    assert len(FindNodes(ir.Assignment).visit(loops[0].body)) == 2
    assert len(FindNodes(ir.Assignment).visit(loops[1].body)) == 3

    block = Dimension(
        'block', index=('jkglo', 'ibl'), step='YDGEOM%NPROMA',
        lower=('1', 'ICST'), upper=('YDGEOM%NGPTOT', 'ICEND')
    )
    remove_block_loops(routine, dimension=block)

    loops = FindNodes(ir.Loop).visit(routine.body)
    assert len(loops) == 2
    assert loops[0].variable == 'jk'
    assert loops[1].variable == 'jl'
    assert len(FindNodes(ir.Assignment).visit(routine.body)) == 1


@pytest.mark.parametrize('frontend', available_frontends())
def test_add_block_loops(frontend):
    """
    A simple test for :any:`add_block_loops`
    """
    fcode = """
subroutine test_add_block_loop(ydgeom, npoints, nlev, arr)
  integer(kind=4), intent(in) :: npoints, nlev
  real(kind=8), intent(inout) :: arr(:,:,:)
  integer :: JKGLO

!$loki parallel
  call my_kernel(arr(:,:,ibl))
!$loki end parallel

!$loki parallel
  do jk=1, nlev
    do jl=1, npoints
      arr(jl, jk, ibl) = 42.0
    end do
  end do
!$loki end parallel

!$omp parallel
  do jk=1, nlev
    do jl=1, npoints
      arr(jl, jk, ibl) = 42.0
    end do
  end do
!$omp end parallel

end subroutine test_add_block_loop
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    loops = FindNodes(ir.Loop).visit(routine.body)
    assert len(loops) == 4
    assert loops[0].variable == 'jk'
    assert loops[1].variable == 'jl'
    assert len(FindNodes(ir.Assignment).visit(routine.body)) == 2

    blocking = Dimension(
        name='block', index=('JKGLO', 'IBL'),
        lower=('1', 'ICST'),
        upper=('YDGEOM%NGPTOT', 'ICEND'),
        step='YDGEOM%NPROMA',
        size='YDGEOM%NGPBLKS',
    )

    add_block_loops(routine, dimension=blocking)

    loops = FindNodes(ir.Loop).visit(routine.body)
    assert len(loops) == 6
    assert loops[0].variable == 'jkglo'
    assert loops[1].variable == 'jkglo'
    assert loops[2].variable == 'jk'
    assert loops[3].variable == 'jl'
    assert loops[4].variable == 'jk'
    assert loops[5].variable == 'jl'

    assigns1 = FindNodes(ir.Assignment).visit(loops[0].body)
    assert len(assigns1) == 2
    assert assigns1[0].lhs == 'ICEND'
    assert str(assigns1[0].rhs) == 'MIN(YDGEOM%NPROMA, YDGEOM%NGPTOT - JKGLO + 1)'
    assert assigns1[1].lhs == 'IBL'
    assert assigns1[1].rhs == '(JKGLO - 1) / YDGEOM%NPROMA + 1'

    assigns2 = FindNodes(ir.Assignment).visit(loops[1].body)
    assert len(assigns2) == 3
    assert assigns2[0].lhs == 'ICEND'
    assert str(assigns2[0].rhs) == 'MIN(YDGEOM%NPROMA, YDGEOM%NGPTOT - JKGLO + 1)'
    assert assigns2[1].lhs == 'IBL'
    assert assigns2[1].rhs == '(JKGLO - 1) / YDGEOM%NPROMA + 1'
    assert assigns2[2].lhs == 'arr(jl, jk, ibl)'
    assert assigns2[2].rhs == 42.0

    decls = FindNodes(ir.VariableDeclaration).visit(routine.spec)
    assert len(decls) == 9 if frontend == OMNI else 5
    decl_symbols = tuple(flatten(d.symbols for d in decls))
    for v in ['JKGLO', 'IBL', 'ICEND']:
        assert v in decl_symbols
loki-ecmwf-0.3.6/loki/transformations/parallel/openmp_region.py0000664000175000017500000002572515167130205025152 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Sub-package with utilities to remove and manipulate parallel OpenMP regions.
"""

from loki.analyse import dataflow_analysis_attached
from loki.expression import symbols as sym, parse_expr
from loki.ir import (
    nodes as ir, FindNodes, FindVariables, Transformer,
    SubstituteStringExpressions, is_loki_pragma, pragmas_attached,
    pragma_regions_attached
)
from loki.tools import dict_override, flatten
from loki.types import DerivedType


__all__ = [
    'remove_openmp_regions', 'add_openmp_regions',
    'remove_firstprivate_copies', 'add_firstprivate_copies'
]


def remove_openmp_regions(routine, insert_loki_parallel=False):
    """
    Remove any OpenMP parallel annotations (``!$omp parallel``).

    Optionally, this can replace ``!$omp parallel`` with ``!$loki
    parallel`` pragmas.

    Parameters
    ----------
    routine : :any:`Subroutine`
        The routine from which to strip all OpenMP annotations.
    insert_loki_parallel : bool
        Flag for the optional insertion of ``!$loki parallel` pragmas
    """

    class RemoveOpenMPRegionTransformer(Transformer):
        """
        Remove OpenMP pragmas from "parallel" regions and remove all
        contained OpenMP pragmas and pragma regions.

        Optionally replaces outer ``!$omp parallel`` region with
        ``!$loki parallel`` region.
        """

        def visit_PragmaRegion(self, region, **kwargs):
            """
            Perform the filtering and removal of OpenMP pragma regions.

            Parameters
            ----------
            active : tuple
                Flag to indicate whether we're actively traversing an
                outer OpenMP region.
            """
            if not region.pragma.keyword.lower() == 'omp':
                return region

            if kwargs['active'] and region.pragma.keyword.lower() == 'omp':
                # Remove other OpenMP pragma regions when in active mode
                region._update(pragma=None, pragma_post=None)
                return region

            if 'parallel' in region.pragma.content.lower():
                # Replace or remove pragmas
                pragma = None
                pragma_post = None
                if insert_loki_parallel:
                    pragma = ir.Pragma(keyword='loki', content='parallel')
                    pragma_post = ir.Pragma(keyword='loki', content='end parallel')

                with dict_override(kwargs, {'active': True}):
                    body = self.visit(region.body, **kwargs)

                region._update(body=body, pragma=pragma, pragma_post=pragma_post)

            return region

        def visit_Pragma(self, pragma, **kwargs):
            """ Remove other OpenMP pragmas if in active region """

            if kwargs['active'] and pragma.keyword.lower() == 'omp':
                return None

            return pragma

    with pragma_regions_attached(routine):
        routine.body = RemoveOpenMPRegionTransformer().visit(routine.body, active=False)


def add_openmp_regions(
        routine, dimension, shared_variables=None, field_group_types=None
):
    """
    Add the OpenMP directives for a parallel driver region with an
    outer block loop.

    Parameters
    ----------
    routine : :any:`Subroutine`
        The routine to which to add OpenMP parallel regions.
    dimension : :any:`Dimension`
        The dimension object describing the block loop variables.
    shared_variables : tuple of str
        Names of variables that should neither be private nor firstprivate
    field_group_types : tuple of str
        Names of types designating "field groups", which should be
        treated as firstprivate
    """
    shared_variables = shared_variables or {}
    field_group_types = field_group_types or {}

    # First get local variables and separate scalars and arrays
    routine_arguments = routine.arguments
    local_variables = tuple(
        v for v in routine.variables if v not in routine_arguments
    )
    local_scalars = tuple(
        v for v in local_variables if isinstance(v, sym.Scalar)
    )
    # Filter arrays by block-dim size, as these are global
    local_arrays = tuple(
        v for v in local_variables
        if isinstance(v, sym.Array) and not v.dimensions[-1] == dimension.size
    )

    with pragma_regions_attached(routine):
        with dataflow_analysis_attached(routine):
            for region in FindNodes(ir.PragmaRegion).visit(routine.body):
                if not is_loki_pragma(region.pragma, starts_with='parallel'):
                    return

                # Accumulate the set of locally used symbols and chase parents
                symbols = tuple(region.uses_symbols | region.defines_symbols)
                symbols = tuple(dict.fromkeys(flatten(
                    s.parents if s.parent else s for s in symbols
                )))

                # Start with loop variables and add local scalars and arrays
                local_vars = tuple(dict.fromkeys(flatten(
                    loop.variable for loop in FindNodes(ir.Loop).visit(region.body)
                )))

                local_vars += tuple(v for v in local_scalars if v.name in symbols)
                local_vars += tuple(v for v in local_arrays if v.name in symbols )

                # Also add used symbols that might be field groups
                local_vars += tuple(dict.fromkeys(
                    v for v in routine_arguments
                    if v.name in symbols and str(v.type.dtype) in field_group_types
                ))

                # Filter out known global variables
                local_vars = tuple(v for v in local_vars if v.name not in shared_variables)

                # Make field group types firstprivate
                firstprivates = tuple(dict.fromkeys(
                    v.name for v in local_vars if v.type.dtype.name in field_group_types
                ))
                # Also make values that have an initial value firstprivate
                firstprivates += tuple(v.name for v in local_vars if v.type.initial)

                # Mark all other variables as private
                privates = tuple(dict.fromkeys(
                    v.name for v in local_vars if v.name not in firstprivates
                ))

                s_fp_vars = ", ".join(str(v) for v in firstprivates)
                s_firstprivate = f'FIRSTPRIVATE({s_fp_vars})' if firstprivates else ''
                s_private = f'PRIVATE({", ".join(str(v) for v in privates)})' if privates else ''
                pragma_parallel = ir.Pragma(
                    keyword='OMP', content=f'PARALLEL DEFAULT(SHARED) {s_private} {s_firstprivate}'
                )
                region._update(
                    pragma=pragma_parallel,
                    pragma_post=ir.Pragma(keyword='OMP', content='END PARALLEL')
                )

                # And finally mark all block-dimension loops as parallel
                with pragmas_attached(routine, node_type=ir.Loop):
                    for loop in FindNodes(ir.Loop).visit(region.body):
                        # Add OpenMP DO directives onto block loops
                        if loop.variable == dimension.index:
                            loop._update(
                                pragma=ir.Pragma(keyword='OMP', content='DO SCHEDULE(DYNAMIC,1)'),
                                pragma_post=ir.Pragma(keyword='OMP', content='END DO'),
                            )


def remove_firstprivate_copies(region, fprivate_map, scope):
    """
    Removes an IFS-specific workaround, where complex derived-type
    objects are explicitly copied into a local copy of the object to
    avoid erroneous firstprivatisation in OpenMP loops.

    Parameters
    ----------
    region : tuple of :any:`Node`
        The code region from which to remove firstprivate copies
    fprivate_map : dict of (str, str)
        String mapping of local-to-global names for explicitly
        privatised objects
    scope : :any:`Scope`
        Scope to use for symbol susbtitution
    """

    class RemoveExplicitCopyTransformer(Transformer):
        """ Remove assignments that match the firstprivatisation map """

        def visit_Assignment(self, assign, **kwargs):  # pylint: disable=unused-argument
            if not isinstance(assign.lhs.type.dtype, DerivedType):
                return assign

            lhs = assign.lhs.name
            if lhs in fprivate_map and assign.rhs == fprivate_map[lhs]:
                return None
            return assign

    # Strip assignments of local copies
    region = RemoveExplicitCopyTransformer().visit(region)

    # Invert the local use of the private copy
    return SubstituteStringExpressions(fprivate_map, scope=scope).visit(region)


def add_firstprivate_copies(routine, fprivate_map):
    """
    Injects IFS-specific thread-local copies of named complex derived
    type objects in parallel regions. This is to prevent issues with
    firstprivate variables in OpenMP loops.

    Parameters
    ----------
    routine : :any:`Subroutine`
        Subroutine in which to insert privatisation copies
    fprivate_map : dict of (str, str)
        String mapping of local-to-global names for explicitly
        privatised objects
    """
    inverse_map = {v: k for k, v in fprivate_map.items()}

    # Ensure the local object copies are declared
    for lcl, gbl in fprivate_map.items():
        lhs = parse_expr(lcl, scope=routine)
        rhs = parse_expr(gbl, scope=routine)
        if not lhs in routine.variable_map:
            routine.variables += (lhs.clone(type=rhs.type.clone(intent=None)),)

    class InjectExplicitCopyTransformer(Transformer):
        """" Inject assignments that match the firstprivate map in parallel regions """

        def visit_PragmaRegion(self, region, **kwargs):  # pylint: disable=unused-argument
            # Apply to pragma-marked "parallel" regions only
            if not is_loki_pragma(region.pragma, starts_with='parallel'):
                return region

            # Collect the explicit privatisation copies
            lvars = FindVariables(unique=True).visit(region.body)
            assigns = ()
            for lcl, gbl in fprivate_map.items():
                lhs = parse_expr(lcl, scope=routine)
                rhs = parse_expr(gbl, scope=routine)
                if rhs in lvars:
                    assigns += (ir.Assignment(lhs=lhs, rhs=rhs),)

            # Remap from global to local name in marked regions
            region = SubstituteStringExpressions(inverse_map, scope=routine).visit(region)

            # Add the copies and return
            region.prepend(assigns)
            return region

    with pragma_regions_attached(routine):
        # Inject assignments of local copies
        routine.body = InjectExplicitCopyTransformer().visit(routine.body)
loki-ecmwf-0.3.6/loki/transformations/parallel/block_loop.py0000664000175000017500000001147415167130205024430 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Transformation utilities to remove and generate parallel block loops.
"""

from loki.expression import symbols as sym
from loki.ir import (
    nodes as ir, FindNodes, Transformer, pragma_regions_attached,
    is_loki_pragma
)
from loki.tools import as_tuple
from loki.types import BasicType, SymbolAttributes


__all__ = ['remove_block_loops', 'add_block_loops']


def remove_block_loops(routine, dimension):
    """
    Remove any outer block :any:`Loop` from a given :any:`Subroutine.

    The loops are identified according to a given :any:`Dimension`
    object, and will remove auxiliary assignments of index and bound
    variables, as commonly used in IFS-style block loops.

    Parameters
    ----------
    routine: :any:`Subroutine`
        Subroutine from which to remove block loops
    dimension : :any:`Dimension`
        The dimension object describing loop variables
    """
    idx = dimension.index
    variables = as_tuple(dimension.indices)
    variables += as_tuple(dimension.lower)
    variables += as_tuple(dimension.upper)

    class RemoveBlockLoopTransformer(Transformer):
        """
        :any:`Transformer` to remove driver-level block loops.
        """

        def visit_Loop(self, loop, **kwargs):  # pylint: disable=unused-argument
            body = self.visit(loop.body, **kwargs)

            if not loop.variable == idx:
                return loop._rebuild(body=body)

            to_remove = tuple(
                a for a in FindNodes(ir.Assignment).visit(body)
                if a.lhs in variables
            )
            return tuple(n for n in body if n not in to_remove)

    routine.body = RemoveBlockLoopTransformer().visit(routine.body)


def add_block_loops(routine, dimension, default_type=None):
    """
    Insert IFS-style (NPROMA) driver block-loops in ``!$loki
    parallel`` regions.

    The provided :any:`Dimension` object describes the variables to
    used when generating the loop and default assignments. It
    encapsulates IFS-specific convention, where a strided loop over
    points, defined by ``dimension.index``, ``dimension.bounds`` and
    ``dimension.step`` is created, alongside assignments that define
    the corresponding block index and upper bound, defined by
    ``dimension.indices[1]`` and ``dimension.upper[1]`` respectively.

    Parameters
    ----------
    routine : :any:`Subroutine`
        The routine in which to add block loops.
    dimension : :any:`Dimension`
        The dimension object describing the block loop variables.
    default_type : :any:`SymbolAttributes`, optional
        Default type to use when creating variables; defaults to
        ``integer(kind=JPIM)``.
    """

    _default = SymbolAttributes(BasicType.INTEGER, kind='JPIM')
    dtype = default_type if default_type else _default

    lidx = routine.parse_expr(dimension.index)
    bidx = routine.parse_expr(dimension.indices[1])
    bupper = routine.parse_expr(dimension.upper[1])

    # Ensure that local integer variables are declared
    for v in (lidx, bupper, bidx):
        if not v in routine.variable_map:
            routine.variables += (v.clone(type=dtype),)

    def _create_block_loop(body, scope):
        """
        Generate block loop object, including indexing preamble
        """

        bsize = scope.parse_expr(dimension.step)
        lupper = scope.parse_expr(dimension.upper[0])
        lrange = sym.LoopRange((sym.Literal(1), lupper, bsize))

        expr_tail = scope.parse_expr(f'{lupper}-{lidx}+1')
        expr_max = sym.InlineCall(
            function=sym.ProcedureSymbol('MIN', scope=scope), parameters=(bsize, expr_tail)
        )
        preamble = (ir.Assignment(lhs=bupper, rhs=expr_max),)
        preamble += (ir.Assignment(
            lhs=bidx, rhs=scope.parse_expr(f'({lidx}-1)/{bsize}+1')
        ),)

        return ir.Loop(variable=lidx, bounds=lrange, body=preamble + body)

    class InsertBlockLoopTransformer(Transformer):

        def visit_PragmaRegion(self, region, **kwargs):
            """
            (Re-)insert driver-level block loops into marked parallel region.
            """
            if not is_loki_pragma(region.pragma, starts_with='parallel'):
                return region

            scope = kwargs.get('scope')

            loop = _create_block_loop(body=region.body, scope=scope)

            region._update(body=(ir.Comment(''), loop))
            return region

    with pragma_regions_attached(routine):
        routine.body = InsertBlockLoopTransformer().visit(routine.body, scope=routine)
loki-ecmwf-0.3.6/loki/transformations/parallel/field_views.py0000664000175000017500000001400215167130205024573 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Transformation utilities to manage and inject FIELD-API boilerplate code.
"""

from loki.expression import symbols as sym
from loki.ir import (
    nodes as ir, FindNodes, FindVariables, Transformer
)
from loki.logging import warning
from loki.tools import as_tuple


__all__ = [
    'remove_field_api_view_updates', 'add_field_api_view_updates'
]


def remove_field_api_view_updates(routine, field_group_types, dim_object=None):
    """
    Remove FIELD API boilerplate calls for view updates of derived types.

    This utility is intended to remove the IFS-specific group type
    objects that provide block-scope view pointers to deep kernel
    trees. It will remove all calls to ``UPDATE_VIEW`` on derive-type
    objects with the respective types.

    Parameters
    ----------
    routine : :any:`Subroutine`
        The routine from which to remove FIELD API update calls
    field_group_types : tuple of str
        List of names of the derived types of "field group" objects to remove
    dim_object : str, optional
        Optional name of the "dimension" object; if provided it will remove the
        call to ``%UPDATE(...)`` accordingly.
    """
    field_group_types = as_tuple(str(fgt).lower() for fgt in field_group_types)

    class RemoveFieldAPITransformer(Transformer):

        def visit_CallStatement(self, call, **kwargs):  # pylint: disable=unused-argument

            if '%update_view' in str(call.name).lower():
                if not str(call.name.parent.type.dtype).lower() in field_group_types:
                    warning(f'[Loki::ControlFlow] Removing {call.name} call, but not in field group types!')

                return None

            if dim_object and f'{dim_object}%update'.lower() in str(call.name).lower():
                return None

            return call

        def visit_Assignment(self, assign, **kwargs):  # pylint: disable=unused-argument
            if str(assign.lhs.type.dtype).lower() in field_group_types:
                warning(f'[Loki::ControlFlow] Found LHS field group assign: {assign}')
            return assign

        def visit_Loop(self, loop, **kwargs):
            loop = self.visit_Node(loop, **kwargs)
            return loop if loop.body else None

        def visit_Conditional(self, cond, **kwargs):
            cond = super().visit_Node(cond, **kwargs)
            return cond if cond.body else None

    routine.body = RemoveFieldAPITransformer().visit(routine.body)


def add_field_api_view_updates(routine, dimension, field_group_types, dim_object=None):
    """
    Adds FIELD API boilerplate calls for view updates.

    The provided :any:`Dimension` object describes the local loop variables to
    pass to the respective update calls. In particular, ``dimension.indices[1]``
    is used to denote the block loop index that is passed to ``UPDATE_VIEW()``
    calls on field group object. The list of type names ``field_group_types``
    is used to identify for which objcets the view update calls get added.

    Parameters
    ----------
    routine : :any:`Subroutine`
        The routine from which to remove FIELD API update calls
    dimension : :any:`Dimension`
        The dimension object describing the block loop variables.
    field_group_types : tuple of str
        List of names of the derived types of "field group" objects to remove
    dim_object : str, optional
        Optional name of the "dimension" object; if provided it will remove the
        call to ``%UPDATE(...)`` accordingly.
    """

    def _create_dim_update(scope, dim_object):
        index = scope.parse_expr(dimension.index)
        upper = scope.parse_expr(dimension.upper[1])
        bindex = scope.parse_expr(dimension.indices[1])
        idims = scope.get_symbol(dim_object)
        csym = sym.ProcedureSymbol(name='UPDATE', parent=idims, scope=idims.scope)
        return ir.CallStatement(name=csym, arguments=(bindex, upper, index), kwarguments=())

    def _create_view_updates(section, scope):
        bindex = scope.parse_expr(dimension.indices[1])

        fgroup_vars = sorted(tuple(
            v for v in FindVariables(unique=True).visit(section)
            if str(v.type.dtype) in field_group_types
        ), key=str)
        calls = ()
        for fgvar in fgroup_vars:
            fgsym = scope.get_symbol(fgvar.name)
            csym = sym.ProcedureSymbol(name='UPDATE_VIEW', parent=fgsym, scope=fgsym.scope)
            calls += (ir.CallStatement(name=csym, arguments=(bindex,), kwarguments=()),)

        return calls

    class InsertFieldAPIViewsTransformer(Transformer):
        """ Injects FIELD-API view updates into block loops """

        def visit_Loop(self, loop, **kwargs):  # pylint: disable=unused-argument
            if not loop.variable == 'JKGLO':
                return loop

            scope = kwargs.get('scope')

            # Find the loop-setup assignments
            _loop_symbols = dimension.indices
            _loop_symbols += as_tuple(dimension.lower) + as_tuple(dimension.upper)
            loop_setup = tuple(
                a for a in FindNodes(ir.Assignment).visit(loop.body)
                if a.lhs in _loop_symbols
            )
            idx = max(loop.body.index(a) for a in loop_setup) + 1

            # Prepend FIELD API boilerplate
            preamble = (
                ir.Comment(''), ir.Comment('! Set up thread-local view pointers')
            )
            if dim_object:
                preamble += (_create_dim_update(scope, dim_object=dim_object),)
            preamble += _create_view_updates(loop.body, scope)

            loop._update(body=loop.body[:idx] + preamble + loop.body[idx:])
            return loop

    routine.body = InsertFieldAPIViewsTransformer().visit(routine.body, scope=routine)
loki-ecmwf-0.3.6/loki/transformations/routine_signatures.py0000664000175000017500000002161715167130205024442 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Collection of utilities and transformations altering routine signatures.
"""

import os
import itertools as it
from loki.batch import Transformation, ProcedureItem
from loki.ir import (
    VariableDeclaration, FindVariables,
    Transformer, FindNodes, CallStatement,
    SubstituteExpressions
)
from loki.tools import as_tuple, flatten
from loki.types import BasicType

__all__ = ['RemoveDuplicateArgs', 'remove_duplicate_args_from_calls',
           'modify_variable_declarations']


class RemoveDuplicateArgs(Transformation):
    """
    Transformation to remove duplicate arguments for both caller
    and callee. 

    .. warning::
        this won't work properly for multiple calls to the same routine
        with differing duplicate arguments

    Parameters
    ----------
    recurse_to_kernels : bool, optional
        Remove duplicate arguments only at the driver level or recurse to
        (nested) kernels (Default: `True`).
    rename_common : bool, optional
        Try to rename dummy arguments in called routines that received the same argument
        on the caller side, by finding a common name pattern in those names (Default: `False`).
    """

    # This trafo only operates on procedures
    item_filter = (ProcedureItem,)

    def __init__(self, recurse_to_kernels=True, rename_common=False):
        self.recurse_to_kernels = recurse_to_kernels
        self.rename_common = rename_common

    def transform_subroutine(self, routine, **kwargs):
        role = kwargs['role']
        if role == 'driver' or self.recurse_to_kernels:
            remove_duplicate_args_from_calls(routine, rename_common=self.rename_common)

def remove_duplicate_args_from_calls(routine, rename_common=False):
    """
    Utility to remove duplicate arguments from calls in :data:`routine`
    
    This updates the calls as well as the called routines. It requires calls
    to be enriched with interprocedural information.

    .. warning::
        this won't work properly for multiple calls to the same routine
        with differing duplicate arguments

    Parameters
    ----------
    routine : :any:`Subroutine`
        The subroutine where calls should be transformed.
    rename_common : bool, optional
        Try to rename dummy arguments in called routines that received the same argument
        on the caller side, by finding a common name pattern in those names (Default: `False`).
    """

    def remove_duplicate_args_call(call):
        arg_map = {}
        for routine_arg, call_arg in call.arg_iter():
            arg_map.setdefault(call_arg, []).append(routine_arg)
        # filter duplicate kwargs (comparing to the other kwarguments)
        _new_kwargs = as_tuple(list(kw_vals)[0] for g, kw_vals in it.groupby(call.kwarguments, key=lambda x: x[1]))
        # filter duplicate kwargs (comparing to the arguments)
        new_kwargs = tuple(kwarg for kwarg in _new_kwargs if kwarg[1] not in call.arguments)
        # (filter duplicate arguments and) update call
        call._update(arguments=as_tuple(dict.fromkeys(call.arguments)), kwarguments=new_kwargs)
        return arg_map

    def modify_callee(callee, callee_arg_map):

        def allowed_rename(routine, rename):
            # check whether rename is already "used" in routine
            if rename in routine.arguments or rename in routine.variables:
                return False
            return True

        combine = [routine_args for call_arg, routine_args in callee_arg_map.items() if len(routine_args) > 1]
        if rename_common:
            matches = [
                os.path.commonprefix([str(elem.name) for elem in args]).rstrip('_') or
                os.path.commonprefix([str(elem.name)[::-1] for elem in args]).rstrip('_')[::-1]
                for args in combine
            ]
            rename_common_map = {c[0].name: m for c, m in zip(combine, matches) if m}
            # check whether found rename is already "used" in routine
            unallowed_renames = ()
            for name, rename in rename_common_map.items():
                if not allowed_rename(callee, rename):
                    unallowed_renames += (name,)
            # and if already "used", remove and use instead default
            for key in unallowed_renames:
                del rename_common_map[key]
        else:
            rename_common_map = {}
        redundant = flatten([routine_args[1:] for routine_args in combine])
        combine_map = {routine_args[0]: as_tuple(routine_args[1:]) for routine_args in combine}
        arg_map = {arg.name: rename_common_map.get(common_arg.name, common_arg.name)
                   for common_arg, redundant_args in combine_map.items() for arg in redundant_args}
        # remove duplicates from callee.arguments
        new_routine_args = tuple(arg for arg in callee.arguments if arg not in redundant)
        # rename if common name is possible
        new_routine_args = as_tuple(arg.clone(name=rename_common_map[arg.name])
                if arg.name in rename_common_map else arg for arg in new_routine_args)
        callee.arguments = new_routine_args

        # rename usage/occurences in callee.body
        var_map = {}
        variables = FindVariables(unique=False).visit(callee.body)
        var_map = {var: var.clone(name=arg_map[var.name]) for var in variables if var.name in arg_map}
        var_map.update({var: var.clone(name=rename_common_map[var.name]) for var in variables
            if var.name in rename_common_map})
        callee.body = SubstituteExpressions(var_map).visit(callee.body)
        # modify the variable declarations, thus remove redundant variable declarations and possibly rename
        modify_variable_declarations(callee, remove_symbols=redundant, rename_symbols=rename_common_map)
        # store the information for possibly later renaming kwarguments on caller side
        return rename_common_map

    def rename_kwarguments(relevant_calls, rename_common_map_routine):
        for call in relevant_calls:
            kwarguments = call.kwarguments
            if kwarguments:
                call_name = str(call.routine.name).lower()
                new_kwargs = as_tuple((rename_common_map_routine[call_name][kw[0]], kw[1])
                        if kw[0] in rename_common_map_routine[call_name] else kw for kw in kwarguments)
                call._update(kwarguments=new_kwargs)

    calls = FindNodes(CallStatement).visit(routine.body)
    call_arg_map = {}
    relevant_calls = []
    # adapt call statements (and remove duplicate args/kwargs)
    for call in calls:
        if call.routine is BasicType.DEFERRED:
            continue
        call_arg_map[call.routine] = remove_duplicate_args_call(call)
        relevant_calls.append(call)
    rename_common_map_routine = {}
    # modify/adapt callees
    for callee, callee_arg_map in call_arg_map.items():
        rename_common_map_routine[str(callee.name).lower()] = modify_callee(callee, callee_arg_map)
    # handle possibly renamed kwarguments on caller side
    if rename_common:
        rename_kwarguments(relevant_calls, rename_common_map_routine)


def modify_variable_declarations(routine, remove_symbols=(), rename_symbols=None):
    """
    Utility to modify variable declarations by either removing symbols or renaming
    symbols.

    .. note::
        This utility only works on the variable declarations itself and
        won't modify variable/symbol usages elsewhere!

    Parameters
    ----------
    routine : :any:`Subroutine`
        The subroutine to be transformed.
    remove_symbols : list, tuple
        List of symbols for which their declaration should be removed.
    rename_symbols : dict
        Dict/Map of symbols for which their declaration should be renamed.
    """
    rename_symbols = rename_symbols if rename_symbols is not None else {}
    var_decls = FindNodes(VariableDeclaration).visit(routine.spec)
    remove_symbol_names = [var.name.lower() for var in remove_symbols]
    decl_map = {}
    already_declared = ()
    for decl in var_decls:
        symbols = [symbol for symbol in decl.symbols if symbol.name.lower() not in remove_symbol_names]
        symbols = [symbol.clone(name=rename_symbols[symbol.name])
                if symbol.name in rename_symbols else symbol for symbol in symbols]
        symbols = [symbol for symbol in symbols if not symbol.name.lower() in already_declared]
        already_declared += tuple(symbol.name.lower() for symbol in symbols)
        if symbols and symbols != decl.symbols:
            decl_map[decl] = decl.clone(symbols=as_tuple(symbols))
        else:
            if not symbols:
                decl_map[decl] = None
    routine.spec = Transformer(decl_map).visit(routine.spec)
loki-ecmwf-0.3.6/loki/transformations/parametrise.py0000664000175000017500000003422015167130205023017 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Parametrise variables.

E.g., parametrise

.. code-block:: fortran

    subroutine driver(a, b)
        integer, intent(in) :: a
        integer, intent(in) :: b
        call kernel(a, b)
    end subroutine driver

    subroutine kernel(a, b)
        integer, intent(in) :: a
        integer, intent(in) :: b
        real :: array(a)
        ...
    end subroutine kernel

using the transformation

.. code-block:: python

    dic2p = {'a': 10}
    scheduler.process(transformation=ParametriseTransformation(dic2p=dic2p))

to

.. code-block:: fortran

    subroutine driver(parametrised_a, b)
        integer, parameter :: a = 10
        integer, intent(in) :: parametrised_a
        integer, intent(in) :: b
        IF (parametrised_a /= 10) THEN
            PRINT *, "Variable a parametrised to value 10, but subroutine driver received another value."
            STOP 1
        END IF
        call kernel(b)
    end subroutine driver

    subroutine kernel(b)
        integer, parameter :: a = 10
        integer, intent(in) :: b
        real :: array(a)
        ...
    end subroutine kernel

or

.. code-block:: fortran

    subroutine driver(parametrised_a, b)
        integer, intent(in) :: parametrised_a
        integer, intent(in) :: b
        IF (parametrised_a /= 10) THEN
            PRINT *, "Variable a parametrised to value 10, but subroutine driver received another value."
            STOP 1
        END IF
        call kernel(b)
    end subroutine driver

    subroutine kernel(b)
        integer, intent(in) :: b
        real :: array(10)
        ...
    end subroutine kernel

using the transformation

.. code-block:: python

    dic2p = {'a': 10}
    scheduler.process(transformation=ParametriseTransformation(dic2p=dic2p, replace_by_value=True))
"""

from loki.batch import Transformation
from loki.expression import symbols as sym
from loki.ir import nodes as ir, Transformer, FindNodes, FindInlineCalls
from loki.tools.util import as_tuple, CaseInsensitiveDict

from loki.transformations.utilities import single_variable_declaration
from loki.transformations.inline import inline_constant_parameters


__all__ = ['ParametriseTransformation', 'declare_fixed_value_scalars_as_constants']


def declare_fixed_value_scalars_as_constants(routine):
    """
    Mark local scalars that are assigned a fixed value as parameters.

    This is not really sophisticated and will eventually be superseded by
    a constant propagation transformation.
    """
    def is_constant_rhs(expr):
        # expr is a literal e.g., a IntLiteral
        if issubclass(type(expr), sym._Literal):
            return True
        # expr is a Product/Sum and all children are literals
        if isinstance(expr, (sym.Product, sym.Sum)):
            return all(issubclass(type(_expr), sym._Literal) for _expr in expr.children)
        return False

    assignments = FindNodes(ir.Assignment).visit(routine.body)
    # filter for local variables and scalars
    variables = [var for var in routine.variables if var not in routine.arguments and not isinstance(var, sym.Array)]
    # don't bother with those being used in (inline) calls (although intent 'in' would be fine)
    calls = as_tuple(FindNodes(ir.CallStatement).visit(routine.body)) + as_tuple(FindInlineCalls().visit(routine.body))
    args = set()
    for call in calls:
        args |= set(call.arguments) | set(arg[1] for arg in call.kwarguments)
    if args:
        variables = [var for var in variables if var not in args]
    assignments_dic = {}
    for assignment in assignments:
        if assignment.lhs in variables:
            assignments_dic.setdefault(assignment.lhs, []).append(assignment)
    # remove those which are written to multiple times
    keys2remove = []
    for key, vals in assignments_dic.items():
        if len(vals) > 1:
            keys2remove.append(key)
    for key in keys2remove:
        del assignments_dic[key]
    # keep only those which are assigned a constant value to
    parametrise_map = {}
    for key, vals in assignments_dic.items():
        val = vals[0]
        if is_constant_rhs(val.rhs):
            parametrise_map[key] = val
    _vars = list(parametrise_map.keys())
    # make sure the relevant variables are declared individually
    single_variable_declaration(routine, [str(var.name) for var in _vars])
    # update relevant vars to be parameters and assign correct initial value
    for var in _vars:
        routine.symbol_attrs[str(var.name)] = var.type.clone(parameter=True, initial=parametrise_map[var].rhs)
    # remove the original assignments in the body which are now used to initialise the parameter variables
    assignment_map = {}
    for assignment in parametrise_map.values():
        assignment_map[assignment] = None
    routine.body = Transformer(assignment_map).visit(routine.body)


class ParametriseTransformation(Transformation):
    """
    Parametrise variables with provided values.

    This transformation checks for each subroutine (defined as driver or entry point) the arguments to be parametrised
    according to :attr:`dic2p` and passes this information down the calltree.

    .. note::

        A sanity run-time check will be inserted at each entry point to check consistency of the provided value
        and argument value at this point!

    .. warning::

        The subroutine/call signature(s) may be altered as arguments are converted to local parameters or int literals.
        Therefore, consistency must be ensured, meaning all parts of the code calling subroutines that are transformed
        and all possibly differing names of variables at the entry points must be included, otherwise the resulting
        code will not compile correctly!

    E.g., use this class like this:

    .. code-block:: python

        def error_stop(**kwargs):
            msg = kwargs.get("msg")
            return ir.Intrinsic(text=f'error stop "{msg}"'),

        dic2p = {'a': 12, 'b': 11}

        transformation = ParametriseTransformation(dic2p=dic2p, abort_callback=error_stop,
                                entry_points=("driver1", "driver2"))

        scheduler.process(transformation=transformation)

    Parameters
    ----------
    dic2p: dict
        Dictionary of variable names and corresponding values to be parametrised.
    replace_by_value: bool
        Replace variables entirely by value (default: `False`)
    entry_points: None or tuple
        Subroutine names to be used as entry points for parametrisation. Default `None` uses driver(s) as
        entry points.
    abort_callback:
        Callback routine used for error on sanity check.
        Available arguments via ``kwargs``:

        * ``msg`` - predefined error message
        * ``routine`` - the routine executing the sanity check
        * ``var`` - the variable getting checked
        * ``value`` - the value the variable should have (according to :attr:`dic2p`)
    key : str
        Access identifier/key for the ``item.trafo_data`` dictionary. Only necessary to provide if several of
        these transformations are carried out in succession.
    """

    _key = "ParametriseTransformation"

    def __init__(self, dic2p, replace_by_value=False, entry_points=None, abort_callback=None, key=None):
        self.dic2p = dic2p
        self.replace_by_value = replace_by_value
        self.entry_points = tuple(entry_point.upper() for entry_point in as_tuple(entry_points)) or None
        self.abort_callback = abort_callback
        if key is not None:
            self._key = key

    def transform_subroutine(self, routine, **kwargs):
        """
        Transformation applied to :any:`Subroutine` item.

        Parametrises all variables as defined by :attr:`dic2p` either to be a parameter or by replacing the
        variable with the value itself.

        Parameters
        ----------
        routine : :any:`Subroutine`
            The subroutine to be transformed.
        **kwargs : optional
            Keyword arguments for the transformation.
        """

        item = kwargs.get('item', None)
        role = kwargs.get('role', None)
        sub_sgraph = kwargs.get('sub_sgraph', None)
        successors = as_tuple(sub_sgraph.successors(item)) if sub_sgraph is not None else ()

        successor_map = CaseInsensitiveDict(
            (successor.local_name, successor)
            for successor in successors
        )

        # decide whether subroutine is an entry point or not
        process_entry_point = False
        if self.entry_points is None:
            if role is not None and role == "driver":
                dic2p = self.dic2p
                process_entry_point = True
            else:
                if self._key in item.trafo_data:
                    dic2p = item.trafo_data[self._key]
                else:
                    dic2p = {}
        else:
            if routine.name.upper() in self.entry_points:
                dic2p = self.dic2p
                process_entry_point = True
            else:
                if self._key in item.trafo_data:
                    dic2p = item.trafo_data[self._key]
                else:
                    dic2p = {}

        vars2p = list(dic2p)

        # proceed if dictionary with mapping of variables to parametrised is not empty
        if dic2p:
            if process_entry_point:
                # rename arguments that are parametrised (to allow for sanity checks)
                arguments = []
                for arg in routine.arguments:
                    if arg.name not in vars2p:
                        arguments.append(arg)
                    else:
                        arguments.append(arg.clone(name=f'parametrised_{arg.name}'))
                routine.arguments = arguments
                # introduce sanity check
                for key, value in reversed(dic2p.items()):
                    if f'parametrised_{key}' in routine.variable_map:
                        error_msg = f"Variable {key} parametrised to value {value}, but subroutine {routine.name} " \
                                    f"received another value"
                        condition = sym.Comparison(routine.variable_map[f'parametrised_{key}'], '!=',
                                                   sym.IntLiteral(value))
                        comment = ir.Comment(f"! Stop execution: {error_msg}")
                        parametrised_var = routine.variable_map[f'parametrised_{key}']
                        # use default abort mechanism
                        if self.abort_callback is None:
                            abort = (ir.Intrinsic(text=f'PRINT *, "{error_msg}: ", {parametrised_var.name}'),
                                     ir.Intrinsic(text="STOP 1"))
                        # use user define abort/warn mechanism
                        else:
                            kwargs = {"msg": error_msg, "routine": routine.name, "var": parametrised_var,
                                      "value": value}
                            abort = self.abort_callback(**kwargs)
                        body = (comment,) + abort
                        conditional = ir.Conditional(condition=condition,
                                                     body=body, else_body=None)
                        routine.body.prepend(conditional)
                        routine.body.prepend(ir.Comment(f"! Sanity check for parametrised variable: {key}"))
            else:
                routine.arguments = [arg for arg in routine.arguments if arg.name not in vars2p]

            # remove variables to be parametrised from all call statements
            call_map = {}
            for call in FindNodes(ir.CallStatement).visit(routine.body):
                if str(call.name) in successor_map:
                    successor_map[str(call.name)].trafo_data[self._key] = {}
                    arg_map = dict(call.arg_iter())
                    arg_map_reversed = {v: k for k, v in arg_map.items()}
                    indices = [call.arguments.index(var2p) for var2p in vars2p if var2p in call.arguments]
                    for index in indices:
                        name = str(call.name)
                        successor_map[name].trafo_data[self._key][str(arg_map_reversed[call.arguments[index]])] = \
                            dic2p[call.arguments[index].name]
                    arguments = tuple(arg for arg in call.arguments if arg not in vars2p)
                    call_map[call] = call.clone(arguments=arguments)
            routine.body = Transformer(call_map).visit(routine.body)

            # remove declarations
            declarations = FindNodes(ir.VariableDeclaration).visit(routine.spec)
            parameter_declarations = []
            decl_map = {}
            for decl in declarations:
                symbols = []
                for smbl in decl.symbols:
                    if smbl in vars2p:
                        parameter_declarations.append(decl.clone(symbols=(smbl.clone(
                            type=decl.symbols[0].type.clone(parameter=True, intent=None,
                                                            initial=sym.IntLiteral(
                                                                dic2p[smbl.name]))),))) # or smbl.name?
                    else:
                        symbols.append(smbl.clone())

                    if symbols:
                        decl_map[decl] = decl.clone(symbols=as_tuple(symbols))
                    else:
                        decl_map[decl] = None
            routine.spec = Transformer(decl_map).visit(routine.spec)

            # introduce parameter declarations
            declarations = FindNodes(ir.VariableDeclaration).visit(routine.spec)
            for parameter_declaration in parameter_declarations:
                routine.spec.insert(routine.spec.body.index(declarations[0]), parameter_declaration)

            # replace all parameter variables with their corresponding value (inline constant parameters)
            if self.replace_by_value:
                inline_constant_parameters(routine=routine, external_only=False)
loki-ecmwf-0.3.6/loki/transformations/array_indexing/0000775000175000017500000000000015167130205023133 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/transformations/array_indexing/demote.py0000664000175000017500000000706215167130205024767 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

""" Utilities for demoting the rank of array variables. """

from loki.expression import symbols as sym
from loki.ir import (
    nodes as ir, FindNodes, Transformer, FindVariables,
    SubstituteExpressions
)
from loki.logging import info
from loki.tools import as_tuple, CaseInsensitiveDict


__all__ = ['demote_variables']


def demote_variables(routine, variable_names, dimensions):
    """
    Demote a list of array variables by removing any occurence of a
    provided set of dimension symbols.

    Parameters
    ----------
    routine : :any:`Subroutine`
        The subroutine in which the variables should be promoted.
    variable_names : list of str
        The names of variables to be promoted. Matching of variables against
        names is case-insensitive.
    dimensions : :py:class:`pymbolic.Expression` or tuple
        Symbol name or tuple of symbol names representing the dimension
        to remove from all occurances of the named variables.
    """
    dimensions = as_tuple(dimensions)

    # Compare lower-case only, since we're not comparing symbols
    vnames = tuple(name.lower() for name in variable_names)

    variables = FindVariables(unique=False).visit(routine.ir)
    variables = tuple(v for v in variables if v.name.lower() in vnames)
    variables = tuple(v for v in variables if hasattr(v, 'shape'))

    if not variables:
        return

    # Record original array shapes
    shape_map = CaseInsensitiveDict({v.name: v.shape for v in variables})

    # Remove shape and dimension entries from each variable in the list
    vmap = {}
    for v in variables:
        old_shape = shape_map[v.name]
        new_shape = tuple(s for s in old_shape if s not in dimensions)
        new_dims = tuple(d for d, s in zip(v.dimensions, old_shape) if s in new_shape)

        new_type = v.type.clone(shape=new_shape or None)
        vmap[v] = v.clone(dimensions=new_dims or None, type=new_type)

    # Propagate the new dimensions to declarations and routine bodys
    routine.body = SubstituteExpressions(vmap).visit(routine.body)
    routine.spec = SubstituteExpressions(vmap).visit(routine.spec)

    # Ensure all declarations with `DIMENSION` keywords are modified too!
    decls = tuple(
        d for d in FindNodes(ir.VariableDeclaration).visit(routine.spec)
        if d.dimensions and any(s.name.lower() in vnames for s in d.symbols)
    )
    decl_map = {}
    for decl in decls:
        # If all symbols have the same shape (after demotion)
        sym_shape = tuple(s.shape if isinstance(s, sym.Array) else None for s in decl.symbols)
        if all(d == sym_shape[0] for d in sym_shape):
            dimensions = decl.symbols[0].shape if isinstance(decl.symbols[0], sym.Array) else None
            decl_map[decl] = decl.clone(dimensions=dimensions)
        else:
            # If not, split into multiple declarations
            sdims = tuple(s.shape if isinstance(s, sym.Array) else None for s in decl.symbols)
            decl_map[decl] = tuple(
                decl.clone(symbols=(s,), dimensions=d) for s, d in zip(decl.symbols, sdims)
            )
    routine.spec = Transformer(decl_map).visit(routine.spec)

    info(f'[Loki::Transform] Demoted variables in {routine.name}: {", ".join(variable_names)}')
loki-ecmwf-0.3.6/loki/transformations/array_indexing/__init__.py0000664000175000017500000000134215167130205025244 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Collection of utility routines to deal with common array indexing conversions.
"""

from loki.transformations.array_indexing.array_indices import * # noqa
from loki.transformations.array_indexing.demote import * # noqa
from loki.transformations.array_indexing.promote import * # noqa
from loki.transformations.array_indexing.vector_notation import * # noqa
loki-ecmwf-0.3.6/loki/transformations/array_indexing/tests/0000775000175000017500000000000015167130205024275 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/transformations/array_indexing/tests/__init__.py0000664000175000017500000000057015167130205026410 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
loki-ecmwf-0.3.6/loki/transformations/array_indexing/tests/test_array_promote.py0000664000175000017500000001275515167130205030603 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest
import numpy as np

from loki import Subroutine
from loki.jit_build import jit_compile
from loki.expression import symbols as sym
from loki.frontend import available_frontends

from loki.transformations.array_indexing.promote import promote_variables


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_promote_variable_scalar(tmp_path, frontend):
    """
    Apply variable promotion for a single scalar variable.
    """
    fcode = """
subroutine transform_promote_variable_scalar(ret)
  implicit none
  integer, intent(out) :: ret
  integer :: tmp, jk

  ret = 0
  do jk=1,10
    tmp = jk
    ret = ret + tmp
  end do
end subroutine transform_promote_variable_scalar
    """.strip()
    routine = Subroutine.from_source(fcode, frontend=frontend)

    # Test the original implementation
    filepath = tmp_path/(f'{routine.name}_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname=routine.name)
    ret = function()
    assert ret == 55

    # Apply and test the transformation
    assert isinstance(routine.variable_map['tmp'], sym.Scalar)
    promote_variables(routine, ['TMP'], pos=0, index=routine.variable_map['JK'], size=sym.Literal(10))
    assert isinstance(routine.variable_map['tmp'], sym.Array)
    assert routine.variable_map['tmp'].shape == (sym.Literal(10),)

    promoted_filepath = tmp_path/(f'{routine.name}_promoted_{frontend}.f90')
    promoted_function = jit_compile(routine, filepath=promoted_filepath, objname=routine.name)
    ret = promoted_function()
    assert ret == 55


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_promote_variables(tmp_path, frontend):
    """
    Apply variable promotion for scalar and array variables.
    """
    fcode = """
subroutine transform_promote_variables(scalar, vector, n)
  implicit none
  integer, intent(in) :: n
  integer, intent(inout) :: scalar, vector(n)
  integer :: tmp_scalar, tmp_vector(n), tmp_matrix(n,n)
  integer :: jl, jk

  do jl=1,n
    ! a bit of a hack to create initialized meaningful output
    tmp_vector(:) = 0
  end do

  do jl=1,n
    tmp_scalar = jl
    tmp_vector(jl) = jl

    do jk=1,n
      tmp_matrix(jk, jl) = jl + jk
    end do
  end do

  scalar = 0
  do jl=1,n
    scalar = scalar + tmp_scalar
    vector = tmp_matrix(:,jl) + tmp_vector(:)
  end do
end subroutine transform_promote_variables
    """.strip()
    routine = Subroutine.from_source(fcode, frontend=frontend)

    # Test the original implementation
    filepath = tmp_path/(f'{routine.name}_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname=routine.name)

    n = 10
    scalar = np.array(0)
    vector = np.zeros(shape=(n,), order='F', dtype=np.int32)
    function(scalar, vector, n)
    assert scalar == n*n
    assert np.all(vector == np.array(list(range(1, 2*n+1, 2)), order='F', dtype=np.int32) + n + 1)

    # Verify dimensions before promotion
    assert isinstance(routine.variable_map['tmp_scalar'], sym.Scalar)
    assert isinstance(routine.variable_map['tmp_vector'], sym.Array)
    assert routine.variable_map['tmp_vector'].shape == (routine.variable_map['n'],)
    assert isinstance(routine.variable_map['tmp_matrix'], sym.Array)
    assert routine.variable_map['tmp_matrix'].shape == (routine.variable_map['n'], routine.variable_map['n'])

    # Promote scalar and vector and verify dimensions
    promote_variables(routine, ['tmp_scalar', 'tmp_vector'], pos=-1, index=routine.variable_map['JL'],
                      size=routine.variable_map['n'])

    assert isinstance(routine.variable_map['tmp_scalar'], sym.Array)
    assert routine.variable_map['tmp_scalar'].shape == (routine.variable_map['n'],)
    assert isinstance(routine.variable_map['tmp_vector'], sym.Array)
    assert routine.variable_map['tmp_vector'].shape == (routine.variable_map['n'], routine.variable_map['n'])
    assert isinstance(routine.variable_map['tmp_matrix'], sym.Array)
    assert routine.variable_map['tmp_matrix'].shape == (routine.variable_map['n'], routine.variable_map['n'])

    # Promote matrix and verify dimensions
    promote_variables(routine, ['tmp_matrix'], pos=1, index=routine.variable_map['JL'],
                      size=routine.variable_map['n'])

    assert isinstance(routine.variable_map['tmp_scalar'], sym.Array)
    assert routine.variable_map['tmp_scalar'].shape == (routine.variable_map['n'],)
    assert isinstance(routine.variable_map['tmp_vector'], sym.Array)
    assert routine.variable_map['tmp_vector'].shape == (routine.variable_map['n'], routine.variable_map['n'])
    assert isinstance(routine.variable_map['tmp_matrix'], sym.Array)
    assert routine.variable_map['tmp_matrix'].shape == (routine.variable_map['n'], ) * 3

    # Test promoted routine
    promoted_filepath = tmp_path/(f'{routine.name}_promoted_{frontend}.f90')
    promoted_function = jit_compile(routine, filepath=promoted_filepath, objname=routine.name)

    scalar = np.array(0)
    vector = np.zeros(shape=(n,), order='F', dtype=np.int32)
    promoted_function(scalar, vector, n)
    assert scalar == n*(n+1)//2
    assert np.all(vector[:-1] == np.array(list(range(n + 1, 2*n)), order='F', dtype=np.int32))
    assert vector[-1] == 3*n
loki-ecmwf-0.3.6/loki/transformations/array_indexing/tests/test_array_indexing.py0000664000175000017500000007555415167130205030731 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import platform
import pytest
import numpy as np

from loki import Module, Subroutine, fgen
from loki.jit_build import jit_compile, jit_compile_lib, clean_test, Builder, Obj
from loki.expression import symbols as sym
from loki.frontend import available_frontends, OMNI
from loki.ir import nodes as ir, FindNodes, FindVariables

from loki.transformations.array_indexing.array_indices import (
    invert_array_indices, flatten_arrays,
    normalize_array_shape_and_access, shift_to_zero_indexing,
    LowerConstantArrayIndices
)
from loki.transformations.transpile import (
    FortranCTransformation, FortranISOCWrapperTransformation
)


@pytest.fixture(scope='function', name='builder')
def fixture_builder(tmp_path):
    yield Builder(source_dirs=tmp_path, build_dir=tmp_path)
    Obj.clear_cache()



@pytest.mark.skipif(platform.system() == 'Darwin', reason='Unclear issue causing problems on MacOS (#352)')
@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('start_index', (0, 1, 5))
def test_transform_normalize_array_shape_and_access(tmp_path, frontend, start_index):
    """
    Test normalization of array shape and access, thus changing arrays with start
    index different than "1" to have start index "1".

    E.g., ``x1(5:len)`` -> ```x1(1:len-4)``
    """
    fcode = f"""
    module norm_arr_shape_access_mod
    implicit none

    contains

    subroutine norm_arr_shape_access(x1, x2, x3, x4, assumed_x1, l1, l2, l3, l4)
        ! use nested_routine_mod, only : nested_routine
        implicit none
        integer :: i1, i2, i3, i4, c1, c2, c3, c4
        integer, intent(in) :: l1, l2, l3, l4
        integer, intent(inout) :: x1({start_index}:l1+{start_index}-1)
        integer, intent(inout) :: x2({start_index}:l2+{start_index}-1, &
         & {start_index}:l1+{start_index}-1)
        integer, intent(inout) :: x3({start_index}:l3+{start_index}-1, &
         & {start_index}:l2+{start_index}-1, {start_index}:l1+{start_index}-1)
        integer, intent(inout) :: x4({start_index}:l4+{start_index}-1, &
         & {start_index}:l3+{start_index}-1, {start_index}:l2+{start_index}-1, &
         & {start_index}:l1+{start_index}-1)
        integer, intent(inout) :: assumed_x1(l1)
        c1 = 1
        c2 = 1
        c3 = 1
        c4 = 1
        do i1=1,l1
            assumed_x1(i1) = c1
            call nested_routine(assumed_x1, l1, c1)
        end do
        x1({start_index}:l4+{start_index}-1) = 0
        do i1={start_index},l1+{start_index}-1
            x1(i1) = c1
            do i2={start_index},l2+{start_index}-1
                x2(i2, i1) = c2*10 + c1
                do i3={start_index},l3+{start_index}-1
                    x3(i3, i2, i1) = c3*100 + c2*10 + c1
                    do i4={start_index},l4+{start_index}-1
                        x4(i4, i3, i2, i1) = c4*1000 + c3*100 + c2*10 + c1
                        c4 = c4 + 1
                    end do
                    c3 = c3 + 1
                end do
                c2 = c2 + 1
            end do
            c1 = c1 + 1
        end do
    end subroutine norm_arr_shape_access

    subroutine nested_routine(nested_x1, l1, c1)
        implicit none
        integer, intent(in) :: l1, c1
        integer, intent(inout) :: nested_x1(:)
        integer :: i1
        do i1=1,l1
            nested_x1(i1) = c1
        end do
    end subroutine nested_routine

    end module norm_arr_shape_access_mod
    """

    def init_arguments(l1, l2, l3, l4):
        x1 = np.zeros(shape=(l1,), order='F', dtype=np.int32)
        assumed_x1 = np.zeros(shape=(l1,), order='F', dtype=np.int32)
        x2 = np.zeros(shape=(l2,l1,), order='F', dtype=np.int32)
        x3 = np.zeros(shape=(l3,l2,l1,), order='F', dtype=np.int32)
        x4 = np.zeros(shape=(l4,l3,l2,l1,), order='F', dtype=np.int32)
        return x1, x2, x3, x4, assumed_x1

    def validate_routine(routine):
        arrays = [var for var in FindVariables().visit(routine.body) if isinstance(var, sym.Array)]
        for arr in arrays:
            assert all(not isinstance(shape, sym.RangeIndex) for shape in arr.shape)

    l1 = 2
    l2 = 3
    l3 = 4
    l4 = 5
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    filepath = tmp_path/(f'norm_arr_shape_access_{frontend}.f90')
    # compile and test "original" module/function
    mod = jit_compile(module, filepath=filepath, objname='norm_arr_shape_access_mod')
    function = getattr(mod, 'norm_arr_shape_access')
    orig_x1, orig_x2, orig_x3, orig_x4, orig_assumed_x1 = init_arguments(l1, l2, l3, l4)
    function(orig_x1, orig_x2, orig_x3, orig_x4, orig_assumed_x1, l1, l2, l3, l4)
    clean_test(filepath)

    # apply `normalize_array_shape_and_access`
    for routine in module.routines:
        normalize_array_shape_and_access(routine)

    filepath = tmp_path/(f'norm_arr_shape_access_normalized_{frontend}.f90')
    # compile and test "normalized" module/function
    mod = jit_compile(module, filepath=filepath, objname='norm_arr_shape_access_mod')
    function = getattr(mod, 'norm_arr_shape_access')
    x1, x2, x3, x4, assumed_x1 = init_arguments(l1, l2, l3, l4)
    function(x1, x2, x3, x4, assumed_x1, l1, l2, l3, l4)
    clean_test(filepath)
    # validate the routine "norm_arr_shape_access"
    validate_routine(module.subroutines[0])
    # validate the nested routine to see whether the assumed size array got correctly handled
    assert module.subroutines[1].variable_map['nested_x1'] == 'nested_x1(:)'

    # check whether results generated by the "original" and "normalized" version agree
    assert (x1 == orig_x1).all()
    assert (assumed_x1 == orig_assumed_x1).all()
    assert (x2 == orig_x2).all()
    assert (x3 == orig_x3).all()
    assert (x4 == orig_x4).all()


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('start_index', (0, 1, 5))
def test_transform_flatten_arrays(tmp_path, frontend, builder, start_index):
    """
    Test flattening or arrays, meaning converting multi-dimensional
    arrays to one-dimensional arrays including corresponding
    index arithmetic.
    """
    fcode = f"""
    subroutine transf_flatten_arr(x1, x2, x3, x4, l1, l2, l3, l4)
        implicit none
        integer :: i1, i2, i3, i4, c1, c2, c3, c4
        integer, intent(in) :: l1, l2, l3, l4
        integer, intent(inout) :: x1({start_index}:l1+{start_index}-1)
        integer, intent(inout) :: x2({start_index}:l2+{start_index}-1, &
         & {start_index}:l1+{start_index}-1)
        integer, intent(inout) :: x3({start_index}:l3+{start_index}-1, &
         & {start_index}:l2+{start_index}-1, {start_index}:l1+{start_index}-1)
        integer, intent(inout) :: x4({start_index}:l4+{start_index}-1, &
         & {start_index}:l3+{start_index}-1, {start_index}:l2+{start_index}-1, &
         & {start_index}:l1+{start_index}-1)
        c1 = 1
        c2 = 1
        c3 = 1
        c4 = 1
        do i1={start_index},l1+{start_index}-1
            x1(i1) = c1
            do i2={start_index},l2+{start_index}-1
                x2(i2, i1) = c2*10 + c1
                do i3={start_index},l3+{start_index}-1
                    x3(i3, i2, i1) = c3*100 + c2*10 + c1
                    do i4={start_index},l4+{start_index}-1
                        x4(i4, i3, i2, i1) = c4*1000 + c3*100 + c2*10 + c1
                        c4 = c4 + 1
                    end do
                    c3 = c3 + 1
                end do
                c2 = c2 + 1
            end do
            c1 = c1 + 1
        end do

    end subroutine transf_flatten_arr
    """
    def init_arguments(l1, l2, l3, l4, flattened=False):
        x1 = np.zeros(shape=(l1,), order='F', dtype=np.int32)
        x2 = np.zeros(shape=(l2*l1) if flattened else (l2,l1,), order='F', dtype=np.int32)
        x3 = np.zeros(shape=(l3*l2*l1) if flattened else (l3,l2,l1,), order='F', dtype=np.int32)
        x4 = np.zeros(shape=(l4*l3*l2*l1) if flattened else (l4,l3,l2,l1,), order='F', dtype=np.int32)
        return x1, x2, x3, x4

    def validate_routine(routine):
        arrays = [var for var in FindVariables().visit(routine.body) if isinstance(var, sym.Array)]
        assert all(len(arr.dimensions) == 1 for arr in arrays)
        assert all(len(arr.shape) == 1 for arr in arrays)

    l1 = 2
    l2 = 3
    l3 = 4
    l4 = 5
    # Test the original implementation
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/(f'{routine.name}_{start_index}_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname=routine.name)
    orig_x1, orig_x2, orig_x3, orig_x4 = init_arguments(l1, l2, l3, l4)
    function(orig_x1, orig_x2, orig_x3, orig_x4, l1, l2, l3, l4)
    clean_test(filepath)

    # Test flattening order='F'
    f_routine = Subroutine.from_source(fcode, frontend=frontend)
    normalize_array_shape_and_access(f_routine)
    flatten_arrays(routine=f_routine, order='F', start_index=1)
    filepath = tmp_path/(f'{f_routine.name}_{start_index}_flattened_F_{frontend}.f90')
    function = jit_compile(f_routine, filepath=filepath, objname=routine.name)
    f_x1, f_x2, f_x3, f_x4 = init_arguments(l1, l2, l3, l4, flattened=True)
    function(f_x1, f_x2, f_x3, f_x4, l1, l2, l3, l4)
    validate_routine(f_routine)
    clean_test(filepath)

    assert (f_x1 == orig_x1.flatten(order='F')).all()
    assert (f_x2 == orig_x2.flatten(order='F')).all()
    assert (f_x3 == orig_x3.flatten(order='F')).all()
    assert (f_x4 == orig_x4.flatten(order='F')).all()

    # Test flattening order='C'
    c_routine = Subroutine.from_source(fcode, frontend=frontend)
    normalize_array_shape_and_access(c_routine)
    invert_array_indices(c_routine)
    flatten_arrays(routine=c_routine, order='C', start_index=1)
    filepath = tmp_path/(f'{c_routine.name}_{start_index}_flattened_C_{frontend}.f90')
    function = jit_compile(c_routine, filepath=filepath, objname=routine.name)
    c_x1, c_x2, c_x3, c_x4 = init_arguments(l1, l2, l3, l4, flattened=True)
    function(c_x1, c_x2, c_x3, c_x4, l1, l2, l3, l4)
    validate_routine(c_routine)
    clean_test(filepath)

    assert f_routine.body == c_routine.body

    assert (c_x1 == orig_x1.flatten(order='F')).all()
    assert (c_x2 == orig_x2.flatten(order='F')).all()
    assert (c_x3 == orig_x3.flatten(order='F')).all()
    assert (c_x4 == orig_x4.flatten(order='F')).all()

    # Test C transpilation (which includes flattening)
    f2c_routine = Subroutine.from_source(fcode, frontend=frontend)
    f2c = FortranCTransformation()
    f2c.apply(source=f2c_routine, path=tmp_path)
    f2cwrap = FortranISOCWrapperTransformation()
    f2cwrap.apply(source=f2c_routine, path=tmp_path)
    libname = f'fc_{f2c_routine.name}_{start_index}_{frontend}'
    c_kernel = jit_compile_lib(
        [tmp_path/f'{f2c_routine.name}_fc.F90', tmp_path/f'{f2c_routine.name}_c.c'],
        path=tmp_path, name=libname, builder=builder
    )
    fc_function = c_kernel.transf_flatten_arr_fc_mod.transf_flatten_arr_fc
    f2c_x1, f2c_x2, f2c_x3, f2c_x4 = init_arguments(l1, l2, l3, l4, flattened=True)
    fc_function(f2c_x1, f2c_x2, f2c_x3, f2c_x4, l1, l2, l3, l4)
    validate_routine(c_routine)

    assert (f2c_x1 == orig_x1.flatten(order='F')).all()
    assert (f2c_x2 == orig_x2.flatten(order='F')).all()
    assert (f2c_x3 == orig_x3.flatten(order='F')).all()
    assert (f2c_x4 == orig_x4.flatten(order='F')).all()


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('ignore', ((), ('i2',), ('i4', 'i1')))
def test_shift_to_zero_indexing(frontend, ignore):
    """
    Test shifting array dimensions to zero (or rather shift dimension `dim`
    to `dim - 1`). This does not produce valid Fortran, but is part of the
    F2C transpilation logic.
    """
    fcode = """
    subroutine transform_shift_indexing(x1, x2, x3, x4, l1, l2, l3, l4)
        implicit none
        integer :: i1, i2, i3, i4, c1, c2, c3, c4
        integer, intent(in) :: l1, l2, l3, l4
        integer, intent(inout) :: x1(l1)
        integer, intent(inout) :: x2(l2, l1)
        integer, intent(inout) :: x3(l3, l2, l1)
        integer, intent(inout) :: x4(l4, l3, l2, l1)
        c1 = 1
        c2 = 1
        c3 = 1
        c4 = 1
        do i1=1,l1
            x1(i1) = c1
            do i2=1,l2
                x2(i2, i1) = c2*10 + c1
                do i3=1,l3
                    x3(i3, i2, i1) = c3*100 + c2*10 + c1
                    do i4=1,l4
                        x4(i4, i3, i2, i1) = c4*1000 + c3*100 + c2*10 + c1
                        c4 = c4 + 1
                    end do
                    c3 = c3 + 1
                end do
                c2 = c2 + 1
            end do
            c1 = c1 + 1
        end do

    end subroutine transform_shift_indexing
    """

    expected_dims = {'x1': ('i1',), 'x2': ('i2', 'i1'),
            'x3': ('i3', 'i2', 'i1'), 'x4': ('i4', 'i3', 'i2', 'i1')}
    routine = Subroutine.from_source(fcode, frontend=frontend)
    arrays = [var for var in FindVariables().visit(routine.body) if isinstance(var, sym.Array)]
    for array in arrays:
        assert array.dimensions == expected_dims[array.name]

    shift_to_zero_indexing(routine, ignore=ignore)

    arrays = [var for var in FindVariables().visit(routine.body) if isinstance(var, sym.Array)]
    for array in arrays:
        dimensions = tuple(sym.Sum((sym.Scalar(name=dim), sym.Product((-1, sym.IntLiteral(1)))))
                if dim not in ignore else dim for dim in expected_dims[array.name])
        assert fgen(array.dimensions) == fgen(dimensions)


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('explicit_dimensions', [True, False])
def test_transform_flatten_arrays_call(tmp_path, frontend, builder, explicit_dimensions):
    """
    Test flattening or arrays, meaning converting multi-dimensional
    arrays to one-dimensional arrays including corresponding
    index arithmetic (for calls).
    """
    array_dims = '(:,:)' if explicit_dimensions else ''
    fcode_driver = f"""
SUBROUTINE driver_routine(nlon, nlev, a, b)
    use kernel_mod, only: kernel_routine
    INTEGER, INTENT(IN)    :: nlon, nlev
    INTEGER, INTENT(INOUT) :: a(nlon,nlev)
    INTEGER, INTENT(INOUT)  :: b(nlon,nlev)

    call kernel_routine(nlon, nlev, a{array_dims}, b{array_dims})

END SUBROUTINE driver_routine
    """
    fcode_kernel = """
module kernel_mod
IMPLICIT NONE
CONTAINS
SUBROUTINE kernel_routine(nlon, nlev, a, b)
    INTEGER, INTENT(IN)    :: nlon, nlev
    INTEGER, INTENT(INOUT) :: a(nlon,nlev)
    INTEGER, INTENT(INOUT) :: b(nlon,nlev)
    INTEGER :: i, j

    do j=1, nlon
      do i=1, nlev
        a(i,j) = i*10 + j
        b(i,j) = i*10 + j + 1
      end do
    end do
END SUBROUTINE kernel_routine
end module kernel_mod
    """
    def init_arguments(nlon, nlev, flattened=False):
        a = np.zeros(shape=(nlon*nlev) if flattened else (nlon,nlev,), order='F', dtype=np.int32)
        b = np.zeros(shape=(nlon*nlev) if flattened else (nlon,nlev,), order='F', dtype=np.int32)
        return a, b

    def validate_routine(routine):
        arrays = [var for var in FindVariables().visit(routine.body) if isinstance(var, sym.Array)]
        assert all(len(arr.dimensions) == 1 or not arr.dimensions for arr in arrays)
        assert all(len(arr.shape) == 1 for arr in arrays)

    kernel_module = Module.from_source(fcode_kernel, frontend=frontend, xmods=[tmp_path])
    driver = Subroutine.from_source(fcode_driver, frontend=frontend, xmods=[tmp_path],
            definitions=kernel_module)
    kernel = kernel_module.subroutines[0]

    # check for a(:,:) and b(:,:) if "explicit_dimensions"
    call = FindNodes(ir.CallStatement).visit(driver.body)[0]
    if explicit_dimensions:
        assert call.arguments[-2].dimensions == (sym.RangeIndex((None, None)), sym.RangeIndex((None, None)))
        assert call.arguments[-1].dimensions == (sym.RangeIndex((None, None)), sym.RangeIndex((None, None)))
    else:
        assert call.arguments[-2].dimensions == ()
        assert call.arguments[-1].dimensions == ()

    # compile and test reference
    refname = f'ref_{driver.name}_{frontend}'
    reference = jit_compile_lib([kernel_module, driver], path=tmp_path, name=refname, builder=builder)
    ref_function = reference.driver_routine

    nlon = 10
    nlev = 12
    a_ref, b_ref = init_arguments(nlon, nlev)
    ref_function(nlon, nlev, a_ref, b_ref)
    builder.clean()

    # flatten all the arrays in the kernel and driver
    flatten_arrays(routine=kernel, order='F', start_index=1)
    flatten_arrays(routine=driver, order='F', start_index=1)

    # check whether all the arrays are 1-dimensional
    validate_routine(kernel)
    validate_routine(driver)

    # compile and test the flattened variant
    flattenedname = f'flattened_{driver.name}_{frontend}'
    flattened = jit_compile_lib([kernel_module, driver], path=tmp_path, name=flattenedname, builder=builder)
    flattened_function = flattened.driver_routine

    a_flattened, b_flattened = init_arguments(nlon, nlev, flattened=True)
    flattened_function(nlon, nlev, a_flattened, b_flattened)

    # check whether reference and flattened variant(s) produce same result
    assert (a_flattened == a_ref.flatten(order='F')).all()
    assert (b_flattened == b_ref.flatten(order='F')).all()

    builder.clean()

@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('recurse_to_kernels', (False, True))
@pytest.mark.parametrize('inline_external_only', (False, True))
@pytest.mark.parametrize('pass_as_kwarg', (False, True,))
def test_lower_constant_array_indices(tmp_path, frontend, recurse_to_kernels, inline_external_only, pass_as_kwarg):
    """
    Test lowering constant array indices
    """
    fcode_driver = f"""
subroutine driver(nlon,nlev,nb,var)
  use kernel_mod, only: kernel
  implicit none
  integer, parameter :: param_1 = 1
  integer, parameter :: param_2 = 2
  integer, parameter :: param_3 = 5
  integer, intent(in) :: nlon,nlev,nb
  real, intent(inout) :: var(nlon,nlev,param_3,nb)
  integer :: ibl
  integer :: offset
  integer :: some_val
  integer :: loop_start, loop_end
  loop_start = 2
  loop_end = nb
  some_val = 0
  offset = 1
  !$omp test
  do ibl=loop_start, loop_end
    call kernel(nlon,nlev,{'var=' if pass_as_kwarg else ''}var(:,:,param_1,ibl), {'another_var=' if pass_as_kwarg else ''}var(:,:,param_2:param_3,ibl), {'icend=' if pass_as_kwarg else ''}offset, {'lstart=' if pass_as_kwarg else ''}loop_start, {'lend=' if pass_as_kwarg else ''}loop_end)
    call kernel(nlon,nlev,{'var=' if pass_as_kwarg else ''}var(:,:,param_1,ibl), {'another_var=' if pass_as_kwarg else ''}var(:,:,param_2:param_3,ibl), {'icend=' if pass_as_kwarg else ''}offset, {'lstart=' if pass_as_kwarg else ''}loop_start, {'lend=' if pass_as_kwarg else ''}loop_end)
    ! call kernel(nlon,nlev,var(:,:,param_1,ibl), var(:,:,param_2:param_3,ibl), offset, loop_start, loop_end)
  enddo
end subroutine driver
"""

    fcode_kernel = """
module kernel_mod
implicit none
contains
subroutine kernel(nlon,nlev,var,another_var,icend,lstart,lend)
  use compute_mod, only: compute
  implicit none
  integer, intent(in) :: nlon,nlev,icend,lstart,lend
  real, intent(inout) :: var(nlon,nlev)
  real, intent(inout) :: another_var(nlon,nlev,4)
  integer :: jk, jl, jt
  var(:,:) = 0.
  do jk = 1,nlev
    do jl = 1, nlon
      var(jl, jk) = 0.
      do jt= 1,4
        another_var(jl, jk, jt) = 0.0
      end do
    end do
  end do
  call compute(nlon,nlev,var)
  call compute(nlon,nlev,var)
end subroutine kernel
end module kernel_mod
"""

    fcode_nested_kernel = """
module compute_mod
implicit none
contains
subroutine compute(nlon,nlev,var)
  implicit none
  integer, intent(in) :: nlon,nlev
  real, intent(inout) :: var(nlon,nlev)
  var(:,:) = 0.
end subroutine compute
end module compute_mod
"""

    nested_kernel_mod = Module.from_source(fcode_nested_kernel, frontend=frontend, xmods=[tmp_path])
    kernel_mod = Module.from_source(fcode_kernel, frontend=frontend, definitions=nested_kernel_mod, xmods=[tmp_path])
    driver = Subroutine.from_source(fcode_driver, frontend=frontend, definitions=kernel_mod, xmods=[tmp_path])

    kwargs = {'recurse_to_kernels': recurse_to_kernels, 'inline_external_only': inline_external_only}
    LowerConstantArrayIndices(**kwargs).apply(driver, role='driver', targets=('kernel',))
    LowerConstantArrayIndices(**kwargs).apply(kernel_mod['kernel'], role='kernel', targets=('compute',))
    LowerConstantArrayIndices(**kwargs).apply(nested_kernel_mod['compute'], role='kernel')

    # driver
    kernel_calls = FindNodes(ir.CallStatement).visit(driver.body)
    for kernel_call in kernel_calls:
        if pass_as_kwarg:
            arg1 = kernel_call.kwarguments[0][1]
            arg2 = kernel_call.kwarguments[1][1]
        else:
            arg1 = kernel_call.arguments[2]
            arg2 = kernel_call.arguments[3]
        if inline_external_only and frontend != OMNI:
            assert arg1.dimensions == (':', ':', 'param_1', 'ibl')
            assert arg2.dimensions == (':', ':', 'param_2:param_3', 'ibl')
        else:
            assert arg1.dimensions == (':', ':', ':', 'ibl')
            assert arg2.dimensions == (':', ':', ':', 'ibl')
    # kernel
    kernel_vars = kernel_mod['kernel'].variable_map
    if inline_external_only and frontend != OMNI:
        assert kernel_vars['var'].shape == ('nlon', 'nlev')
        assert kernel_vars['var'].dimensions == ('nlon', 'nlev')
        assert kernel_vars['another_var'].shape == ('nlon', 'nlev', 4)
        assert kernel_vars['another_var'].dimensions == ('nlon', 'nlev', 4)
    else:
        assert kernel_vars['var'].shape == ('nlon', 'nlev', 5)
        assert kernel_vars['var'].dimensions == ('nlon', 'nlev', 5)
        assert kernel_vars['another_var'].shape == ('nlon', 'nlev', 5)
        assert kernel_vars['another_var'].dimensions == ('nlon', 'nlev', 5)
    if inline_external_only and frontend != OMNI:
        for var in FindVariables().visit(kernel_mod['kernel'].body):
            if var.name.lower() == 'var' and not any(isinstance(dim, sym.RangeIndex) for dim in var.dimensions):
                assert var.dimensions == ('jl', 'jk')
            if var.name.lower() == 'another_var' and not any(isinstance(dim, sym.RangeIndex) for dim in var.dimensions):
                assert tuple(str(dim) for dim in var.dimensions) == ('jl', 'jk', 'jt')
    else:
        for var in FindVariables().visit(kernel_mod['kernel'].body):
            if var.name.lower() == 'var' and not any(isinstance(dim, sym.RangeIndex) for dim in var.dimensions):
                assert var.dimensions == ('jl', 'jk', 1)
            if var.name.lower() == 'another_var' and not any(isinstance(dim, sym.RangeIndex) for dim in var.dimensions):
                assert tuple(str(dim) for dim in var.dimensions) == ('jl', 'jk', 'jt + 2 + -1')
    compute_calls = FindNodes(ir.CallStatement).visit(kernel_mod['kernel'].body)
    for compute_call in compute_calls:
        for arg in compute_call.arguments:
            if arg.name.lower() == 'var':
                if inline_external_only and frontend != OMNI:
                    assert arg.dimensions == (':', ':')
                elif recurse_to_kernels:
                    assert arg.dimensions == (':', ':', ':')
                else:
                    assert arg.dimensions == (':', ':', '1')
    # nested kernel
    nested_kernel_var = nested_kernel_mod['compute'].variable_map['var']
    if recurse_to_kernels and (not inline_external_only or frontend == OMNI):
        assert nested_kernel_var.shape == ('nlon', 'nlev', 5)
        assert nested_kernel_var.dimensions == ('nlon', 'nlev', 5)
        for var in FindVariables().visit(nested_kernel_mod['compute'].body):
            if var.name.lower() == 'var':
                assert var.dimensions == (':', ':', 1)
    else:
        assert nested_kernel_var.shape == ('nlon', 'nlev')
        assert nested_kernel_var.dimensions == ('nlon', 'nlev')
        for var in FindVariables().visit(nested_kernel_mod['compute'].body):
            if var.name.lower() == 'var':
                assert var.dimensions == (':', ':')


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('recurse_to_kernels', (False, True,))
@pytest.mark.parametrize('inline_external_only', (False, True,))
def test_lower_constant_array_indices_academic(tmp_path, frontend, recurse_to_kernels, inline_external_only):
    """
    Test lowering constant array indices for a valid but somewhat academic example ...

    The transformation is capable to handle that, but let's just hope we'll never see
    something like that out there in the wild ...
    """
    fcode_driver = """
subroutine driver(nlon,nlev,nb,var)
  use kernel_mod, only: kernel
  implicit none
  integer, parameter :: param_1 = 1
  integer, parameter :: param_2 = 2
  integer, parameter :: param_3 = 5
  integer, intent(in) :: nlon,nlev,nb
  real, intent(inout) :: var(nlon,4,3,nlev,param_3,nb)
  ! real, intent(inout) :: var(nlon,3,nlev,param_3,nb)
  integer :: ibl, j
  integer :: offset
  integer :: some_val
  integer :: loop_start, loop_end
  loop_start = 2
  loop_end = nb
  some_val = 0
  offset = 1
  !$omp test
  do ibl=loop_start, loop_end
    do j=1,4
      call kernel(nlon,nlev,var(:,j,1,:,param_1,ibl), var(:,j,2:3,:,param_2:param_3,ibl), offset, loop_start, loop_end)
      call kernel(nlon,nlev,var(:,j,1,:,param_1,ibl), var(:,j,2:3,:,param_2:param_3,ibl), offset, loop_start, loop_end)
    end do
  enddo
end subroutine driver
"""

    fcode_kernel = """
module kernel_mod
implicit none
contains
subroutine kernel(nlon,nlev,var,another_var,icend,lstart,lend)
  use compute_mod, only: compute
  implicit none
  integer, intent(in) :: nlon,nlev,icend,lstart,lend
  real, intent(inout) :: var(nlon,nlev)
  real, intent(inout) :: another_var(nlon,2,nlev,4)
  integer :: jk, jl, jt
  var(:,:) = 0.
  do jk = 1,nlev
    do jl = 1, nlon
      var(jl, jk) = 0.
      do jt= 1,4
        another_var(jl, 1, jk, jt) = 0.0
      end do
    end do
  end do
  call compute(nlon,nlev,var)
  call compute(nlon,nlev,var)
end subroutine kernel
end module kernel_mod
"""

    fcode_nested_kernel = """
module compute_mod
implicit none
contains
subroutine compute(nlon,nlev,var)
  implicit none
  integer, intent(in) :: nlon,nlev
  real, intent(inout) :: var(nlon,nlev)
  var(:,:) = 0.
end subroutine compute
end module compute_mod
"""

    nested_kernel_mod = Module.from_source(fcode_nested_kernel, frontend=frontend, xmods=[tmp_path])
    kernel_mod = Module.from_source(fcode_kernel, frontend=frontend, definitions=nested_kernel_mod, xmods=[tmp_path])
    driver = Subroutine.from_source(fcode_driver, frontend=frontend, definitions=kernel_mod, xmods=[tmp_path])

    kwargs = {'recurse_to_kernels': recurse_to_kernels, 'inline_external_only': inline_external_only}
    LowerConstantArrayIndices(**kwargs).apply(driver, role='driver', targets=('kernel',))
    LowerConstantArrayIndices(**kwargs).apply(kernel_mod['kernel'], role='kernel', targets=('compute',))
    LowerConstantArrayIndices(**kwargs).apply(nested_kernel_mod['compute'], role='kernel')

    # driver
    kernel_calls = FindNodes(ir.CallStatement).visit(driver.body)
    for kernel_call in kernel_calls:
        if inline_external_only and frontend != OMNI:
            assert kernel_call.arguments[2].dimensions == (':', 'j', ':', ':', 'param_1', 'ibl')
            assert kernel_call.arguments[3].dimensions == (':', 'j', ':', ':', 'param_2:param_3', 'ibl')
        else:
            assert kernel_call.arguments[2].dimensions == (':', 'j', ':', ':', ':', 'ibl')
            assert kernel_call.arguments[3].dimensions == (':', 'j', ':', ':', ':', 'ibl')
    # kernel
    kernel_vars = kernel_mod['kernel'].variable_map
    if inline_external_only and frontend != OMNI:
        assert kernel_vars['var'].shape == ('nlon', 3, 'nlev')
        assert kernel_vars['var'].dimensions == ('nlon', 3, 'nlev')
        assert kernel_vars['another_var'].shape == ('nlon', 3, 'nlev', 4)
        assert kernel_vars['another_var'].dimensions == ('nlon', 3, 'nlev', 4)
    else:
        assert kernel_vars['var'].shape == ('nlon', '3', 'nlev', 5)
        assert kernel_vars['var'].dimensions == ('nlon', '3', 'nlev', 5)
        assert kernel_vars['another_var'].shape == ('nlon', '3', 'nlev', 5)
        assert kernel_vars['another_var'].dimensions == ('nlon', '3', 'nlev', 5)
    if inline_external_only and frontend != OMNI:
        for var in FindVariables().visit(kernel_mod['kernel'].body):
            if var.name.lower() == 'var' and not any(isinstance(dim, sym.RangeIndex) for dim in var.dimensions):
                assert var.dimensions == ('jl', 1, 'jk')
            if var.name.lower() == 'another_var' and not any(isinstance(dim, sym.RangeIndex) for dim in var.dimensions):
                assert tuple(str(dim) for dim in var.dimensions) == ('jl', '1 + 2 + -1', 'jk', 'jt')
    else:
        for var in FindVariables().visit(kernel_mod['kernel'].body):
            if var.name.lower() == 'var' and not any(isinstance(dim, sym.RangeIndex) for dim in var.dimensions):
                assert var.dimensions == ('jl', 1, 'jk', 1)
            if var.name.lower() == 'another_var' and not any(isinstance(dim, sym.RangeIndex) for dim in var.dimensions):
                assert tuple(str(dim) for dim in var.dimensions) == ('jl', '1 + 2 + -1', 'jk', 'jt + 2 + -1')
    compute_calls = FindNodes(ir.CallStatement).visit(kernel_mod['kernel'].body)
    for compute_call in compute_calls:
        for arg in compute_call.arguments:
            if arg.name.lower() == 'var':
                if inline_external_only and frontend != OMNI:
                    if recurse_to_kernels:
                        assert arg.dimensions == (':', ':', ':')
                    else:
                        assert arg.dimensions == (':', 1, ':')
                elif recurse_to_kernels:
                    assert arg.dimensions == (':', ':', ':', ':')
                else:
                    assert arg.dimensions == (':', 1, ':', '1')
    # nested kernel
    nested_kernel_var = nested_kernel_mod['compute'].variable_map['var']
    if recurse_to_kernels and (not inline_external_only or frontend == OMNI):
        assert nested_kernel_var.shape == ('nlon', 3, 'nlev', 5)
        assert nested_kernel_var.dimensions == ('nlon', 3, 'nlev', 5)
        for var in FindVariables().visit(nested_kernel_mod['compute'].body):
            if var.name.lower() == 'var':
                assert var.dimensions == (':', 1, ':', 1)
    else:
        if recurse_to_kernels:
            assert nested_kernel_var.shape == ('nlon', 3, 'nlev')
            assert nested_kernel_var.dimensions == ('nlon', 3, 'nlev')
        else:
            assert nested_kernel_var.shape == ('nlon', 'nlev')
            assert nested_kernel_var.dimensions == ('nlon', 'nlev')
        for var in FindVariables().visit(nested_kernel_mod['compute'].body):
            if var.name.lower() == 'var':
                if recurse_to_kernels:
                    assert var.dimensions == (':', 1, ':')
                else:
                    assert var.dimensions == (':', ':')
loki-ecmwf-0.3.6/loki/transformations/array_indexing/tests/test_array_demote.py0000664000175000017500000001474215167130205030371 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest
import numpy as np

from loki import Subroutine
from loki.jit_build import jit_compile
from loki.expression import symbols as sym
from loki.frontend import available_frontends

from loki.transformations.array_indexing.demote import demote_variables


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_demote_variables(tmp_path, frontend):
    """
    Apply variable demotion to a range of array variables.
    """
    fcode = """
subroutine transform_demote_variables(scalar, vector, matrix, n, m)
  implicit none
  integer, intent(in) :: n, m
  integer, intent(inout) :: scalar, vector(n), matrix(n, n)
  integer :: tmp_scalar, tmp_vector(n, m), tmp_matrix(n, m, n)
  integer :: jl, jk, jm

  do jl=1,n
    do jm=1,m
      tmp_vector(jl, jm) = scalar + jl
    end do
  end do

  do jm=1,m
    do jl=1,n
      scalar = jl
      vector(jl) = tmp_vector(jl, jm) + tmp_vector(jl, jm)

      do jk=1,n
        tmp_matrix(jk, jm, jl) = vector(jl) + jk
      end do
    end do
  end do

  do jk=1,n
    do jm=1,m
      do jl=1,n
        matrix(jk, jl) = tmp_matrix(jk, jm, jl)
      end do
    end do
  end do
end subroutine transform_demote_variables
    """.strip()
    routine = Subroutine.from_source(fcode, frontend=frontend)

    # Test the original implementation
    filepath = tmp_path/(f'{routine.name}_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname=routine.name)

    n = 3
    m = 2
    scalar = np.array(0)
    vector = np.zeros(shape=(n,), order='F', dtype=np.int32)
    matrix = np.zeros(shape=(n, n), order='F', dtype=np.int32)
    function(scalar, vector, matrix, n, m)

    assert scalar == 3
    assert np.all(vector == np.arange(1, n + 1)*2)
    assert np.all(matrix == np.sum(np.mgrid[1:4,2:8:2], axis=0))

    # Do the variable demotion for all relevant array variables
    demote_variables(routine, ['tmp_vector', 'tmp_matrix'], ['m'])

    assert isinstance(routine.variable_map['scalar'], sym.Scalar)
    assert isinstance(routine.variable_map['vector'], sym.Array)
    assert routine.variable_map['vector'].shape == (routine.variable_map['n'],)
    assert isinstance(routine.variable_map['tmp_vector'], sym.Array)
    assert routine.variable_map['tmp_vector'].shape == (routine.variable_map['n'],)
    assert isinstance(routine.variable_map['matrix'], sym.Array)
    assert routine.variable_map['matrix'].shape == (routine.variable_map['n'], routine.variable_map['n'])
    assert isinstance(routine.variable_map['tmp_matrix'], sym.Array)
    assert routine.variable_map['tmp_matrix'].shape == (routine.variable_map['n'], routine.variable_map['n'])

    # Test promoted routine
    demoted_filepath = tmp_path/(f'{routine.name}_demoted_{frontend}.f90')
    demoted_function = jit_compile(routine, filepath=demoted_filepath, objname=routine.name)

    n = 3
    m = 2
    scalar = np.array(0)
    vector = np.zeros(shape=(n,), order='F', dtype=np.int32)
    matrix = np.zeros(shape=(n, n), order='F', dtype=np.int32)
    demoted_function(scalar, vector, matrix, n, m)

    assert scalar == 3
    assert np.all(vector == np.arange(1, n + 1)*2)
    assert np.all(matrix == np.sum(np.mgrid[1:4,2:8:2], axis=0))

    # Test that the transformation doesn't fail for scalar arguments and leaves the
    # IR unchanged
    demoted_fcode = routine.to_fortran()
    demote_variables(routine, ['jl'], ['m'])
    assert routine.to_fortran() == demoted_fcode


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_demote_dimension_arguments(tmp_path, frontend):
    """
    Apply variable demotion to array arguments defined with DIMENSION
    keywords.
    """
    fcode = """
subroutine transform_demote_dimension_arguments(vec1, vec2, matrix, n, m)
    implicit none
    integer, intent(in) :: n, m
    integer, dimension(n), intent(inout) :: vec1, vec2
    integer, dimension(n, m), intent(inout) :: matrix
    integer, dimension(n) :: vec_tmp
    integer :: i, j

    do i=1,n
        do j=1,m
        vec_tmp(i) = vec1(i) + vec2(i)
        matrix(i, j) = matrix(i, j) + vec_tmp(i)
        end do
    end do
end subroutine transform_demote_dimension_arguments
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    # Test the original implementation
    filepath = tmp_path/(f'{routine.name}_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname=routine.name)

    assert isinstance(routine.variable_map['vec1'], sym.Array)
    assert routine.variable_map['vec1'].shape == (routine.variable_map['n'],)
    assert isinstance(routine.variable_map['vec2'], sym.Array)
    assert routine.variable_map['vec2'].shape == (routine.variable_map['n'],)
    assert isinstance(routine.variable_map['matrix'], sym.Array)
    assert routine.variable_map['matrix'].shape == (routine.variable_map['n'], routine.variable_map['m'])

    n = 3
    m = 2
    vec1 = np.zeros(shape=(n,), order='F', dtype=np.int32) + 3
    vec2 = np.zeros(shape=(n,), order='F', dtype=np.int32) + 2
    matrix = np.zeros(shape=(n, m), order='F', dtype=np.int32) + 1
    function(vec1, vec2, matrix, n, m)

    assert np.all(vec1 == 3) and np.sum(vec1) == 9
    assert np.all(vec2 == 2) and np.sum(vec2) == 6
    assert np.all(matrix == 6) and np.sum(matrix) == 36

    demote_variables(routine, ['vec1', 'vec_tmp', 'matrix'], ['n'])

    assert isinstance(routine.variable_map['vec1'], sym.Scalar)
    assert isinstance(routine.variable_map['vec2'], sym.Array)
    assert routine.variable_map['vec2'].shape == (routine.variable_map['n'],)
    assert isinstance(routine.variable_map['matrix'], sym.Array)
    assert routine.variable_map['matrix'].shape == (routine.variable_map['m'],)

    # Test promoted routine
    demoted_filepath = tmp_path/(f'{routine.name}_demoted_{frontend}.f90')
    demoted_function = jit_compile(routine, filepath=demoted_filepath, objname=routine.name)

    n = 3
    m = 2
    vec1 = np.array(3)
    vec2 = np.zeros(shape=(n,), order='F', dtype=np.int32) + 2
    matrix = np.zeros(shape=(m, ), order='F', dtype=np.int32) + 1
    demoted_function(vec1, vec2, matrix, n, m)

    assert vec1 == 3
    assert np.all(vec2 == 2) and np.sum(vec2) == 6
    assert np.all(matrix == 16) and np.sum(matrix) == 32
loki-ecmwf-0.3.6/loki/transformations/array_indexing/tests/test_vector_notation.py0000664000175000017500000004465515167130205031141 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest
import numpy as np

from loki import Module, Subroutine, Dimension
from loki.jit_build import jit_compile, jit_compile_lib, Builder, Obj
from loki.expression import symbols as sym
from loki.frontend import available_frontends, OMNI
from loki.ir import nodes as ir, FindNodes, FindVariables

from loki.transformations.array_indexing.vector_notation import (
    resolve_vector_notation, resolve_vector_dimension,
    remove_explicit_array_dimensions, add_explicit_array_dimensions
)


@pytest.fixture(scope='function', name='builder')
def fixture_builder(tmp_path):
    yield Builder(source_dirs=tmp_path, build_dir=tmp_path)
    Obj.clear_cache()


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_promote_resolve_vector_notation(tmp_path, frontend):
    """
    Apply and test resolve vector notation utility.
    """
    fcode = """
subroutine transform_resolve_vector_notation(ret1, ret2)
  implicit none
  integer, parameter :: param1 = 3
  integer, parameter :: param2 = 5
  integer, intent(out) :: ret1(param1, param1), ret2(param1, param2)
  integer :: tmp, jk

  ret1(:, :) = 11
  ret2(:, :) = 42

end subroutine transform_resolve_vector_notation
    """.strip()
    routine = Subroutine.from_source(fcode, frontend=frontend)
    resolve_vector_notation(routine)

    loops = FindNodes(ir.Loop).visit(routine.body)
    arrays = [var for var in FindVariables(unique=False).visit(routine.body) if isinstance(var, sym.Array)]

    assert len(loops) == 4
    assert loops[0].variable == 'i_ret1_1'
    assert loops[0].bounds == '1:param1' if frontend != OMNI else '1:3:1'
    assert loops[1].variable == 'i_ret1_0'
    assert loops[1].bounds == '1:param1' if frontend != OMNI else '1:3:1'
    assert loops[2].variable == 'i_ret2_1'
    assert loops[2].bounds == '1:param2' if frontend != OMNI else '1:5:1'
    assert loops[3].variable == 'i_ret2_0'
    assert loops[3].bounds == '1:param1' if frontend != OMNI else '1:3:1'

    assert len(arrays) == 2
    assert arrays[0].dimensions == ('i_ret1_0', 'i_ret1_1')
    assert arrays[1].dimensions == ('i_ret2_0', 'i_ret2_1')

    ret1 = np.zeros(shape=(3, 3), order='F', dtype=np.int32)
    ret2 = np.zeros(shape=(3, 5), order='F', dtype=np.int32)

    filepath = tmp_path/(f'{routine.name}_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname=routine.name)
    function(ret1, ret2)

    assert np.all(ret1 == 11)
    assert np.all(ret2 == 42)


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('kidia_loop', (True, False))
def test_transform_resolve_vector_notation_common_loops(tmp_path, frontend, kidia_loop):
    """
    Apply and test resolve vector notation utility with already
    available/appropriate loops.
    """
    fcode = f"""
subroutine transform_resolve_vector_notation_common_loops(scalar, vector, vector_2, matrix, n, m, l, kidia, kfdia)
  implicit none
  integer, intent(in) :: n, m, l, kidia, kfdia
  integer, intent(inout) :: scalar, vector(n), vector_2(n), matrix(l, n)
  integer :: tmp_scalar, tmp_vector(n, m), tmp_matrix(l, m, n), tmp_dummy(n, 0:4)
  integer :: jl, jk, jm

  tmp_dummy(:,:) = 0
  tmp_vector(:, 1) = tmp_dummy(:, 1)
  tmp_vector(:, :) = 0
  tmp_matrix(:, :, :) = 0
  matrix(:, :) = 0

  do jl={'kidia,kfdia' if kidia_loop else '1,n'}
    do jm=1,m
      tmp_vector(jl, jm) = scalar + jl
    end do
  end do

  do jm=1,m
    do jl=1,n
      scalar = jl
      vector(jl) = tmp_vector(jl, jm) + tmp_vector(jl, jm)

      do jk=1,l
        tmp_matrix(jk, jm, jl) = vector(jl) + jk
      end do
    end do
  end do


  do jk=1,l
    matrix(jk, :) = 0
    do jm=1,m
      do jl=1,n
        matrix(jk, jl) = tmp_matrix(jk, jm, jl)
      end do
    end do
  end do

  vector_2(:) = 1
  vector_2(kidia:kfdia) = 2

end subroutine transform_resolve_vector_notation_common_loops
    """.strip()
    routine = Subroutine.from_source(fcode, frontend=frontend)
    # Test the original implementation
    filepath = tmp_path/(f'{routine.name}_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname=routine.name)

    n = 3
    m = 2
    l = 3
    kidia = 1
    kfdia = n
    scalar = np.array(0)
    vector = np.zeros(shape=(n,), order='F', dtype=np.int32)
    vector_2 = np.zeros(shape=(n,), order='F', dtype=np.int32)
    matrix = np.zeros(shape=(n, n), order='F', dtype=np.int32)
    function(scalar, vector, vector_2, matrix, n, m, l, kidia, kfdia)

    assert scalar == 3
    assert np.all(vector == np.arange(1, n + 1)*2)
    assert np.all(matrix == np.sum(np.mgrid[1:4,2:8:2], axis=0))

    resolve_vector_notation(routine)
    loops = FindNodes(ir.Loop).visit(routine.body)
    arrays = [var for var in FindVariables(unique=False).visit(routine.body) if isinstance(var, sym.Array)]
    assert len(loops) == 21
    assert loops[0].variable == 'i_tmp_dummy_1' and loops[0].bounds == '0:4'
    assert loops[1].variable == 'jl' and loops[1].bounds == '1:n'
    assert loops[2].variable == 'jl' and loops[2].bounds == '1:n'
    assert loops[3].variable == 'jm' and loops[3].bounds == '1:m'
    assert loops[4].variable == 'jl' and loops[4].bounds == '1:n'
    assert loops[5].variable == 'jl' and loops[5].bounds == '1:n'
    assert loops[6].variable == 'jm' and loops[6].bounds == '1:m'
    assert loops[7].variable == 'jk' and loops[7].bounds == '1:l'
    assert loops[8].variable == 'jl' and loops[8].bounds == '1:n'
    assert loops[9].variable == 'jk' and loops[9].bounds == '1:l'
    assert loops[10].variable == 'jl'
    if kidia_loop:
        assert loops[10].bounds == 'kidia:kfdia'
    else:
        assert loops[10].bounds == '1:n'
    assert loops[11].variable == 'jm' and loops[11].bounds == '1:m'
    assert loops[12].variable == 'jm' and loops[12].bounds == '1:m'
    assert loops[13].variable == 'jl' and loops[13].bounds == '1:n'
    assert loops[14].variable == 'jk' and loops[14].bounds == '1:l'
    assert loops[15].variable == 'jk' and loops[15].bounds == '1:l'
    assert loops[16].variable == 'jl' and loops[16].bounds == '1:n'
    assert loops[17].variable == 'jm' and loops[17].bounds == '1:m'
    assert loops[18].variable == 'jl' and loops[18].bounds == '1:n'
    assert loops[19].variable == 'jl' and loops[19].bounds == '1:n'
    if kidia_loop:
        assert loops[20].variable == 'jl'
        assert loops[20].bounds == 'kidia:kfdia'
    else:
        assert loops[20].variable == 'i_vector_2_0'
        assert loops[20].bounds == 'kidia:kfdia'

    assert len(arrays) == 17
    assert arrays[0].name.lower() == 'tmp_dummy' and arrays[0].dimensions == ('jl', 'i_tmp_dummy_1')
    assert arrays[1].name.lower() == 'tmp_vector' and arrays[1].dimensions == ('jl', 1)
    assert arrays[2].name.lower() == 'tmp_dummy' and arrays[2].dimensions == ('jl', 1)
    assert arrays[3].name.lower() == 'tmp_vector' and arrays[3].dimensions == ('jl', 'jm')
    assert arrays[4].name.lower() == 'tmp_matrix' and arrays[4].dimensions == ('jk', 'jm', 'jl')
    assert arrays[15].name.lower() == 'vector_2' and arrays[15].dimensions == ('jl',)
    assert arrays[16].name.lower() == 'vector_2'
    if kidia_loop:
        assert arrays[16].dimensions == ('jl',)
    else:
        assert arrays[16].dimensions == ('i_vector_2_0',)

    # Test promoted routine
    resolved_filepath = tmp_path/(f'{routine.name}_resolved_{frontend}.f90')
    resolved_function = jit_compile(routine, filepath=resolved_filepath, objname=routine.name)

    n = 3
    m = 2
    l = 3
    kidia = 1
    kfdia = n
    scalar = np.array(0)
    vector = np.zeros(shape=(n,), order='F', dtype=np.int32)
    vector_2 = np.zeros(shape=(n,), order='F', dtype=np.int32)
    matrix = np.zeros(shape=(n, n), order='F', dtype=np.int32)
    resolved_function(scalar, vector, vector_2, matrix, n, m, l, kidia, kfdia)

    assert scalar == 3
    assert np.all(vector == np.arange(1, n + 1)*2)
    assert np.all(matrix == np.sum(np.mgrid[1:4,2:8:2], axis=0))


@pytest.mark.parametrize('frontend', available_frontends(skip=[(OMNI, 'OMNI does not like missing information')]))
def test_transform_inline_call_resolve_vector_notation(frontend):
    """
    Apply and test resolve vector notation utility to not apply to a inline call
    although Loki needs to assume it is an array.
    """
    fcode = """
subroutine transform_resolve_vector_notation_inline_call(x)
  use some_mod, only: some_func
  implicit none
  integer, parameter :: param1 = 3
  integer, parameter :: param2 = 5
  integer, intent(in) :: x(param1, param2)

  ! should stay like that
  tmp = some_func(ret1(1, 1))

end subroutine transform_resolve_vector_notation_inline_call
    """.strip()
    routine = Subroutine.from_source(fcode, frontend=frontend)
    resolve_vector_notation(routine)
    var_map = {var.name.lower(): var for var in FindVariables(unique=False).visit(routine.body)
            if isinstance(var, sym.Array)}
    # Fortran's questionable choice of having the same syntax for a inline call and array access ...
    assert 'some_func' in var_map
    assert var_map['some_func'].dimensions == ('ret1(1, 1)',)


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('calls_only', (False, True))
def test_transform_explicit_dimensions(tmp_path, frontend, builder, calls_only):
    """
    Test making dimensions of arrays explicit and undoing this,
    thus removing colon notation from array dimensions either for all
    or for arrays within (inline) calls only.
    """
    fcode_driver = """
  SUBROUTINE driver_routine(nlon, nlev, a, b)
    use kernel_explicit_dimensions_mod, only: kernel_routine
    INTEGER, INTENT(IN)    :: nlon, nlev
    INTEGER, INTENT(INOUT) :: a(nlon,nlev)
    INTEGER, INTENT(INOUT)  :: b(nlon,nlev)

    call kernel_routine(nlon, a, b=b, nlev=nlev)

  END SUBROUTINE driver_routine
    """

    fcode_kernel = """
  module kernel_explicit_dimensions_mod
  IMPLICIT NONE
  CONTAINS
  SUBROUTINE kernel_routine(nlon, a, b, nlev)
    INTEGER, INTENT(IN)    :: nlon, nlev
    INTEGER, INTENT(INOUT) :: a(nlon,nlev)
    INTEGER, INTENT(INOUT) :: b(nlon,nlev)

    A = MYADD(A, B=B)
  END SUBROUTINE kernel_routine

  PURE ELEMENTAL FUNCTION MYADD(A, B)
    INTEGER :: MYADD
    INTEGER, INTENT(IN) :: A, B

    MYADD = A + B
  END FUNCTION
  end module kernel_explicit_dimensions_mod
    """

    def init_arguments(nlon, nlev):
        a = 2*np.ones(shape=(nlon,nlev,), order='F', dtype=np.int32)
        b = 3*np.ones(shape=(nlon,nlev,), order='F', dtype=np.int32)
        return a, b

    kernel_module = Module.from_source(fcode_kernel, frontend=frontend, xmods=[tmp_path])
    driver = Subroutine.from_source(fcode_driver, frontend=frontend, xmods=[tmp_path],
                                     definitions=[kernel_module])
    kernel = kernel_module.subroutines[0]

    # compile and test reference
    refname = f'ref_explicit_dims_{driver.name}_{frontend}'
    reference = jit_compile_lib([kernel_module, driver], path=tmp_path, name=refname, builder=builder)
    ref_function = reference.driver_routine

    nlon = 10
    nlev = 12
    a_ref, b_ref = init_arguments(nlon, nlev)
    ref_function(nlon, nlev, a_ref, b_ref)
    builder.clean()

    # add explicit array dimensions
    add_explicit_array_dimensions(driver)
    add_explicit_array_dimensions(kernel)
    kernel_call = FindNodes(ir.CallStatement).visit(driver.body)[0]
    kernel_call_array_args = [arg for arg in kernel_call.arguments if isinstance(arg, sym.Array)]
    assert all(len(arg.dimensions) == 2 for arg in kernel_call_array_args)

    # remove explicit array dimensions (possibly only for calls)
    remove_explicit_array_dimensions(driver, calls_only=calls_only)
    remove_explicit_array_dimensions(kernel, calls_only=calls_only)

    kernel_call = FindNodes(ir.CallStatement).visit(driver.body)[0]
    kernel_call_array_args = [arg for arg in kernel_call.arguments if isinstance(arg, sym.Array)]
    assert all(not arg.dimensions for arg in kernel_call_array_args)
    if calls_only:
        assignments = FindNodes(ir.Assignment).visit(kernel.body)
        assert len(assignments) == 1
        assert len(assignments[0].lhs.dimensions) == 2
        parameters = (assignments[0].rhs.parameters[0],)
        parameters += (assignments[0].rhs.kwarguments[0][1],)
        assert not parameters[0].dimensions
        assert not parameters[1].dimensions
    else:
        kernel_arrays = FindVariables().visit(kernel.body)
        assert all(not arr.dimensions for arr in kernel_arrays)

    # compile and test the resulting code
    testname = f'test_explicit_dims_{"calls_only_" if calls_only else ""}_{driver.name}_{frontend}'
    test = jit_compile_lib([kernel_module, driver], path=tmp_path, name=testname, builder=builder)
    test_function = test.driver_routine

    a_test, b_test = init_arguments(nlon, nlev)
    test_function(nlon, nlev, a_test, b_test)

    # check whether reference and flattened variant(s) produce same result
    assert (a_test == a_ref).all()
    assert (b_test == b_ref).all()

    builder.clean()


@pytest.mark.parametrize('frontend', available_frontends())
def test_resolve_vector_dimension(frontend):
    """ Test vector resolution utility for a single dimension. """

    fcode = """
subroutine kernel(start, end, nlon, nlev, z, work, play, sleep, repeat)
  integer, intent(in) :: start, end, nlon, nlev
  real, intent(in) :: z
  real, intent(out) :: work(nlon), play(nlon, nlev), sleep(nlev,nlev), repeat(nlev,nlon)
  integer :: jl
  real :: work_maxval

  work(start:end) = 0.
  work_maxval = maxval(work(start:end))

  play(:,1:nlev) = 42.
  sleep(:, :) = z * z * z
  repeat(:,start:end) = 6.66
end subroutine kernel
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    horizontal = Dimension(name='horizontal', index='jl', lower='start', upper='end')
    resolve_vector_dimension(routine, dimension=horizontal)

    loops = FindNodes(ir.Loop).visit(routine.body)
    assert len(loops) == 2

    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assigns) == 5

    # Check that the first expression has been wrapped
    assert assigns[0] in loops[0].body
    assert assigns[0].lhs == 'work(jl)'

    # Ensure that none of the other sections has been wrapped
    assert not assigns[1] in loops[0].body
    assert not assigns[1] in loops[1].body
    assert 'maxval' == assigns[1].rhs.name.lower()
    assert 'start:end' in assigns[1].rhs.parameters[0].dimensions

    assert not assigns[2] in loops[0].body
    assert not assigns[2] in loops[1].body
    assert assigns[2].lhs == 'play(:,1:nlev)'

    assert not assigns[3] in loops[0].body
    assert not assigns[3] in loops[1].body
    assert assigns[3].lhs == 'sleep(:,:)'

    # Check that the last expression has been partially wrapped
    assert assigns[4] in loops[1].body
    assert assigns[4].lhs == 'repeat(:,jl)'


@pytest.mark.parametrize('frontend', available_frontends())
def test_resolve_masked_statements(frontend):
    """
    Test resolving of masked statements in kernel.
    """

    fcode = """
subroutine test_resolve_where(start, end, nlon, nz, q, t)
  INTEGER, INTENT(IN) :: start, end  ! Iteration indices
  INTEGER, INTENT(IN) :: nlon, nz    ! Size of the horizontal and vertical
  REAL, INTENT(INOUT) :: t(nlon,nz)
  REAL, INTENT(INOUT) :: q(nlon,nz)
  INTEGER :: jk

  DO jk = 2, nz
    WHERE (q(start:end, jk) > 1.234)
      q(start:end, jk) = q(start:end, jk-1) + t(start:end, jk)
    ELSEWHERE
      q(start:end, jk) = t(start:end, jk)
    END WHERE
  END DO
end subroutine test_resolve_where
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    horizontal = Dimension(
        name='horizontal', index='jl', lower='start', upper='end'
    )
    resolve_vector_dimension(routine, dimension=horizontal)

    # Ensure horizontal loop variable has been declared
    assert 'jl' in routine.variables

    # Ensure we have three loops in the kernel,
    # horizontal loops should be nested within vertical
    loops = FindNodes(ir.Loop).visit(routine.body)
    assert len(loops) == 2
    assert loops[1] in FindNodes(ir.Loop).visit(loops[0].body)
    assert loops[1].variable == 'jl'
    assert loops[1].bounds == 'start:end'
    assert loops[0].variable == 'jk'
    assert loops[0].bounds == '2:nz'

    # Ensure that the respective conditional has been inserted correctly
    conds = FindNodes(ir.Conditional).visit(routine.body)
    assert len(conds) == 1
    assert conds[0] in FindNodes(ir.Conditional).visit(loops[1])
    assert conds[0].condition == 'q(jl, jk) > 1.234'

    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assigns) == 2
    assert assigns[0] in conds[0].body
    assert assigns[0].lhs == 'q(jl, jk)' and assigns[0].rhs == 'q(jl, jk - 1) + t(jl, jk)'
    assert assigns[1] in conds[0].else_body
    assert assigns[1].lhs == 'q(jl, jk)' and assigns[1].rhs == 't(jl, jk)'


@pytest.mark.parametrize('frontend', available_frontends())
def test_resolve_masked_inferred_bounds(frontend):
    """ Test the resolution of WHERE stmts with inferred bounds """

    fcode = """
subroutine test_masked_inferred(n, m, x, y, z)
  implicit none
  integer, intent(in) :: n, m
  real(kind=8), intent(inout) :: x(n), y(n), z(m)
  integer :: i

  do i=1,n
    x(i) = i
  end do
  y(:) = 0.0
  z(:) = 0.0

  where( (x > 5.0) )
    x = y
  end where
end subroutine test_masked_inferred
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    dim = Dimension(name='n', index='i', lower='1', upper='n')
    resolve_vector_dimension(
        routine, dimension=dim, derive_qualified_ranges=True
    )

    # Check only assignments over ``n`` have been resolved
    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assigns) == 4
    assert assigns[0].lhs == 'x(i)' and assigns[0].rhs == 'i'
    assert assigns[1].lhs == 'y(i)' and assigns[1].rhs == '0.0'
    assert assigns[2].lhs == 'z(1:m)' and assigns[2].rhs == '0.0'
    assert assigns[3].lhs == 'x(i)' and assigns[3].rhs == 'y(i)'

    # Check the WHERE has been resolved to IF
    conds = FindNodes(ir.Conditional).visit(routine.body)
    assert len(conds) == 1
    assert conds[0].condition == 'x(i) > 5.0'
    assert assigns[3] in conds[0].body

    # Check that new loops have been inserted
    loops = FindNodes(ir.Loop).visit(routine.body)
    assert len(loops) == 3
    assert assigns[0] in loops[0].body
    assert assigns[1] in loops[1].body
    assert conds[0] in loops[2].body
loki-ecmwf-0.3.6/loki/transformations/array_indexing/promote.py0000664000175000017500000003106515167130205025177 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Utilities for promoting the rank of array variables.
"""
from collections import defaultdict
import operator as op

from loki.analyse import dataflow_analysis_attached
from loki.expression import symbols as sym, simplify, symbolic_op
from loki.ir import (
    nodes as ir, FindNodes, Transformer, FindVariables,
    SubstituteExpressions
)
from loki.logging import info
from loki.tools import as_tuple, OrderedSet


__all__ = [
    'promote_variables', 'promote_nonmatching_variables',
    'promotion_dimensions_from_loop_nest',
]


def promote_variables(routine, variable_names, pos, index=None, size=None):
    """
    Promote a list of variables by inserting new array dimensions of given size
    and updating all uses of these variables with a given index expression.

    When providing only :data:`size` or :data:`index`, promotion is restricted
    to updating only variable declarations or their use, respectively, and the
    other is left unchanged.

    Parameters
    ----------
    routine : :any:`Subroutine`
        The subroutine in which the variables should be promoted.
    variable_names : list of str
        The names of variables to be promoted. Matching of variables against
        names is case-insensitive.
    pos : int
        The position of the new array dimension using Python indexing
        convention (i.e., count from 0 and use negative values to count from
        the end).
    index : :py:class:`pymbolic.primitives.Expression`, optional
        The indexing expression (or a tuple for multi-dimension promotion)
        to use for the promotion dimension(s), e.g., loop variables. Usage of
        variables is only updated if `index` is provided. When the index
        expression is not live at the variable use, ``:`` is used instead.
    size : :py:class:`pymbolic.Expression`, optional
        The size of the dimension (or tuple for multi-dimension promotion) to
        insert at `pos`. When this is provided, the declaration of variables
        is updated accordingly.
    """
    variable_names = {name.lower() for name in variable_names}

    if not variable_names:
        return

    # Insert new index dimension
    if index is not None:
        index = as_tuple(index)
        index_vars = [OrderedSet(FindVariables().visit(i)) for i in index]

        # Create a copy of the tree and apply promotion in-place
        routine.body = Transformer().visit(routine.body)

        with dataflow_analysis_attached(routine):
            for node, var_list in FindVariables(unique=False, with_ir_node=True).visit(routine.body):
                # All the variables marked for promotion that appear in this IR node
                var_list = [v for v in var_list if v.name.lower() in variable_names]

                if not var_list:
                    continue

                # We use the given index expression in this node if all
                # variables therein are defined, otherwise we use `:`
                node_index = tuple(i if v <= node.live_symbols else sym.RangeIndex((None, None))
                                   for i, v in zip(index, index_vars))

                var_map = {}
                for var in var_list:
                    # If the position is given relative to the end we convert it to
                    # a positive index
                    if hasattr(var, 'dimensions'):
                        var_dim = var.dimensions
                    else:
                        var_dim = ()
                    if pos < 0:
                        var_pos = len(var_dim) - pos + 1
                    else:
                        var_pos = pos

                    dimensions = as_tuple(var_dim[:var_pos] + node_index + var_dim[var_pos:])
                    var_map[var] = var.clone(dimensions=dimensions)

                # need to apply update immediately because identical variable use
                # in other nodes might yield same hash but different substitution
                SubstituteExpressions(var_map, inplace=True).visit(node)

    # Apply shape promotion
    if size is not None:
        size = as_tuple(size)

        var_list = [var for decl in FindNodes(ir.VariableDeclaration).visit(routine.spec)
                    for var in decl.symbols if var.name.lower() in variable_names]

        var_shapes = [getattr(var, 'shape', ()) for var in var_list]
        if pos < 0:
            var_pos = [len(shape) - pos + 1 for shape in var_shapes]
        else:
            var_pos = [pos] * len(var_shapes)
        var_shapes = [d[:p] + size + d[p:] for d, p in zip(var_shapes, var_pos)]

        var_map = {v: v.clone(type=v.type.clone(shape=shape), dimensions=shape)
                   for v, shape in zip(var_list, var_shapes)}
        routine.spec = SubstituteExpressions(var_map).visit(routine.spec)


def promotion_dimensions_from_loop_nest(var_names, loops, promotion_vars_dims, promotion_vars_index):
    """
    Determine promotion dimensions corresponding to the iteration space of a loop nest.

    Parameters
    ----------
    var_names : list of str
        The names of variables to consider for promotion.
    loops : list of :any:`Loop`
        The list of nested loops, sorted from outermost to innermost.
    promotion_vars_dims : dict((str, tuple))
        The mapping of variable names to promotion dimensions. When determining
        promotion dimensions for the variables in :data:`var_names` this dict is
        checked for already existing promotion dimensions and, if not matching,
        the maximum of both is taken for each dimension.
    promotion_vars_index : dict((str, tuple))
        The mapping of variable names to subscript expressions. These expressions
        are later inserted for every variable use. When the indexing expression
        for the loop nest does not match the existing expression in this dict,
        a :any:`RuntimeError` is raised.

    Returns
    -------
    (:data:`promotion_vars_dims`, :data:`promotion_vars_dims`) : tuple of dict
        The updated mappings :data:`promotion_vars_dims` and :data:`promotion_vars_index`.

    """
    # TODO: Would be nice to be able to promote this to the smallest possible dimension
    #       (in a loop var=start,end this is (end-start+1) with subscript index (var-start+1))
    #       but it requires being able to decide whether this yields a constant dimension,
    #       thus we need to stick to the upper bound for the moment as this is constant
    #       in our use cases.
    loop_lengths = [simplify(loop.bounds.stop) for loop in reversed(loops)]
    loop_index = [loop.variable for loop in reversed(loops)]

    def _merge_dims_and_index(dims_a, index_a, dims_b, index_b, var_name):
        """
        Helper routine that takes two pairs of promotion dimensions and indices
        (let's call them a and b) and tries to merge them to form the promotion
        configuration that accomodates both.
        """
        # Let's assume we have the same or more promotion dimensions in b than in a
        if len(dims_b) < len(dims_a):
            return _merge_dims_and_index(dims_b, index_b, dims_a, index_a, var_name)  # pylint: disable=arguments-out-of-order

        # We identify each dimension by the index expression; therefore, we have
        # to merge them first
        new_index = []
        ptr_a, ptr_b = 0, 0
        while ptr_a < len(index_a) and ptr_b < len(index_b):
            # Let's see if the next index in a can be found somewhere in b
            try:
                a_in_b = index_b.index(index_a[ptr_a], ptr_b)
            except ValueError:
                a_in_b = None

            if a_in_b is None:
                # It's not in there, so just add it to the new index
                # and go to the next
                new_index += [index_a[ptr_a]]
                ptr_a += 1
            else:
                # Found a in b: add it and anything before from b
                new_index += index_b[ptr_b:a_in_b+1]
                ptr_a += 1
                ptr_b = a_in_b + 1

            # Skip any indices we have already dealt with
            while ptr_a < len(index_a) and index_a[ptr_a] in new_index:
                ptr_a += 1
            while ptr_b < len(index_b) and index_b[ptr_b] in new_index:
                ptr_b += 1

        # Add any remaining indices in a and b
        if ptr_a < len(index_a):
            assert ptr_b == len(index_b)
            new_index += index_a[ptr_a:]
        else:
            assert ptr_a == len(index_a)
            new_index += index_b[ptr_b:]

        # With the merged index in place, we need to go through each corresponding
        # dimension from a and b and pick the larger
        new_dims = []
        for idx in new_index:
            # Look for position of that index in a and b
            try:
                ptr_a = index_a.index(idx)
            except ValueError:
                ptr_a = None
            try:
                ptr_b = index_b.index(idx)
            except ValueError:
                ptr_b = None

            if ptr_a is None:
                # exists only in b
                new_dims += [dims_b[ptr_b]]
            elif ptr_b is None:
                # exists only in a
                new_dims += [dims_a[ptr_a]]
            else:
                # exists in both: pick the larger
                if symbolic_op(dims_a[ptr_a], op.lt, dims_b[ptr_b]):
                    new_dims += [dims_b[ptr_b]]
                else:
                    new_dims += [dims_a[ptr_a]]

        # ... and we're done: return the new dimensions and index
        return new_dims, new_index

    for var_name in var_names:
        # Check if we have already marked this variable for promotion: let's make sure the added
        # dimensions are large enough for this loop (nest)
        if var_name not in promotion_vars_dims:
            promotion_vars_dims[var_name] = loop_lengths
            promotion_vars_index[var_name] = loop_index
        else:
            promotion_vars_dims[var_name], promotion_vars_index[var_name] = \
                _merge_dims_and_index(promotion_vars_dims[var_name], promotion_vars_index[var_name],
                                      loop_lengths, loop_index, var_name)

    return promotion_vars_dims, promotion_vars_index


def promote_nonmatching_variables(routine, promotion_vars_dims, promotion_vars_index):
    """
    Promote multiple variables with potentially non-matching promotion
    dimensions or index expressions.

    This is a convenience routine for using :meth:`promote_variables` that
    groups variables by indexing expression and promotion dimensions to
    reduce the number of calls to :meth:`promote_variables`.

    Parameters
    ----------
    routine : any:`Subroutine`
        The subroutine to be modified.
    promotion_vars_dims : dict
        The mapping of variable names to promotion dimensions. The variables'
        shapes are expanded where necessary to have at least these dimensions.
    promotion_vars_index : dict
        The mapping of variable names to subscript expressions to be used
        whenever reading/writing the variable.
    """
    if not promotion_vars_dims:
        return

    variable_map = routine.variable_map

    # First, let's find out what dimensions we actually need to add
    for var_name in promotion_vars_dims:
        shape = variable_map[var_name].type.shape
        if shape is None:
            continue

        # Eliminate 1:n declared shapes (mostly thanks to OMNI)
        shape = [s.stop if isinstance(s, sym.Range) and s.start == 1 else s for s in shape]

        dims = []
        index = []
        for dim, idx in zip(promotion_vars_dims[var_name], promotion_vars_index[var_name]):
            if not any(symbolic_op(dim, op.eq, d) for d in shape):
                dims += [dim]
                index += [idx]
        promotion_vars_dims[var_name] = dims
        promotion_vars_index[var_name] = index

    # Group promotion variables by index and size to reduce number of traversals for promotion
    index_size_var_map = defaultdict(list)
    for var_name, size in promotion_vars_dims.items():
        index_size_var_map[(as_tuple(promotion_vars_index[var_name]), as_tuple(size))] += [var_name]
    for (index, size), var_names in index_size_var_map.items():
        promote_variables(routine, var_names, -1, index=index, size=size)
    info('%s: promoted variable(s): %s', routine.name, ', '.join(promotion_vars_dims.keys()))
loki-ecmwf-0.3.6/loki/transformations/array_indexing/array_indices.py0000664000175000017500000004303615167130205026327 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

""" Utilities to change indices and indexing in array expressions. """

from loki.batch import Transformation, ProcedureItem
from loki.expression import symbols as sym, simplify, is_constant
from loki.ir import (
    nodes as ir, FindNodes, FindVariables, SubstituteExpressions
)
from loki.tools import as_tuple
from loki.transformations.inline import inline_constant_parameters


__all__ = [
    'shift_to_zero_indexing', 'invert_array_indices',
    'normalize_range_indexing', 'flatten_arrays',
    'normalize_array_shape_and_access', 'LowerConstantArrayIndices',
]


def shift_to_zero_indexing(routine, ignore=None):
    """
    Shift all array indices to adjust to 0-based indexing conventions (eg. for C or Python)

    Parameters
    ----------
    routine : :any:`Subroutine`
        The subroutine in which the array dimensions should be shifted
    ignore : list of str
        List of variable names for which, if found in the dimension expression
        of an array subscript, that dimension is not shifted to zero.
    """
    ignore = as_tuple(ignore)
    vmap = {}
    for v in FindVariables(unique=False).visit(routine.body):
        if isinstance(v, sym.Array):
            new_dims = []
            for d in v.dimensions:
                if isinstance(d, sym.RangeIndex):
                    start = d.start - sym.Literal(1) if d.start is not None else None
                    # no shift for stop because Python ranges are [start, stop)
                    new_dims += [sym.RangeIndex((start, d.stop, d.step))]
                else:
                    if ignore and any(var in ignore for var in FindVariables().visit(d)):
                        new_dims += [d]
                    else:
                        new_dims += [d - sym.Literal(1)]
            vmap[v] = v.clone(dimensions=as_tuple(new_dims))
    routine.body = SubstituteExpressions(vmap).visit(routine.body)


def invert_array_indices(routine):
    """
    Invert data/loop accesses from column to row-major

    TODO: Take care of the indexing shift between C and Fortran.
    Basically, we are relying on the CGen to shift the iteration
    indices and dearly hope that nobody uses the index's value.
    """
    # Invert array indices in the routine body
    vmap = {}
    for v in FindVariables(unique=True).visit(routine.body):
        if isinstance(v, sym.Array):
            rdim = as_tuple(reversed(v.dimensions))
            vmap[v] = v.clone(dimensions=rdim)
    routine.body = SubstituteExpressions(vmap).visit(routine.body)

    # Invert variable and argument dimensions for the automatic cast generation
    for v in routine.variables:
        if isinstance(v, sym.Array):
            rdim = as_tuple(reversed(v.dimensions))
            if v.shape:
                rshape = as_tuple(reversed(v.shape))
                vmap[v] = v.clone(dimensions=rdim, type=v.type.clone(shape=rshape))
            else:
                vmap[v] = v.clone(dimensions=rdim)
    routine.variables = [vmap.get(v, v) for v in routine.variables]


def normalize_range_indexing(routine):
    """
    Replace the ``(1:size)`` indexing in array sizes that OMNI introduces.
    """
    def is_one_index(dim):
        return isinstance(dim, sym.RangeIndex) and dim.lower == 1 and dim.step is None

    vmap = {}
    for v in routine.variables:
        if isinstance(v, sym.Array):
            new_dims = [d.upper if is_one_index(d) else d for d in v.dimensions]
            new_shape = [d.upper if is_one_index(d) else d for d in v.shape]
            new_type = v.type.clone(shape=as_tuple(new_shape))
            vmap[v] = v.clone(dimensions=as_tuple(new_dims), type=new_type)
    routine.variables = [vmap.get(v, v) for v in routine.variables]


def flatten_arrays(routine, order='F', start_index=1):
    """
    Flatten arrays, converting multi-dimensional arrays to
    one-dimensional arrays.

    Parameters
    ----------
    routine : :any:`Subroutine`
        The subroutine in which the variables should be promoted.
    order : str
        Assume Fortran (F) vs. C memory/array order.
    start_index : int
        Assume array indexing starts with `start_index`.
    """
    def new_dims(dim, shape):
        if all(_dim == sym.RangeIndex((None, None)) for _dim in dim):
            return None
        if len(dim) > 1:
            if isinstance(shape[-2], sym.RangeIndex):
                raise TypeError(f'Resolve shapes being of type RangeIndex, e.g., "{shape[-2]}" before flattening!')
            _dim = (sym.Sum((dim[-2], sym.Product((shape[-2], dim[-1] - start_index)))),)
            new_dim = dim[:-2]
            new_dim += _dim
            return new_dims(new_dim, shape[:-1])
        return dim

    if order == 'C':
        array_map = {
            var: var.clone(dimensions=new_dims(var.dimensions[::-1], var.shape[::-1]))
            for var in FindVariables().visit(routine.body)
            if isinstance(var, sym.Array) and var.shape and len(var.shape)
        }
    elif order == 'F':
        array_map = {
            var: var.clone(dimensions=new_dims(var.dimensions, var.shape))
            for var in FindVariables().visit(routine.body)
            if isinstance(var, sym.Array) and var.shape and len(var.shape)
        }
    else:
        raise ValueError(f'Unsupported array order "{order}"')

    routine.body = SubstituteExpressions(array_map).visit(routine.body)

    routine.variables = [v.clone(dimensions=as_tuple(sym.Product(v.shape)),
                                 type=v.type.clone(shape=as_tuple(sym.Product(v.shape))))
                         if isinstance(v, sym.Array) else v for v in routine.variables]


def normalize_array_shape_and_access(routine):
    """
    Shift all arrays to start counting at "1"
    """
    def is_explicit_range_index(dim):
        # return False if assumed sized array or lower dimension equals to 1
        # return (isinstance(dim, sym.RangeIndex) and not dim.lower == 1 and not dim is None
        #             and not dim.lower is None and not dim.upper is None)
        return (isinstance(dim, sym.RangeIndex)
                and not (dim.lower == 1 or dim.lower is None or dim.upper is None))

    vmap = {}
    for v in FindVariables(unique=False).visit(routine.body):
        if isinstance(v, sym.Array):
            # skip if e.g., `array(len)`, passed as `call routine(array)`
            if not v.dimensions:
                continue
            new_dims = []
            for i, d in enumerate(v.shape):
                if is_explicit_range_index(d):
                    if isinstance(v.dimensions[i], sym.RangeIndex):
                        start = simplify(v.dimensions[i].start - d.start + 1) if d.start is not None else None
                        stop = simplify(v.dimensions[i].stop - d.start + 1) if d.stop is not None else None
                        new_dims += [sym.RangeIndex((start, stop, d.step))]
                    else:
                        start = simplify(v.dimensions[i] - d.start + 1) if d.start is not None else None
                        new_dims += [start]
                else:
                    new_dims += [v.dimensions[i]]
            if new_dims:
                vmap[v] = v.clone(dimensions=as_tuple(new_dims))
    routine.body = SubstituteExpressions(vmap).visit(routine.body)

    vmap = {}
    for v in routine.variables:
        if isinstance(v, sym.Array):
            new_dims = [sym.RangeIndex((1, simplify(d.upper - d.lower + 1)))
                if is_explicit_range_index(d) else d for d in v.dimensions]
            new_shape = [sym.RangeIndex((1, simplify(d.upper - d.lower + 1)))
                if is_explicit_range_index(d) else d for d in v.shape]
            new_type = v.type.clone(shape=as_tuple(new_shape))
            vmap[v] = v.clone(dimensions=as_tuple(new_dims), type=new_type)
    routine.variables = [vmap.get(v, v) for v in routine.variables]
    normalize_range_indexing(routine)


class LowerConstantArrayIndices(Transformation):
    """
    A transformation to pass/lower constant array indices down the call tree.

    For example, the following code:

    .. code-block:: fortran

      subroutine driver(...)
        real, intent(inout) :: var(nlon,nlev,5,nb)
        do ibl=1,10
          call kernel(var(:, :, 1, ibl), var(:, :, 2:5, ibl))
        end do
      end subroutine driver

      subroutine kernel(var1, var2)
        real, intent(inout) :: var1(nlon, nlev)
        real, intent(inout) :: var2(nlon, nlev, 4)
        var1(:, :) = ...
        do jk=1,nlev
          do jl=1,nlon
            var1(jl, jk) = ...
            do jt=1,4
              var2(jl, jk, jt) = ...
            enddo
          enddo
        enddo
      end subroutine kernel

    is transformed to:

    .. code-block:: fortran

      subroutine driver(...)
        real, intent(inout) :: var(nlon,nlev,5,nb)
        do ibl=1,10
          call kernel(var(:, :, :, ibl), var(:, :, :, ibl))
        end do
      end subroutine driver

      subroutine kernel(var1, var2)
        real, intent(inout) :: var1(nlon, nlev, 5)
        real, intent(inout) :: var2(nlon, nlev, 5)
        var1(:, :, 1) = ...
        do jk=1,nlev
          do jl=1,nlon
            var1(jl, jk, 1) = ...
            do jt=1,4
              var2(jl, jk, jt + 2 + -1) = ...
            enddo
          enddo
        enddo
      end subroutine kernel

    Parameters
    ----------
    recurse_to_kernels: bool
        Recurse to kernels, thus lower constant array indices below the driver level for nested
        kernel calls (default: `True`).
    inline_external_only: bool
        Inline only external constant expressions or all of them (default: `False`)
    """

    # This trafo only operates on procedures
    item_filter = (ProcedureItem,)

    def __init__(self, recurse_to_kernels=True, inline_external_only=True):
        self.recurse_to_kernels = recurse_to_kernels
        self.inline_external_only = inline_external_only

    @staticmethod
    def explicit_dimensions(routine):
        """
        Make dimensions of arrays explicit within :any:`Subroutine` ``routine``.
        E.g., convert two-dimensional array ``arr2d`` to ``arr2d(:,:)`` or
        ``arr3d`` to ``arr3d(:,:,:)``.

        Parameters
        ----------
        routine: :any:`Subroutine`
            The subroutine to check
        """
        arrays = [var for var in FindVariables(unique=False).visit(routine.body) if isinstance(var, sym.Array)]
        array_map = {}
        for array in arrays:
            if not array.dimensions:
                new_dimensions = (sym.RangeIndex((None, None)),) * len(array.shape)
                array_map[array] = array.clone(dimensions=new_dimensions)
        routine.body = SubstituteExpressions(array_map).visit(routine.body)

    @staticmethod
    def is_constant_dim(dim):
        """
        Check whether dimension dim is constant, thus, either a constant
        value or a constant range index.

        Parameters
        ----------
        dim: :py:class:`pymbolic.primitives.Expression`
        """
        if is_constant(dim):
            return True
        if isinstance(dim, sym.RangeIndex)\
                and all(child is not None and is_constant(child) for child in dim.children[:-1]):
            return True
        return False

    def transform_subroutine(self, routine, **kwargs):
        role = kwargs['role']
        targets = tuple(str(t).lower() for t in as_tuple(kwargs.get('targets', None)))
        if role == 'driver' or self.recurse_to_kernels:
            inline_constant_parameters(routine, external_only=self.inline_external_only)
            self.process(routine, targets)

    def process(self, routine, targets):
        """
        Process the driver and possibly kernels
        """
        dispatched_routines = ()
        offset_map = {}
        for call in FindNodes(ir.CallStatement).visit(routine.body):
            if str(call.name).lower() not in targets:
                continue
            # skip already dispatched routines but still update the call signature
            if call.routine in dispatched_routines:
                self.update_call_signature(call)
                continue
            # explicit array dimensions for the callee
            self.explicit_dimensions(call.routine)
            dispatched_routines += (call.routine,)
            # create the offset map and apply to call and callee
            offset_map[call.routine.name.lower()] = self.create_offset_map(call)
            self.process_callee(call.routine, offset_map[call.routine.name.lower()])
            self.update_call_signature(call)

    def update_call_signature(self, call):
        """
        Replace constant indices for call arguments being arrays with ':' and update the call.
        """
        new_args = [arg.clone(dimensions=\
                tuple(sym.RangeIndex((None, None)) if self.is_constant_dim(d) else d for d in arg.dimensions))\
                if isinstance(arg, sym.Array) else arg for arg in call.arguments]
        new_kwargs = [(kw[0], kw[1].clone(dimensions=\
                tuple(sym.RangeIndex((None, None)) if self.is_constant_dim(d) else d for d in kw[1].dimensions)))\
                if isinstance(kw[1], sym.Array) else kw for kw in call.kwarguments]
        call._update(arguments=as_tuple(new_args), kwarguments=as_tuple(new_kwargs))

    def create_offset_map(self, call):
        """
        Create map/dictionary for arguments with constant array indices.
        
        For, e.g., 

        integer :: arg(len1, len2, len3, len4)
        call kernel(..., arg(:, 2, 4:6, i), ...)

        offset_map[arg] = {
            0: (0, None, None),  # same index as before, no offset
            1: (None, 1, len2),  # New index, offset 1, size of the dimension is len2
            2: (1, 4, len3),     # Used to be position 1, offset 4, size of the dimension is len3
            3: (-1, None, None), # disregard as this is neither constant nor passed to callee
        }
        """
        offset_map = {}
        for routine_arg, call_arg in call.arg_iter():
            if not isinstance(routine_arg, sym.Array):
                continue
            offset_map[routine_arg.name] = {}
            current_index = 0
            for i, dim in enumerate(call_arg.dimensions):
                if self.is_constant_dim(dim):
                    if isinstance(dim, sym.RangeIndex):
                        # constant array index is e.g. '1:3' or '5:10'
                        offset_map[routine_arg.name][i] = (current_index, dim.children[0], call_arg.shape[i])
                    else:
                        # constant array index is e.g., '1' or '42'
                        offset_map[routine_arg.name][i] = (None, dim, call_arg.shape[i])
                        current_index -= 1
                else:
                    if not isinstance(dim, sym.RangeIndex):
                        # non constant array index is a variable e.g. 'jl'
                        offset_map[routine_arg.name][i] = (-1, None, None)
                        current_index -= 1
                    else:
                        # non constant array index is ':'
                        offset_map[routine_arg.name][i] = (current_index, None, None)
                current_index += 1
        return offset_map

    def process_callee(self, routine, offset_map):
        """
        Process/adapt the callee according to information in `offset_map`.

        Adapt the variable declarations and usage/indexing.
        """
        # adapt variable declarations, thus adapt the dimension and shape of the corresponding arguments
        vmap = {}
        variable_map = routine.variable_map
        for var_name in offset_map:
            var = variable_map[var_name]
            new_dims = ()
            for i in range(max(k for k, v in offset_map[var.name].items() if v != 0) + 1):
                original_index = offset_map[var_name][i][0]
                offset = offset_map[var_name][i][1]
                size = offset_map[var_name][i][2]
                if not (original_index is None or 0 <= original_index < len(var.dimensions)):
                    continue
                if offset is not None:
                    new_dims += (size,)
                else:
                    new_dims += (var.shape[original_index],)
            vmap[var] = var.clone(dimensions=new_dims, type=var.type.clone(shape=new_dims))
        routine.spec = SubstituteExpressions(vmap).visit(routine.spec)
        # adapt the variable usage, thus the indexing/dimension
        vmap = {}
        for var in FindVariables(unique=False).visit(routine.body):
            if var.name in offset_map and var.dimensions is not None and var.dimensions:
                new_dims = ()
                for i in range(max(k for k, v in offset_map[var.name].items() if v != 0) + 1):
                    original_index = offset_map[var.name][i][0]
                    offset = offset_map[var.name][i][1]
                    if not (original_index is None or 0 <= original_index < len(var.dimensions)):
                        continue
                    if offset is not None:
                        if original_index is None:
                            new_dims += (offset,)
                        else:
                            new_dims += (var.dimensions[original_index] + offset - 1,)
                    else:
                        new_dims += (var.dimensions[original_index],)
                vmap[var] = var.clone(dimensions=new_dims)
        routine.body = SubstituteExpressions(vmap).visit(routine.body)
loki-ecmwf-0.3.6/loki/transformations/array_indexing/vector_notation.py0000664000175000017500000003572415167130205026735 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

""" Utilities to manipulate vector notation in array expressions. """

from itertools import count

from loki.expression import symbols as sym, LokiIdentityMapper
from loki.frontend import HAVE_FP
from loki.ir import (
    nodes as ir, FindNodes, FindExpressions, Transformer,
    FindVariables, SubstituteExpressions, FindInlineCalls
)
from loki.tools import as_tuple, dict_override, OrderedSet
from loki.types import SymbolAttributes, BasicType

from loki.transformations.utilities import (
    get_integer_variable, get_loop_bounds
)

if HAVE_FP:
    from fparser.two import Fortran2003


__all__ = [
    'remove_explicit_array_dimensions', 'add_explicit_array_dimensions',
    'resolve_vector_notation', 'resolve_vector_dimension',
    'ResolveVectorNotationTransformer'
]


def remove_explicit_array_dimensions(routine, calls_only=False):
    """
    Remove colon notation from array dimensions within :any:`Subroutine` ``routine``.
    E.g., convert two-dimensional array ``arr2d(:,:)`` to ``arr2d`` or
    ``arr3d(:,:,:)`` to ``arr3d``, but NOT e.g., ``arr(1,:,:)``.

    Parameters
    ----------
    routine: :any:`Subroutine`
        The subroutine to check
    calls_only: bool
        Whether to remove colon notation from array dimensions only
        from arrays within (inline) calls or all arrays (default: False)
    """
    if calls_only:
        # handle calls (to subroutines) and inline calls (to functions)
        calls = FindNodes(ir.CallStatement).visit(routine.body)
        inline_calls = FindInlineCalls().visit(routine.body)
        inline_call_map = {}
        for call in as_tuple(calls) + as_tuple(inline_calls):
            # handle arguments
            arguments = ()
            for arg in call.arguments:
                if isinstance(arg, sym.Array) and all(dim == sym.RangeIndex((None, None)) for dim in arg.dimensions):
                    new_dimensions = None
                    arguments += (arg.clone(dimensions=new_dimensions),)
                else:
                    arguments += (arg,)
            # handle kwargs
            kwarguments = ()
            for (kwarg_name, kwarg) in call.kwarguments:
                if isinstance(kwarg, sym.Array) and all(dim==sym.RangeIndex((None, None)) for dim in kwarg.dimensions):
                    kwarguments += ((kwarg_name, kwarg.clone(dimensions=None)),)
                else:
                    kwarguments += ((kwarg_name, kwarg),)
            # distinguish calls and inline calls
            if isinstance(call, sym.InlineCall):
                inline_call_map[call] = call.clone(parameters=arguments, kw_parameters=kwarguments)
            else:
                # directly update calls
                call._update(arguments=arguments, kwarguments=kwarguments)
        if inline_call_map:
            # update inline calls via expression substitution
            routine.body = SubstituteExpressions(inline_call_map).visit(routine.body)
    else:
        arrays = [var for var in FindVariables(unique=False).visit(routine.body) if isinstance(var, sym.Array)]
        array_map = {}
        for array in arrays:
            if all(dim == sym.RangeIndex((None, None)) for dim in array.dimensions):
                new_dimensions = None
                array_map[array] = array.clone(dimensions=new_dimensions)
        routine.body = SubstituteExpressions(array_map).visit(routine.body)


def add_explicit_array_dimensions(routine):
    """
    Make dimensions of arrays explicit within :any:`Subroutine` ``routine``.
    E.g., convert two-dimensional array ``arr2d`` to ``arr2d(:,:)`` or
    ``arr3d`` to ``arr3d(:,:,:)``.

    Parameters
    ----------
    routine: :any:`Subroutine`
        The subroutine to check
    """
    arrays = [var for var in FindVariables(unique=False).visit(routine.body) if isinstance(var, sym.Array)]
    array_map = {}
    for array in arrays:
        if not array.dimensions:
            new_dimensions = (sym.RangeIndex((None, None)),) * len(array.shape)
            array_map[array] = array.clone(dimensions=new_dimensions)
    routine.body = SubstituteExpressions(array_map).visit(routine.body)


def resolve_vector_notation(routine):
    """
    Resolve implicit vector notation by inserting explicit loops
    """

    # Find loops and map their range to the loop index variable
    loop_map = {
        sym.RangeIndex(loop.bounds.children): loop.variable
        for loop in FindNodes(ir.Loop).visit(routine.body)
    }

    transformer = ResolveVectorNotationTransformer(
        loop_map=loop_map, scope=routine, inplace=True,
        derive_qualified_ranges=True,
    )
    routine.body = transformer.visit(routine.body)

    # Add declarations for all newly create loop index variables
    routine.variables += tuple(OrderedSet(transformer.index_vars))


def resolve_vector_dimension(routine, dimension, derive_qualified_ranges=False):
    """
    Resolve vector notation for a given dimension only. The dimension
    is defined by a loop variable and the bounds of the given range.

    Unliked the related :meth:`resolve_vector_notation` utility, this
    will only resolve the defined dimension according to ``bounds``
    and ``loop_variable``.

    Parameters
    ----------
    routine : :any:`Subroutine`
        The subroutine in which to resolve vector notation usage.
    dimension : :any:`Dimension`
        Dimension object that defines the dimension to resolve
    derive_qualified_ranges : bool
        Flag to enable the derivation of (all) range bounds from
        shape information.
    """
    # Find the iteration index variable and bound variables
    index = get_integer_variable(routine, name=dimension.index)
    bounds = get_loop_bounds(routine, dimension=dimension)

    # Map any range indices to the given loop index variable
    loop_map = {sym.RangeIndex(bounds): index}

    transformer = ResolveVectorNotationTransformer(
        loop_map=loop_map, scope=routine, inplace=True,
        derive_qualified_ranges=derive_qualified_ranges,
        map_unknown_ranges=False
    )
    routine.body = transformer.visit(routine.body)

    # Add declarations for all newly create loop index variables
    routine.variables += tuple(OrderedSet(transformer.index_vars))


class IterationRangeShapeMapper(LokiIdentityMapper):
    """
    A mapper that derives the fully qualified iteration dimension for
    unbounded :any:`RangeIndex` indices in array expressions.
    """

    @staticmethod
    def _shape_to_range(s):
        return sym.RangeIndex(
            (s.lower, s.upper, s.step) if isinstance(s, sym.Range) else (sym.IntLiteral(1), s)
        )

    def map_array(self, expr, *args, **kwargs):
        """ Replace ``:`` range indices with ``1:shape`` vector indices """

        # Resolve implicit range indices if we know the shape
        if not expr.dimensions and expr.shape:
            expr = expr.clone(dimensions=tuple(sym.RangeIndex((None, None)) for _ in expr.shape))

        # Derive fully qualified bounds for ``:``
        new_dims = tuple(
            self._shape_to_range(s) if isinstance(d, sym.RangeIndex) and d == ':' else d
            for i, d, s in zip(count(), expr.dimensions, as_tuple(expr.shape))
        )
        # make sure it is not a inline call that was misread as array access ...
        if new_dims:
            return expr.clone(dimensions=new_dims)
        return expr


class IterationRangeIndexMapper(LokiIdentityMapper):
    """
    A mapper that replaces fully qualified :any:`RangeIndex` symbols
    with discrete loop indices and collects the according
    ``index_to_range_map``.

    This takes mapping of known loop indices for a set of ranges and will
    use these variables if it encounters a matching index range. If not it
    will create new index variables using the given scope and ``basename``.
    The flag ``map_unknown_ranges`` can be used to toggle the
    automatic generation of generic indices from qualified range
    symbols.

    Parameters
    ----------
    routine: :any:`Subroutine`
        The subroutine to check
    loop_map : dict of :any:`RangeIndex` to :any:`Scalar`
        Map of known loop indices for given ranges
    basename : str
        Base name string for new iteration variables
    scope : :any:`Subroutine` or :any:`Module`
        Scope in which to create potential new iteration index symbols
    map_unknown_ranges : bool
        Flag to indicate whether range indices not encountered in ``loop_map``
        should be should be remapped to generic loop indices.
    """

    def __init__(
            self, *args, loop_map=None, basename=None, scope=None,
            map_unknown_ranges=True, **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.loop_map = loop_map or {}
        self.basename = basename if basename else 'i'
        self.scope = scope
        self.map_unknown_ranges = map_unknown_ranges

        self.index_range_map = {}

    def map_array(self, expr, *args, **kwargs):

        shape_index_map = {}
        for i, dim in zip(count(), expr.dimensions):
            if isinstance(dim, sym.RangeIndex):
                # See if index variable is knwon for this loop range
                if dim in self.loop_map:
                    ivar = self.loop_map[dim]
                else:
                    # Skip if we're not supposed to create new indices
                    if not self.map_unknown_ranges:
                        continue

                    # Create new index variable
                    vtype = SymbolAttributes(BasicType.INTEGER)
                    ivar = sym.Variable(name=f'{self.basename}_{i}', type=vtype, scope=self.scope)
                shape_index_map[(i, dim)] = ivar
                self.index_range_map[ivar] = dim

        # Add index variable to range replacement
        new_dims = as_tuple(
            shape_index_map.get((i, d), d) for i, d in zip(count(), expr.dimensions)
        )
        return expr.clone(dimensions=new_dims)


class ResolveVectorNotationTransformer(Transformer):
    """
    A :any:`Transformer` that resolves implicit vector notation by
    inserting explicit loops.

    Parameters
    ----------
    loop_map : dict of tuple to :any:`Variable`
        A dict mapping the tuple ``(lower, upper, step)`` to
        a known variable symbol to use as loop index.
    scope : :any:`Subroutine` or :any:`Module`
        The scope in which to create new loop index variables
    derive_qualified_ranges : bool
        Derive explicit bounds for all unqualified index ranges
        (``:``) before resolving them with loops.
    map_unknown_ranges : bool
        Flag to indicate whether unknown, but fully qualified range
        indices are to be remapped to loops.
    """

    def __init__(
            self, *args, loop_map=None, scope=None,
            derive_qualified_ranges=True, map_unknown_ranges=True,
            **kwargs
    ):
        super().__init__(*args, **kwargs)

        self.scope = scope
        self.loop_map = {} if loop_map is None else loop_map
        self.index_vars = OrderedSet()

        self.map_unknown_ranges = map_unknown_ranges
        self.derive_qualified_ranges = derive_qualified_ranges

    def visit_Assignment(self, stmt, **kwargs):  # pylint: disable=unused-argument
        create_loops = kwargs.get('create_loops', True)

        if HAVE_FP:
            if any(redux_op in FindExpressions().visit(stmt.rhs)
                   for redux_op in Fortran2003.Intrinsic_Name.array_reduction_names):
                return stmt

        # Replace all unbounded ranges with bounded ranges based on array shape
        if self.derive_qualified_ranges:
            shape_mapper = IterationRangeShapeMapper()
            stmt._update(lhs=shape_mapper(stmt.lhs), rhs=shape_mapper(stmt.rhs))

        # Replace all range indices with loop indices and collect the corresponding mapping
        index_mapper = IterationRangeIndexMapper(
            loop_map=self.loop_map, basename=f'i_{stmt.lhs.basename}', scope=self.scope,
            map_unknown_ranges=self.map_unknown_ranges
        )
        stmt._update(lhs=index_mapper(stmt.lhs), rhs=index_mapper(stmt.rhs))

        # Record all newly create loop index variables,
        # so that we can declare them in the outer context
        index_range_map = index_mapper.index_range_map
        self.index_vars.update(list(index_range_map.keys()))

        # Recursively build new loop nest over all implicit dims
        if create_loops and len(index_range_map):
            loop = None
            body = stmt
            for ivar, irange in index_range_map.items():
                if isinstance(irange, sym.RangeIndex):
                    bounds = sym.LoopRange(irange.children)
                else:
                    bounds = sym.LoopRange((sym.Literal(1), irange, sym.Literal(1)))
                loop = ir.Loop(variable=ivar, body=as_tuple(body), bounds=bounds)
                body = loop

            return loop

        # No vector dimensions encountered, return unchanged
        return stmt

    def visit_MaskedStatement(self, masked, **kwargs):  # pylint: disable=unused-argument
        # TODO: Currently limited to simple, single-clause WHERE stmts
        assert len(masked.conditions) == 1 and len(masked.bodies) == 1

        # Replace all unbounded ranges with bounded ranges based on array shape
        conditions = masked.conditions
        if self.derive_qualified_ranges:
            conditions = IterationRangeShapeMapper()(conditions)

        index_mapper = IterationRangeIndexMapper(
            loop_map=self.loop_map, scope=self.scope,
            map_unknown_ranges=self.map_unknown_ranges
        )
        conditions = index_mapper(conditions)
        index_range_map = index_mapper.index_range_map

        with dict_override(kwargs, {'create_loops': False}):
            bodies = self.visit(masked.bodies, **kwargs)
            else_body = self.visit(masked.default, **kwargs)

        # Rebuild construct as an IF conditional inside a loop over the range bounds
        if not index_range_map:
            return masked

        idx_range = list(index_range_map.values())[0]
        bounds = sym.LoopRange((idx_range.start, idx_range.stop, idx_range.step))
        cond = ir.Conditional(
            condition=conditions[0], body=bodies, else_body=else_body
        )

        # Recursively build new loop nest over all implicit dims
        if len(index_range_map):
            loop = None
            body = cond
            for ivar, irange in index_range_map.items():
                if isinstance(irange, sym.RangeIndex):
                    bounds = sym.LoopRange(irange.children)
                else:
                    bounds = sym.LoopRange((sym.Literal(1), irange, sym.Literal(1)))
                loop = ir.Loop(variable=ivar, body=as_tuple(body), bounds=bounds)
                body = loop
            return loop

        return masked
loki-ecmwf-0.3.6/loki/transformations/split_read_write.py0000664000175000017500000001435215167130205024047 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


from loki.batch import Transformation, ProcedureItem
from loki.expression import Array
from loki.ir import (
    nodes as ir, pragma_regions_attached, is_loki_pragma, FindNodes,
    Transformer, SubstituteExpressions
)
from loki.tools import as_tuple, OrderedSet

__all__ = ['SplitReadWriteTransformation']

class SplitReadWriteWalk(Transformer):
    """
    A :any:`Transformer` class to traverse the IR, in-place replace read-write
    assignments with reads, and build a transformer map for the corresponding writes.

    Parameters
    ----------
    dimensions : list
       A list of :any:`Dimension` objects corresponding to all :any:`Loop`s in the ``!$loki split-read-write`` region.
    variable_map : dict
       The variable_map of the parent :any:`Subroutine`.
    count : int
       A running count of the newly created temporaries in the parent :any:`Subroutine` so that
       temporaries created by previous ``!$loki split-read-write`` regions are not redefined.
    """

    def __init__(self, dimensions, variable_map, count=-1, **kwargs):
        self.write_map = {}
        self.temp_count = count
        self.lhs_var_map = {}
        self.dimensions = dimensions
        self.tmp_vars = []

        # parent subroutine variable_map
        self.variable_map = variable_map

        kwargs['inplace'] = True
        super().__init__(**kwargs)

    def visit_Loop(self, o, **kwargs):

        dim = [d for d in self.dimensions if d.index == o.variable]
        dim_nest = kwargs.pop('dim_nest', [])
        return super().visit_Node(o, dim_nest=dim_nest + dim, **kwargs)

    def visit_Assignment(self, o, **kwargs):

        dim_nest = kwargs.pop('dim_nest', [])
        write = None

        # filter out non read-write assignments and scalars
        if isinstance(o.lhs, Array) and o.lhs.name in o.rhs:

            rhs = SubstituteExpressions(self.lhs_var_map).visit(o.rhs)
            if not o.lhs in self.lhs_var_map:
                _dims = []
                _shape = []

                # determine shape of temporary declaration and assignment
                for s in o.lhs.type.shape:
                    if (dim := [dim for dim in self.dimensions
                                if s in dim.size_expressions]):
                        if dim[0] in dim_nest:
                            _shape += [self.variable_map[dim[0].size]]
                            _dims += [self.variable_map[dim[0].index]]

                # define var to store temporary assignment
                self.temp_count += 1
                _type = o.lhs.type.clone(shape=as_tuple(_shape), intent=None)
                tmp_var = o.lhs.clone(name=f'loki_temp_{self.temp_count}',
                                      dimensions=as_tuple(_dims), type=_type)
                self.lhs_var_map[o.lhs] = tmp_var
                self.tmp_vars += [tmp_var,]

                write = as_tuple(ir.Assignment(lhs=o.lhs, rhs=tmp_var))

            o._update(lhs=self.lhs_var_map[o.lhs], rhs=rhs)

        self.write_map[o] = write
        return o

    def visit_LeafNode(self, o, **kwargs):
        # remove all other leaf nodes from second copy of region
        self.write_map[o] = None
        return super().visit_Node(o, **kwargs)

class SplitReadWriteTransformation(Transformation):
    """
    When accumulating values to multiple components of an array, a compiler cannot rule out
    the possibility that the indices alias the same address. Consider for example the following
    code:

    .. code-block:: fortran

        !$loki split-read-write
        do jlon=1,nproma
           var(jlon, n1) = var(jlon, n1) + 1.
           var(jlon, n2) = var(jlon, n2) + 1.
        enddo
        !$loki end split-read-write

    In the above example, there is no guarantee that ``n1`` and ``n2`` do not in fact point to the same location.
    Therefore the load and store instructions for ``var`` have to be executed in order.

    For cases where the user knows ``n1`` and ``n2`` indeed represent distinct locations, this transformation
    provides a pragma assisted mechanism to split the reads and writes, and therefore make the loads independent
    from the stores. The above code would therefore be transformed to:

    .. code-block:: fortran

        !$loki split-read-write
        do jlon=1,nproma
           loki_temp_0(jlon) = var(jlon, n1) + 1.
           loki_temp_1(jlon) = var(jlon, n2) + 1.
        enddo

        do jlon=1,nproma
           var(jlon, n1) = loki_temp_0(jlon)
           var(jlon, n2) = loki_temp_1(jlon)
        enddo
        !$loki end split-read-write

    Parameters
    ----------
    dimensions : list
       A list of :any:`Dimension` objects corresponding to all :any:`Loop`s in the ``!$loki split-read-write`` region.
    """

    item_filter = (ProcedureItem,)

    def __init__(self, dimensions):
        self.dimensions = as_tuple(dimensions)

    def transform_subroutine(self, routine, **kwargs):

        # cache variable_map for fast lookup later
        variable_map = routine.variable_map
        temp_counter = -1
        tmp_vars = []

        # find split read-write pragmas
        with pragma_regions_attached(routine):
            for region in FindNodes(ir.PragmaRegion).visit(routine.body):
                if is_loki_pragma(region.pragma, starts_with='split-read-write'):

                    transformer = SplitReadWriteWalk(self.dimensions, variable_map, count=temp_counter)
                    transformer.visit(region.body)

                    temp_counter += (transformer.temp_count + 1)
                    tmp_vars += transformer.tmp_vars

                    if transformer.write_map:
                        new_writes = Transformer(transformer.write_map).visit(region.body)
                        region.append(new_writes)

        # add declarations for temporaries
        if tmp_vars:
            tmp_vars = OrderedSet(var.clone(dimensions=var.type.shape) for var in tmp_vars)
            routine.variables += as_tuple(tmp_vars)
loki-ecmwf-0.3.6/loki/transformations/data_offload/0000775000175000017500000000000015167130205022533 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/transformations/data_offload/field_offload.py0000664000175000017500000005271515167130205025674 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.analyse import dataflow_analysis_attached
from loki.batch import Transformation
from loki.expression import Array, symbols as sym, parse_expr
from loki.types import BasicType, SymbolAttributes
from loki.ir import (
    nodes as ir, FindNodes, FindVariables, Transformer,
    SubstituteExpressions, pragma_regions_attached, is_loki_pragma, pragmas_attached
)
from loki.logging import warning, error

from loki.transformations.loop_blocking import split_loop_region
from loki.transformations.utilities import find_driver_loops
from loki.transformations.field_api import (
    FieldPointerMap, field_create_device_data, field_wait_for_async_queue
)
from loki.transformations.parallel import remove_field_api_view_updates


__all__ = [
    'FieldOffloadTransformation', 'FieldOffloadBlockedTransformation', 'find_offload_variables',
    'add_field_offload_calls', 'replace_offload_args'
]


class FieldOffloadTransformation(Transformation):
    """

    Transformation to offload arrays owned by Field API fields to the device.

    **This transformation is IFS specific.**

    The transformation assumes that fields are wrapped in derived types specified in
    ``field_group_types`` and will only offload arrays that are members of such derived types.
    In the process this transformation removes calls to Field API ``update_view`` and adds
    declarations for the device pointers to the driver subroutine.

    The transformation acts on ``!$loki data`` regions and offloads all :any:`Array`
    symbols that satisfy the following conditions:

    1. The array is a member of an object that is of type specified in ``field_group_types``.

    2. The array is passed as a parameter to at least one of the kernel targets passed to ``transform_subroutine``.

    Parameters
    ----------
    deviceptr_prefix: str, optional
        The prefix of device pointers added by this transformation (defaults to ``'loki_devptr_'``).
    field_group_types: list or tuple of str, optional
        Names of the field group types with members that may be offloaded (defaults to ``['']``).
    offload_index: str, optional
        Names of index variable to inject in the outmost dimension of offloaded arrays in the kernel
        calls (defaults to ``'IBL'``).
    """

    def __init__(self, devptr_prefix=None, field_group_types=None, offload_index=None):
        self.deviceptr_prefix = 'loki_devptr_' if devptr_prefix is None else devptr_prefix
        field_group_types = [''] if field_group_types is None else field_group_types
        self.field_group_types = tuple(typename.lower() for typename in field_group_types)
        self.offload_index = 'IBL' if offload_index is None else offload_index

    def transform_subroutine(self, routine, **kwargs):
        role = kwargs['role']
        if role == 'driver':
            self.process_driver(routine)

    def process_driver(self, driver):

        # Remove the Field-API view-pointer boilerplate
        remove_field_api_view_updates(driver, self.field_group_types)

        with pragma_regions_attached(driver):
            with dataflow_analysis_attached(driver):
                for region in FindNodes(ir.PragmaRegion).visit(driver.body):
                    # Only work on active `!$loki data` regions
                    if not region.pragma or not is_loki_pragma(region.pragma, starts_with='data'):
                        continue

                    # Determine the array variables for generating Field API offload
                    offload_variables = find_offload_variables(driver, region, self.field_group_types)
                    offload_map = FieldPointerMap(
                        *offload_variables, scope=driver, ptr_prefix=self.deviceptr_prefix
                    )
                    # Inject declarations and offload API calls into driver region
                    declare_device_ptrs(driver, deviceptrs=offload_map.dataptrs)
                    add_field_offload_calls(driver, region, offload_map)
                    replace_offload_args(driver, region, offload_map, self.offload_index)


class FieldOffloadBlockedTransformation(Transformation):
    """

    Transformation to perform blocked offload of arrays owned by Field API fields to the device.

    **This transformation is IFS specific.**

    The transformation assumes that fields are wrapped in derived types specified in
    ``field_group_types`` and will only offload arrays that are members of such derived types.
    In the process this transformation removes calls to Field API ``update_view`` and adds
    declarations for the device pointers to the driver subroutine.

    The transformation acts on ``!$loki data`` regions and offloads all :any:`Array`
    symbols that satisfy the following conditions:

    1. The array is a member of an object that is of type specified in ``field_group_types``.

    2. The array is passed as a parameter to at least one of the kernel targets passed to ``transform_subroutine``.

    Parameters
    ----------
    block_size: int
        Number of blocks per chunk (specified in the final dimension of the field).
    deviceptr_prefix: str, optional
        The prefix of device pointers added by this transformation (defaults to ``'loki_devptr_'``).
    field_group_types: list or tuple of str, optional
        Names of the field group types with members that may be offloaded (defaults to ``['']``).
    offload_index: str, optional
        Names of index variable to inject in the outmost dimension of offloaded arrays in the kernel
        calls (defaults to ``'IBL'``).
    asynchronous: bool, optional
        Perform asynchronous blocked offload (defaults to ``False``)
    num_queues: int, optional
        Number of queues for asynchronous offload should be set to 2 or higher for asynchronous
        offload (defaults to 1).
    """

    def __init__(self, block_size, devptr_prefix=None, field_group_types=None,
                 offload_index=None, asynchronous=False, num_queues=1):
        self.block_size = block_size
        self.deviceptr_prefix = 'loki_devptr_' if devptr_prefix is None else devptr_prefix
        field_group_types = [''] if field_group_types is None else field_group_types
        self.field_group_types = tuple(typename.lower() for typename in field_group_types)
        self.offload_index = 'IBL' if offload_index is None else offload_index

        if not isinstance(asynchronous, bool):
            warning('[Loki] FieldOffloadBlockedTransformation: asynchronous kwarg must be a bool' +
                    ' asynchronous set to False')
            self.asynchronous = False
        else:
            self.asynchronous = asynchronous
        self.num_queues = num_queues

    def transform_subroutine(self, routine, **kwargs):
        role = kwargs['role']
        if role == 'driver':
            self.process_driver(routine)

    def process_driver(self, driver):

        # Remove the Field-API view-pointer boilerplate
        remove_field_api_view_updates(driver, self.field_group_types)

        with pragma_regions_attached(driver):
            with dataflow_analysis_attached(driver):
                for region in FindNodes(ir.PragmaRegion).visit(driver.body):
                    # Only work on active `!$loki data` regions
                    if not region.pragma or not is_loki_pragma(region.pragma, starts_with='data'):
                        continue

                    offload_variables = find_offload_variables(driver, region, self.field_group_types)
                    offload_map = FieldPointerMap(
                        *offload_variables, scope=driver, ptr_prefix=self.deviceptr_prefix
                    )
                    # inject declarations and offload API calls into driver region
                    declare_device_ptrs(driver, deviceptrs=offload_map.dataptrs)
                    # blocks all loops inside the region and places them inside one
                    splitting_vars, block_loop, region = block_driver_loop(driver, region, self.block_size)

                    if self.asynchronous and self.num_queues > 1:
                        add_device_field_allocations(driver, block_loop, offload_map,
                                                     self.block_size, self.num_queues)
                        queue, offset = add_async_blocking_vars(driver, block_loop, self.num_queues, splitting_vars)
                        add_blocked_field_offload_calls(driver, block_loop, region, offload_map,
                                                        splitting_vars, queue, offset)
                        add_wait_calls(driver, block_loop, queue, self.num_queues)
                    else:
                        add_blocked_field_offload_calls(driver, block_loop, region, offload_map, splitting_vars)
                    replace_offload_args(driver, region, offload_map, self.offload_index)
        if self.asynchronous and self.num_queues >1:
            add_async_queue_to_pragmas(block_loop, queue)


def find_offload_variables(driver, region, field_group_types):
    """
    Find the sets of array variable symbols for which we can generate
    Field API offload code.

    Note
    ----
    This method requires Loki's dataflow analysis to be run on the
    :data:`region` via :meth:`dataflow_analysis_attached`.

    Parameters
    ----------
    region : :any:`PragmaRegion`
        Code region object for which to determine offload variables
    field_group_types : list or tuple of str, optional
        Names of the field group types with members that may be offloaded (defaults to ``['']``).

    Returns
    -------
    (inargs, inoutargs, outargs) : (tuple, tuple, tuple)
        The sets of array symbols split into three tuples according to access type.
    """

    # Use dataflow analysis to find in, out and inout variables to that region
    inargs = region.uses_symbols - region.defines_symbols
    inoutargs = region.uses_symbols & region.defines_symbols
    outargs = region.defines_symbols - region.uses_symbols

    # Filter out relevant array symbols
    inargs = tuple(a for a in inargs if isinstance(a, sym.Array) and a.parent)
    inoutargs = tuple(a for a in inoutargs if isinstance(a, sym.Array) and a.parent)
    outargs = tuple(a for a in outargs if isinstance(a, sym.Array) and a.parent)

    # Do some sanity checking and warning for enclosed calls
    for call in FindNodes(ir.CallStatement).visit(region):
        if call.routine is BasicType.DEFERRED:
            warning(f'[Loki] Data offload: Routine {driver.name} has not been enriched ' +
                    f'in {str(call.name).lower()}')
            continue
        for param, arg in call.arg_iter():
            if not isinstance(param, Array):
                continue
            try:
                parent = arg.parent
                if parent.type.dtype.name.lower() not in field_group_types:
                    warning(f'[Loki] Data offload: The parent object {parent.name} of type ' +
                            f'{parent.type.dtype} is not in the list of field wrapper types')
                    continue
            except AttributeError:
                warning(f'[Loki] Data offload: Raw array object {arg.name} encountered in' +
                        f' {driver.name} that is not wrapped by a Field API object')
                continue

    return inargs, inoutargs, outargs


def declare_device_ptrs(driver, deviceptrs):
    """
    Add a set of data pointer declarations to a given :any:`Subroutine`
    """
    for devptr in deviceptrs:
        if devptr.name in driver.variable_map:
            warning(f'[Loki] Data offload: The routine {driver.name} already has a ' +
                    f'variable named {devptr.name}')

    driver.variables += deviceptrs


def add_field_offload_calls(driver, region, offload_map):

    update_map = {
        region: offload_map.host_to_device_calls + (region,) + offload_map.sync_host_calls
    }
    Transformer(update_map, inplace=True).visit(driver.body)


def add_blocked_field_offload_calls(driver, block_loop, region, offload_map, splitting_vars,
                                    queue=None, offset=None):
    """
    Add blocked Field API data transfer calls to a region inside a block loop.


    Parameters
    ----------
    driver : :any:`Subroutine`
        Driver subroutine IR node
    block_loop : :any:`ir.Loop`
        Block loop containing the region to be offloaded.
    region : :any:`PragmaRegion`
        Code region to prepend and append data transfer calls.
    offload_map : :any:`FieldPointerMap`
        FieldPointerMap with variables to be offloaded.
    splitting_vars: :any:`LoopSplittingVariables`
        Loop splitting variables for `block_loop`
    queue : optional
        Queue parameter for Field API data transfer calls
    offset : optional
        Offset parameter for Field API data transfer calls.
    """

    host_to_device = offload_map.host_to_device_force_calls(blk_bounds=sym.LiteralList(values=(
                                                                splitting_vars.block_start,
                                                                splitting_vars.block_end)
                                                                                      ),
                                                            queue=queue,
                                                            offset=offset
                                                            )

    device_to_host = offload_map.sync_host_force_calls(blk_bounds=sym.LiteralList(values=(
                                                            splitting_vars.block_start,
                                                            splitting_vars.block_end)
                                                                                 ),
                                                       queue=queue,
                                                       offset=offset
                                                       )
    with pragmas_attached(driver, ir.Loop):
        update_map = {region: host_to_device + (region,) + device_to_host}
        Transformer(update_map, inplace=True).visit(block_loop)


def add_device_field_allocations(driver, block_loop, offload_map, block_size, num_queues):
    """
    Add Field API device data allocation calls for variables in `offload_map`.

    Parameters
    ----------
    driver : :any:`Subroutine`
        Driver subroutine IR node
    block_loop : :any:`ir.Loop`
        Block loop containing the region to be offloaded.
    offload_map : :any:`FieldPointerMap`
        FieldPointerMap with variables to be offloaded.
    block_size : `int`
        Number of blocks in a chunk.
    """
    blk_bounds = sym.LiteralList(values=(sym.IntLiteral(1), block_size * num_queues))
    create_device_data_calls = tuple(field_create_device_data(field_ptr=offload_map.field_ptr_from_view(arg),
                                                              scope=driver,
                                                              blk_bounds=blk_bounds)
                                     for arg in offload_map.args)
    create_device_data_calls = tuple(dict.fromkeys(create_device_data_calls))
    with pragmas_attached(driver, ir.Loop):
        update_map = {
            block_loop: create_device_data_calls + (block_loop,)
        }
        Transformer(update_map, inplace=True).visit(driver.body)



def add_async_blocking_vars(routine, block_loop, num_queues, splitting_vars):
    """
    Add the variables required for asynchronous blocked offloading over multiple queues to
    the routine.

    The `queue` variable is assigned the value `mod(block_idx, nqueues) + 1` inside the  block loop
    and the `offset` variable is assigned the value `(queue-1) * block_size`

    Parameters
    ----------
    routine : :any:`Subroutine`
        Subroutine IR node containing loop to be blocked.
    block_loop : :any:`ir.Loop`
        Block loop containing the region to be offloaded.
    num_queues : :any:`IntLiteral`
        Queue parameter for Field API data transfer calls
    splitting_vars : :any:`LoopSplittingVariables`
        Loop splitting variables for `block_loop`

    Returns
    -------
    queue:
        Variable holding the queue number in the blocked loop.
    offset:
        Variable holding the offset in the blocked loop.
    """
    queue = sym.Variable(name='loki_block_queue', type=SymbolAttributes(BasicType.INTEGER),
                         scope=routine)
    nqueues = sym.Variable(name='loki_block_nqueues', type=SymbolAttributes(BasicType.INTEGER,
                                                                            parameter=True,
                                                                            initial=sym.IntLiteral(num_queues)),
                           scope=routine)
    offset = sym.Variable(name='loki_block_offset', type=SymbolAttributes(BasicType.INTEGER),
                          scope=routine)
    # add variables to routine variable map
    routine.variables += (queue, nqueues, offset)

    # set queue and offset in loop
    async_blocking_body = (
        ir.Assignment(queue,
                      sym.Sum(children=(sym.InlineCall(sym.DeferredTypeSymbol('MODULO', scope=routine),
                                                       parameters=(splitting_vars.block_idx,
                                                       nqueues)),
                                        sym.IntLiteral(1)))
                      ),
        ir.Assignment(offset, parse_expr(f'({queue}-1)*{splitting_vars.block_size}'))
    )
    block_loop._update(body=async_blocking_body+block_loop.body)

    return queue, offset


def replace_offload_args(driver, region, offload_map, offload_index):
    """
    Replace instances of offload variables with their device pointers inside the region.

    Parameters
    ----------
    driver : :any:`Subroutine`
        Subroutine containing the data region.
    region : :any:`PragmaRegion`
        Data region in which offload variables in the offload map will be replaced.
    offload_map : :any:`FieldPointerMap`
        FieldPointerMap with variables to be offloaded.
    offload_index: str
        Name of index variable to inject in the outmost dimension of offloaded pointers.
    """
    change_map = {}
    offload_idx_expr = driver.variable_map[offload_index]

    args = offload_map.args
    for arg in FindVariables().visit(region.body):
        if arg.name not in args:
            continue

        dataptr = offload_map.dataptr_from_array(arg)
        if len(arg.dimensions) != 0:
            dims = arg.dimensions + (offload_idx_expr,)
        else:
            dims = (sym.RangeIndex((None, None)),) * (len(dataptr.shape)-1) + (offload_idx_expr,)
        change_map[arg] = dataptr.clone(dimensions=dims)
    SubstituteExpressions(change_map, inplace=True).visit(region.body)


def block_driver_loop(driver, region, block_size):
    """
    Block a driver loop inside a code region.

    Parameters
    ----------
    driver : :any:`Subroutine`
        Subroutine containing the driver loop.
    region : :any:`PragmaRegion`
        Region, containing the driver loop, to be blocked inside the loop.
    block_size : `int`
        Number of blocks in a chunk.
    """
    with pragmas_attached(driver, ir.Loop):
        driver_loops = find_driver_loops(driver.body, targets=None)
        if len(driver_loops) == 1:
            loop = driver_loops[0]
        elif len(driver_loops) > 1:
            warning('[Loki] FieldOffloadBlockedTransformation: Multiple driver loops found in ' +
                    f'{driver.name}, discarding all but first')
            loop = driver_loops[0]
        else:
            error(f'[Loki] FieldOffloadBlockedTransformation: No driver loops found in {driver.name}')

        splitting_vars, _, outer_loop, region = split_loop_region(driver, loop, block_size, region)  # pylint: disable=used-before-assignment
    return splitting_vars, outer_loop, region


def add_async_queue_to_pragmas(section, queue):
    """
    Add async pragma content to Loki data and driver-loop pragmas in the section.

    Parameters
    ----------
    section : :any:`Section`
        IR section in whcih async content will be added to Loki pragmas.
    queue :
        Variable holding the queue value that will be added to the async pragma content.
    """
    pragmas = FindNodes(ir.Pragma).visit(section)
    async_content = f' async({queue.name})'
    for pragma in pragmas:
        if is_loki_pragma(pragma, starts_with='data') or is_loki_pragma(pragma, starts_with='driver-loop'):
            pragma._update(content=pragma.content+async_content)


def add_wait_calls(driver, block_loop, queue, num_queues):
    """
    Add wait calls to synchronize all queues after a block loop.

    Parameters
    ----------
    driver :  :any:`Subroutine`
        Routine containing the block loop.
    block_loop : :any:`ir.Loop`
        Block loop containing the region to be offloaded.
    queue :
        Variable holding the queue value that will be added to the async pragma content.
    num_queues : :any:`IntLiteral`
        Total number of async queues.
    """
    wait_loop = ir.Loop(variable=queue,
                        body=field_wait_for_async_queue(queue, driver),
                        bounds=sym.LoopRange((sym.IntLiteral(1), sym.IntLiteral(num_queues)))
                        )
    change_map = {block_loop: (block_loop, wait_loop)}
    Transformer(change_map, inplace=True).visit(driver.body)
loki-ecmwf-0.3.6/loki/transformations/data_offload/__init__.py0000664000175000017500000000130315167130205024641 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
"""
Sub-package providing data offload transformations.
"""

from loki.transformations.data_offload.field_offload import * # noqa
from loki.transformations.data_offload.global_var import * # noqa
from loki.transformations.data_offload.offload import * # noqa
from loki.transformations.data_offload.offload_deepcopy import * # noqa
loki-ecmwf-0.3.6/loki/transformations/data_offload/tests/0000775000175000017500000000000015167130205023675 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/transformations/data_offload/tests/__init__.py0000664000175000017500000000057015167130205026010 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
loki-ecmwf-0.3.6/loki/transformations/data_offload/tests/test_offload_deepcopy.py0000664000175000017500000010332715167130205030616 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from shutil import rmtree
import pytest
import yaml

from loki.backend import fgen
from loki.batch import Scheduler
from loki.ir import (
    nodes as ir, FindNodes, is_loki_pragma, pragma_regions_attached, get_pragma_parameters,
    pragmas_attached
)
from loki.expression import Variable, RangeIndex, IntLiteral
from loki.frontend import available_frontends
from loki.logging import log_levels
from loki.subroutine import Subroutine
from loki.tools import gettempdir, flatten, as_tuple
from loki.transformations import (
        DataOffloadDeepcopyAnalysis, DataOffloadDeepcopyTransformation, find_driver_loops
)
from loki.types import BasicType, DerivedType, SymbolAttributes, Scope


@pytest.fixture(scope='module', name='deepcopy_code')
def fixture_deepcopy_code():
    fcode = {
        #----- field_module -----
        'field_module' : (
            """
module field_module
type :: field_3d
end type
end module field_module
            """.strip()
        ),

        #----- type_mod -----
        'type_def_mod' : (
            """
module type_def_mod
   use field_module, only : field_3d
   type :: variable_type
      class(field_3d), pointer :: fp => null()
      real, pointer, contiguous :: p(:,:,:) => null()
   end type

   type :: other_variable_type
      class(field_3d), pointer :: f_t0 => null()
      real, pointer, contiguous :: vt0_field(:,:,:) => null()
   end type

   type :: view_prefix_variable_type
      class(field_3d), pointer :: f_t1 => null()
      real, pointer, contiguous :: pt1_field(:,:,:) => null()
   end type

   type :: superfluous_type
      type(variable_type) :: var
   end type

   type :: struct_type
      type(variable_type) :: a
      type(variable_type) :: b
      type(variable_type) :: c
      type(variable_type) :: d
      type(variable_type) :: e
      type(superfluous_type), allocatable :: var_ptr(:)
   end type

   type :: opts_type
      logical :: one_flag
      logical :: another_flag
   end type

   type :: dims_type
      integer :: kst
      integer :: kend
      integer :: kbl
      integer :: ngpblks
      integer :: m
   end type

   type :: geom_dims
      integer :: nproma
   end type

   type :: geom_type
      type(geom_dims) :: dim
      integer, pointer :: metadata(:)
   end type
end module type_def_mod
            """.strip()
        ),

        #----- constants_module -----
        'constants_module' : (
            """
module constants_module
real :: pi
real :: eps
end module constants_module
            """.strip()
        ),

        #----- nested_kernel_write -----
        'nested_kernel_write' : (
            """
module nested_kernel_write_mod
contains
subroutine nested_kernel_write(p)
    !... intent(out) can be dangerous with pointers, so we make this intent(inout)
    real, intent(inout) :: p(:,:,:)

    p = 0.
end subroutine nested_kernel_write
end module nested_kernel_write_mod
            """.strip()
        ),

        #----- nested_kernel_read -----
        'nested_kernel_read' : (
            """
module nested_kernel_read_mod
contains
subroutine nested_kernel_read(p)
    real, intent(in) :: p(:,:,:)
    real, allocatable :: b(:,:,:)

    allocate(b, mold=p)
    b = p
    deallocate(b)
end subroutine nested_kernel_read
end module nested_kernel_read_mod
            """.strip()
        ),

        #----- other_kernel -----
        'other_kernel' : (
            """
module other_kernel_mod
contains
subroutine other_kernel(struct)
   use type_def_mod, only : variable_type
   type(variable_type), intent(inout) :: struct

   struct%p = struct%p + 1.
end subroutine other_kernel
end module other_kernel_mod
            """.strip()
        ),

        #----- kernel -----
        'kernel' : (
            """
module kernel_mod
contains
subroutine kernel(geometry, bnds, struct, variable, another_variable)
   use nested_kernel_write_mod, only: nested_kernel_write
   use nested_kernel_read_mod, only: nested_kernel_read
   use other_kernel_mod, only : other_kernel
   use type_def_mod, only: struct_type, dims_type, geom_type, other_variable_type, &
   &                       view_prefix_variable_type
   implicit none

   type(geom_type), intent(in) :: geometry
   type(dims_type), intent(in) :: bnds
   type(struct_type), intent(inout) :: struct
   type(other_variable_type), intent(inout) :: variable
   type(view_prefix_variable_type), intent(in) :: another_variable


   integer :: jrof, jfld, j
   real, pointer :: tmp(:,:,:) => null()
   integer :: a(geometry%dim%nproma)

   call nested_kernel_write(struct%a%p(:,:,bnds%kbl))
   call nested_kernel_read(struct%b%p(:,:,bnds%kbl))
   call nested_kernel_read(variable%vt0_field(:,:,bnds%kbl))
   call nested_kernel_read(another_variable%pt1_field(:,:,bnds%kbl))

   tmp => struct%c%p !... yes this completely breaks the dataflow analysis
   tmp = 0.

   do jrof = bnds%kst, bnds%kend
     struct%b%p(jrof,:,bnds%kbl) = struct%a%p(jrof,:,bnds%kbl)
   enddo 

   do jfld = 1, size(struct%var_ptr)
     call other_kernel(struct%var_ptr(jfld)%var)
   enddo

   do jrof = bnds%kst, bnds%kend
     struct%d%p(jrof,:,bnds%kbl) = struct%e%p(jrof,:,bnds%kbl)
   enddo

   j = geometry%metadata(1)

end subroutine kernel
end module kernel_mod
            """.strip()
        ),

        #----- driver -----
        'driver' : (
            """
module driver_mod
use constants_module, only: pi
implicit none
contains
subroutine driver(dims, struct, array_arg, geometry, variable, another_variable)
   use kernel_mod, only : kernel
   use nested_kernel_write_mod, only: nested_kernel_write
   use type_def_mod, only: struct_type, dims_type, geom_type, other_variable_type, &
   &                       view_prefix_variable_type
   use constants_module, only: eps
   implicit none

   type(dims_type), intent(in) :: dims
   type(struct_type), intent(inout) :: struct
   integer, intent(out) :: array_arg(:,:,:)
   type(geom_type), intent(in) :: geometry
   type(other_variable_type), intent(inout) :: variable
   type(view_prefix_variable_type), intent(in) :: another_variable
   type(dims_type) :: local_dims
   integer :: ibl, ij

#ifdef geometry_present
!$loki data private(local_dims) present(geometry) write(struct%c%p)
   do ibl=1,local_dims%ngpblks
     local_dims = dims
     local_dims%kbl = ibl
     ij = 0

     variable%vt0_field(:,:,ibl) = pi + eps

     call kernel(geometry, local_dims, struct, variable, another_variable)
     call nested_kernel_write(struct%e%p(:,local_dims%m,local_dims%kbl))
     call nested_kernel_write(array_arg)
   enddo
!$loki end data
#else
!$loki data private(local_dims) write(struct%c%p)
   do ibl=1,local_dims%ngpblks
     local_dims = dims
     local_dims%kbl = ibl
     ij = 0

     variable%vt0_field(:,:,ibl) = pi + eps

     call kernel(geometry, local_dims, struct, variable, another_variable)
     call nested_kernel_write(struct%e%p(:,local_dims%m,local_dims%kbl))
     call nested_kernel_write(array_arg)
   enddo
!$loki end data
#endif

end subroutine driver
end module driver_mod
            """.strip()
        ),
        #----- simple driver -----
        'simple_driver' : (
            """
module simple_driver_mod
implicit none
contains
subroutine simple_driver(ngpblks, variable)
   use type_def_mod, only: other_variable_type
   implicit none

   integer, intent(in) :: ngpblks
   type(other_variable_type), intent(inout) :: variable
   integer :: ibl

!$loki data 
!$loki driver-loop
   do ibl=1,ngpblks

     variable%vt0_field(:,:,ibl) = 0.

   enddo
!$loki end data

end subroutine simple_driver
end module simple_driver_mod
            """.strip()
        )
    }

    workdir = gettempdir()/'test_offload_deepcopy'
    if workdir.exists():
        rmtree(workdir)
    workdir.mkdir()
    for name, code in fcode.items():
        (workdir/f'{name}.F90').write_text(code)

    yield workdir

    rmtree(workdir)


@pytest.fixture(scope='module', name='expected_analysis')
def fixture_expected_analysis():
    return {
        'local_dims': {
            'kbl': 'write',
            'kend': 'read',
            'kst': 'read',
            'ngpblks': 'read',
            'm': 'read'
        },
        'dims': 'read',
        'geometry': {
            'dim': {
                'nproma': 'read'
            },
            'metadata' : 'read'
        },
        'array_arg': 'write',
        'variable' : {
            'vt0_field' : 'write'
        },
        'another_variable' : {
            'pt1_field' : 'read',
        },
        'pi': 'read',
        'eps': 'read',
        'struct': {
            'a': {
                'p': 'write'
            },
            'b': {
                'p': 'readwrite'
            },
            'c': {
                'p': 'read'
            },
            'd': {
                'p': 'write'
            },
            'e': {
                'p': 'readwrite'
            },
            'var_ptr': {
                'var': {
                    'p': 'readwrite'
                }
            }
        },
        'ij': 'write'
    }


@pytest.fixture(scope='function', name='config')
def fixture_config():
    """
    Default configuration dict with basic options.
    """
    return {
        'default': {
            'mode': 'idem',
            'role': 'kernel',
            'expand': True,
            'strict': True,
            'enable_imports': True,
            'field_ptr_suffix': '_field',
            'field_ptr_map': {},
            'disable': ['*get_host_data*', '*get_device_data*', '*delete_device_data*']
        },
    }


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('output_analysis', [True, False])
def test_offload_deepcopy_analysis(frontend, config, deepcopy_code, expected_analysis,
                                   output_analysis, tmp_path, caplog):
    """
    Test the analysis for the data offload deepcopy generation.
    """

    def _nested_sort(_dict):
        sorted_dict = {}
        for k, v in _dict.items():
            if isinstance(v, dict):
                sorted_dict[k] = _nested_sort(v)
            else:
                sorted_dict[k] = v

        return dict(sorted(sorted_dict.items()))

    config['routines'] = {
        'driver': {'role': 'driver'},
        'variable_type': {'field_prefix': 'f'}
    }

    scheduler = Scheduler(
        paths=deepcopy_code, config=config, frontend=frontend, xmods=[tmp_path],
        output_dir=tmp_path, preprocess=True
    )

    with caplog.at_level(log_levels['WARNING']):
        transformation = DataOffloadDeepcopyAnalysis(output_analysis=output_analysis)
        scheduler.process(transformation=transformation)

        # check that the warning for pointer associations is produced
        messages = [log.message for log in caplog.records]
        assert '[Loki::DataOffloadDeepcopyAnalysis] Pointer associations found in kernel' in messages[-1]

    # The analysis is tied to driver loops
    trafo_data_key = transformation._key
    driver_item = scheduler['driver_mod#driver']
    driver_loop = find_driver_loops(driver_item.ir.body, targets=['kernel', 'nested_kernel_write'])[0]

    #stringify dict for comparison
    stringified_dict = transformation.stringify_dict(driver_item.trafo_data[trafo_data_key]['analysis'][driver_loop])
    sorted_expected_analysis = _nested_sort(expected_analysis)
    assert _nested_sort(stringified_dict) == sorted_expected_analysis

    # check that the typedef config was also collected
    assert driver_item.trafo_data[trafo_data_key]['typedef_configs']['variable_type']['field_prefix'] == 'f'

    if output_analysis:
        with open(tmp_path/'driver_kernel_dataoffload_analysis.yaml', 'r') as file:
            _dict = yaml.safe_load(file)
        assert _nested_sort(_dict) == sorted_expected_analysis


def check_array_arg(mode, pragmas):
    """Check the correct generation of deepcopy for `array_arg`."""

    if mode == 'offload':
        copy_pragma = [(p, loc) for loc, p in enumerate(pragmas)
                       if 'unstructured-data create' in p.content and '(array_arg)' in p.content for p in pragmas]
        host_pragma = [(p, loc) for loc, p in enumerate(pragmas)
                       if 'update host' in p.content and '(array_arg)' in p.content for p in pragmas]
        assert copy_pragma
        assert host_pragma
        assert host_pragma[0][1] > copy_pragma[0][1]
    else:
        assert not any('array_arg' in p.content for p in pragmas)


def check_variable_type_host(var, calls):
    """Check generated host pull-back for `type(variable_type)` variables."""

    assert any(call.name.name.lower() == 'sget_host_data_rdwr' and f'{var}%p' in call.arguments
            and f'{var}%fp' in call.arguments for call in calls)

def check_variable_type_device(var, routine, calls, pragmas, access):
    """Check generated copy to device for `type(variable_type)` variables."""

    _pass = 0

    calls = [call for call in calls
             if f'get_device_data_{access}' in call.name.name.lower() and f'{var}%p' in call.arguments]
    pragmas = [pragma for pragma in pragmas
               if 'unstructured-data attach' in pragma.content and f'{var}%p' in pragma.content]
    if calls and pragmas:
        assert routine.body.body.index(calls[0]) < routine.body.body.index(pragmas[0])
        _pass += 1

    assert _pass == 1


def check_variable_type_wipe(var, conds):
    """Check generated wipe for `type(variable_type)` variables."""

    conds = [c for c in conds
             if c.condition.name.lower() == 'associated' and f'{var}%fp' in c.condition.parameters]
    _pass = 0
    for cond in conds:
        calls = FindNodes(ir.CallStatement).visit(cond.body)
        pragmas = FindNodes(ir.Pragma).visit(cond.body)

        calls = [call for call in calls
                 if call.name.name.lower() == f'{var}%fp%delete_device_data']
        pragmas = [pragma for pragma in pragmas
                   if 'exit unstructured-data detach' in pragma.content and f'{var}%p' in pragma.content
                   and 'finalize' in pragma.content]
        if calls and pragmas:
            assert cond.body.index(calls[0]) > cond.body.index(pragmas[0])
            _pass += 1

    assert _pass == 1


def check_other_variable_type(mode, conds, calls, pragmas, routine):
    """Check the generated deepcopy for `type(other_variable_type) :: variable`."""

    # Check pullback to host
    conds = [c for c in conds
             if c.condition.name.lower() == 'associated' and 'variable%f_t0' in c.condition.parameters]
    calls = [call for call in calls if 'variable%f_t0' in call.arguments]

    assert any(call.name.name.lower() == 'sget_host_data_rdwr' and 'variable%vt0_field' in call.arguments
               for call in calls)

    if mode == 'offload':
        # Check copy to device of struct
        pragma = [p for p in pragmas if
                  'unstructured-data in' in p.content and '(variable)' in p.content][0]
        assert routine.body.body.index(pragma) < routine.body.body.index(conds[0])

        # Check deletion of struct from device
        pragma = [p for p in pragmas if
                  'exit unstructured-data delete' in p.content and '(variable)' in p.content][0]
        assert routine.body.body.index(pragma) > routine.body.body.index(conds[-1])

        # Check FIELD_API boilerplate for copying to device
        _pass = 0

        calls = [call for call in calls
                 if 'get_device_data_wronly' in call.name.name.lower()
                 and 'variable%vt0_field' in call.arguments]
        pragmas = [pragma for pragma in pragmas
                   if 'unstructured-data attach' in pragma.content and 'variable%vt0_field' in pragma.content]
        if calls and pragmas:
            assert routine.body.body.index(calls[0]) < routine.body.body.index(pragmas[0])
            _pass += 1

        # Check FIELD_API boilerplate for wiping device
        for cond in conds:
            calls = FindNodes(ir.CallStatement).visit(cond.body)
            pragmas = FindNodes(ir.Pragma).visit(cond.body)

            calls = [call for call in calls
                     if call.name.name.lower() == 'variable%f_t0%delete_device_data']
            pragmas = [pragma for pragma in pragmas
                       if 'exit unstructured-data detach' in pragma.content and 'variable%vt0_field' in pragma.content
                       and 'finalize' in pragma.content]
            if calls and pragmas:
                assert cond.body.index(calls[0]) > cond.body.index(pragmas[0])
                _pass += 1


        assert _pass == 2


def check_view_prefix_variable_type(mode, conds, calls, pragmas, routine):
    """Check the generated deepcopy for `type(view_prefix_variable_type) :: another_variable`."""

    # Check pullback to host
    conds = [c for c in conds if c.condition.name.lower() == 'associated' and
             'another_variable%f_t1' in c.condition.parameters]
    calls = [call for call in calls if 'another_variable%f_t1' in call.arguments]

    var = 'another_variable'
    assert any(call.name.name.lower() == 'sget_host_data_rdwr' and f'{var}%pt1_field' in call.arguments
               for call in calls)

    if mode == 'offload':
        # Check copy to device of struct
        pragma = [p for p in pragmas if 'unstructured-data in' in p.content and
                  '(another_variable)' in p.content][0]
        assert routine.body.body.index(pragma) < routine.body.body.index(conds[0])

        # Check deletion of struct from device
        pragma = [p for p in pragmas if 'exit unstructured-data delete' in p.content and
                  '(another_variable)' in p.content][0]
        assert routine.body.body.index(pragma) > routine.body.body.index(conds[-1])

        # Check FIELD_API boilerplate for copying to device
        call_name = 'sget_device_data_rdonly'
        call_args = ['another_variable%pt1_field', 'another_variable%f_t1']

        _pass = 0
        calls = [call for call in calls
                 if call.name.name.lower() == call_name
                 and all(arg in call.arguments for arg in call_args)]
        pragmas = [pragma for pragma in pragmas
                   if 'unstructured-data attach' in pragma.content
                   and 'another_variable%pt1_field' in pragma.content]

        if calls and pragmas:
            assert routine.body.body.index(calls[0]) < routine.body.body.index(pragmas[0])
            _pass += 1

        # Check FIELD_API boilerplate for wiping device
        for cond in conds:
            calls = FindNodes(ir.CallStatement).visit(cond.body)
            pragmas = FindNodes(ir.Pragma).visit(cond.body)

            calls = [call for call in calls
                     if call.name.name.lower() == 'another_variable%f_t1%delete_device_data']
            pragmas = [pragma for pragma in pragmas
                       if 'exit unstructured-data detach' in pragma.content and
                       'another_variable%pt1_field' in pragma.content and
                       'finalize' in pragma.content]
            if calls and pragmas:
                assert cond.body.index(calls[0]) > cond.body.index(pragmas[0])
                _pass += 1

        assert _pass == 2


def check_geometry(conds, pragmas, routine):
    """Check the generated deepcopy for `type(geom_type) :: geometry`."""

    conds = [c for c in conds
             if 'geometry%metadata' in c.condition.parameters]

    assert all(c.condition.name.lower() == 'associated' for c in conds)

    # geometry should only have copy and wipe related instructions
    assert len(conds) == 2

    # Check copy to device of struct
    pragma = [p for p in pragmas if
              'unstructured-data in' in p.content and '(geometry)' in p.content][0]
    assert routine.body.body.index(pragma) < routine.body.body.index(conds[0])

    # Check copy to device of member
    assert 'unstructured-data in' in conds[0].body[0].content
    assert '(geometry%metadata)' in conds[0].body[0].content

    # Check deletion of struct from device
    pragma = [p for p in pragmas if 'exit unstructured-data delete' in p.content
              and '(geometry)' in p.content and 'finalize' in p.content][0]
    assert routine.body.body.index(pragma) > routine.body.body.index(conds[-1])

    # Check deletion of member from device
    assert 'exit unstructured-data delete' in conds[-1].body[0].content
    assert '(geometry%metadata)' in conds[-1].body[0].content
    assert 'finalize' in conds[-1].body[0].content


def check_struct(mode, conds, calls, pragmas, routine):
    """Check the generated deepcopy for `type(struct_type)`."""

    # Filter out conditions on type(struct_type) :: struct
    struct_conds = []
    for cond in conds:
        parameters = flatten([p.name.lower() for p in cond.condition.parameters])
        if any('struct' in p for p in parameters):
            struct_conds.append(cond)

    # Filter out calls containing members of type(struct_type) :: struct as arguments
    calls = [call for call in calls
             if 'struct' in call.name.name.lower() or any('struct' in arg for arg in call.arguments)]

    # Check var_ptr member
    check_struct_var_ptr(mode, struct_conds)

    # Check host pull-back
    check_variable_type_host('struct%a', calls)
    check_variable_type_host('struct%b', calls)
    check_variable_type_host('struct%c', calls)
    check_variable_type_host('struct%d', calls)
    check_variable_type_host('struct%e', calls)

    if mode == 'offload':
        # Check copy to device of struct
        pragma = [p for p in pragmas if
                  'unstructured-data in' in p.content and '(struct)' in p.content][0]
        assert routine.body.body.index(pragma) < routine.body.body.index(struct_conds[0])

        # Check deletion of struct from device
        pragma = [p for p in pragmas if
                  'exit unstructured-data delete' in p.content and '(struct)' in p.content][0]
        assert routine.body.body.index(pragma) > routine.body.body.index(struct_conds[-1])

        check_variable_type_device('struct%a', routine, calls, pragmas, 'wronly')
        check_variable_type_device('struct%b', routine, calls, pragmas, 'rdwr')
        check_variable_type_device('struct%c', routine, calls, pragmas, 'wronly')
        check_variable_type_device('struct%d', routine, calls, pragmas, 'wronly')
        check_variable_type_device('struct%e', routine, calls, pragmas, 'rdwr')

        check_variable_type_wipe('struct%a', struct_conds)
        check_variable_type_wipe('struct%b', struct_conds)
        check_variable_type_wipe('struct%c', struct_conds)
        check_variable_type_wipe('struct%d', struct_conds)
        check_variable_type_wipe('struct%e', struct_conds)


def check_struct_var_ptr(mode, conds):
    """Check the `var_ptr` member of `type(struct_type) :: struct`."""

    # First check host pull-back FIELD_API calls
    conds = [c for c in conds
             if c.condition.name.lower() == 'allocated' and 'struct%var_ptr' in c.condition.parameters]
    loops = FindNodes(ir.Loop).visit(conds)
    calls = FindNodes(ir.CallStatement).visit(loops)

    assert any(fgen(call.name).lower() == 'sget_host_data_rdwr'
               and 'struct%var_ptr(J1)%var%p' in call.arguments and 'struct%var_ptr(j1)%var%fp' in call.arguments
               for call in calls)

    if mode == 'offload':
        _pass = 0
        for cond in conds:
            calls = FindNodes(ir.CallStatement).visit(cond.body)

            if any('get_device_data' in call.name.name.lower() for call in calls):
                # first entry in conditional body should be copyin pragma
                assert 'unstructured-data in' in cond.body[0].content
                assert '(struct%var_ptr)' in cond.body[0].content

                loop = FindNodes(ir.Loop).visit(cond.body)[0]

                # inside the loop body we have a FIELD_API GET_DEVICE_DATA call
                # and an attach statement
                assert fgen(loop.body[0].name).lower() == 'sget_device_data_rdwr'
                assert loop.body[0].arguments[0] == 'struct%var_ptr(j1)%var%p'
                assert 'unstructured-data attach' in loop.body[1].content
                assert 'struct%var_ptr(J1)%var%p' in loop.body[1].content
                _pass += 1

            elif any('delete_device_data' in call.name.name.lower() for call in calls):
                # last entry in conditional body should be delete pragma
                assert 'exit unstructured-data delete' in cond.body[-1].content
                assert 'finalize' in cond.body[-1].content
                assert '(struct%var_ptr)' in cond.body[-1].content

                # then we have a loop and an association check for a field object
                loop = FindNodes(ir.Loop).visit(cond.body)[0]
                _cond = FindNodes(ir.Conditional).visit(loop.body)[0]
                assert _cond.condition.name.lower() == 'associated'
                assert fgen(_cond.condition.parameters[0]).lower() == 'struct%var_ptr(j1)%var%fp'

                # finally inside the conditional body we have a FIELD_API DELETE_DEVICE_DATA call
                # preceded by an detach statement
                assert 'exit unstructured-data detach' in _cond.body[0].content
                assert 'finalize' in _cond.body[0].content
                assert 'struct%var_ptr(J1)%var%p' in _cond.body[0].content
                assert fgen(_cond.body[1].name).lower() == 'struct%var_ptr(j1)%var%fp%delete_device_data'
                _pass += 1

        assert _pass == 2


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('present', [True, False])
@pytest.mark.parametrize('mode', ['offload', 'set_pointers'])
def test_offload_deepcopy_transformation(frontend, config, deepcopy_code, present, mode, tmp_path):
    """
    Test the generation of host-device deepcopy.
    """

    config['routines'] = {
        'driver': {'role': 'driver'},
        'variable_type': {
            'field_prefix': 'f',
            'field_ptrs': ['p'],
        },
        'other_variable_type': {
            'field_prefix': 'f_',
            'view_prefix': 'v',
            'field_ptr_map': {
                'vt0_field': 'f_t0'
            }
        },
        'view_prefix_variable_type': {
            'field_prefix': 'f_',
            'view_ptr_prefix': 'p'
        }
    }

    defines = ['geometry_present'] if present else []
    scheduler = Scheduler(
        paths=deepcopy_code, config=config, frontend=frontend, xmods=[tmp_path],
        output_dir=tmp_path, preprocess=True, defines=defines
    )

    scheduler.process(transformation=DataOffloadDeepcopyAnalysis())
    transformation = DataOffloadDeepcopyTransformation(mode=mode)
    scheduler.process(transformation=transformation)

    driver_item = scheduler['driver_mod#driver']
    driver = driver_item.ir
    pragmas = FindNodes(ir.Pragma).visit(driver.body)
    conds = FindNodes(ir.Conditional, greedy=True).visit(driver.body)
    calls = FindNodes(ir.CallStatement).visit(driver.body)
    # filter out target calls, as we only need to check generated boilerplate
    calls = [call for call in calls if not call.name.name.lower() in driver_item.targets]

    # check array_arg
    check_array_arg(mode, pragmas)

    # check other_variable_type
    check_other_variable_type(mode, conds, calls, pragmas, driver)

    # check view_prefix_variable_type
    check_view_prefix_variable_type(mode, conds, calls, pragmas, driver)

    # check struct
    check_struct(mode, conds, calls, pragmas, driver)

    if not present and mode == 'offload':
        check_geometry(conds, pragmas, driver)

    # check FIELD_ACCESS_MODULE host accessor imports
    assert any(_import.module.lower() == 'field_access_module' for _import in driver.imports)
    imported_symbols = driver.imported_symbols
    assert 'sget_host_data_rdwr' in imported_symbols

    if mode == 'offload':
        # check dims copyin
        pragma = [p for p in pragmas if
                  'unstructured-data in' in p.content and '(dims)' in p.content]
        assert pragma
        pragma = [p for p in pragmas if
                  'unstructured-data delete' in p.content and '(dims)' in p.content]

        # check FIELD_ACCESS_MODULE device accessor imports
        assert all(f'sget_device_data_{access}' in imported_symbols
                   for access in ['rdwr', 'rdonly', 'wronly'])

        # check data present region
        with pragma_regions_attached(driver):

            assert not any('update host' in p.content and '(ij)' in p.content for p in pragmas)

            region = FindNodes(ir.PragmaRegion).visit(driver.body)[-1]
            assert is_loki_pragma(region.pragma, starts_with='structured-data')

            parameters = get_pragma_parameters(region.pragma)
            present_vars = [v.strip().lower() for v in parameters['present'].split(',')]
            assert all(v in present_vars for v in ['geometry', 'struct', 'variable', 'dims'])


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('depth', [1, 2, 3, 4])
def test_loop_nest_wrapping(frontend, depth):
    """Test the utility to wrap a given body in a loop nest."""

    fcode = """
subroutine kernel()
  implicit none
end subroutine kernel
    """

    def _check_loops(loop, depth, loop_level):
        assert isinstance(loop, ir.Loop)

        # check loop bounds
        assert fgen(loop.bounds.lower).lower() == f'lbound(var, {loop_level})'
        assert fgen(loop.bounds.upper).lower() == f'ubound(var, {loop_level})'

        # check loop index
        assert loop.variable.name.lower() == f'j{loop_level}'

        if loop_level > 1:
            _check_loops(loop.body[0], depth, loop_level - 1)
        else:
            assert isinstance(loop.body[0], ir.Assignment)
            assert fgen(loop.body[0].lhs).lower() == f"var({', '.join(f'j{d+1}' for d in range(depth))})"
            assert loop.body[0].rhs == '0'
            assert isinstance(loop.body[1], ir.Pragma)
            assert loop.body[1].keyword == 'loki'
            assert loop.body[1].content.lower() == f"update device( var({', '.join(f'j{d+1}' for d in range(depth))}) )"


    routine = Subroutine.from_source(fcode, frontend=frontend)
    shape = as_tuple([RangeIndex((None, None))] * depth)
    var = Variable(name='var', type=SymbolAttributes(BasicType.INTEGER, shape=shape), dimensions=shape, scope=routine)

    routine.variables += (var,)
    assign = ir.Assignment(lhs=var.clone(dimensions=None), rhs=IntLiteral(0)) # pylint:disable=no-member
    pragma = ir.Pragma(keyword='loki', content='update device( var )')
    body = (assign, pragma)

    trafo = DataOffloadDeepcopyTransformation(mode='offload')
    loopnest = trafo.wrap_in_loopnest(var.clone(dimensions=None), body, routine) # pylint:disable=no-member

    # check loop variable was added to routine
    variables = routine.variables
    assert all(f'J{d+1}' in variables for d in range(depth))

    # check loops are correctly nested
    _check_loops(loopnest[0], depth, depth)


@pytest.mark.parametrize('rank', [1, 2, 3, 4, 5])
@pytest.mark.parametrize('suff', ['im', 'rb', 'rd', 'lm'])
def test_dummy_field_array_typdef_config(rank, suff):
    """Test the creation of a typedef config for ``FIELD_RANKSUFF_ARRAY`` types."""

    scope = Scope()
    typedef = ir.TypeDef(name=f'field_{rank}{suff}_array', parent=scope) # pylint: disable=unexpected-keyword-arg
    _type = SymbolAttributes(DerivedType(typedef=typedef))
    var = Variable(name='field_array', scope=scope, type=_type)

    trafo = DataOffloadDeepcopyTransformation(mode='offload')
    typedef_config = trafo.create_dummy_field_array_typedef_config(var)

    # check typedef config was created correctly
    ref_config = {
        'field_prefix': 'F_',
        'field_ptr_suffix': '_FIELD',
        'field_ptr_map': {}
    }
    assert typedef_config == ref_config


@pytest.mark.parametrize('frontend', available_frontends())
def test_offload_deepcopy_simple_driver(frontend, config, deepcopy_code, tmp_path):
    """
    Test the generation of host-device deepcopy for a simple driver loop.
    """

    config['routines'] = {
        'simple_driver': {'role': 'driver'},
        'other_variable_type': {
            'field_prefix': 'f_',
            'view_prefix': 'v',
            'field_ptr_map': {
                'vt0_field': 'f_t0'
            }
        },
    }

    scheduler = Scheduler(
        paths=deepcopy_code, config=config, frontend=frontend, xmods=[tmp_path],
        output_dir=tmp_path, preprocess=True
    )

    ######............ check analysis
    transformation = DataOffloadDeepcopyAnalysis(output_analysis=True)
    scheduler.process(transformation=transformation)

    # The analysis is tied to driver loops
    trafo_data_key = transformation._key
    driver_item = scheduler['simple_driver_mod#simple_driver']
    driver = scheduler['simple_driver_mod#simple_driver'].ir
    with pragmas_attached(driver, ir.Loop):
        driver_loop = find_driver_loops(driver.body, targets=[])[0]

    #stringify dict for comparison
    stringified_dict = transformation.stringify_dict(driver_item.trafo_data[trafo_data_key]['analysis'][driver_loop])
    expected_analysis = {
        'ngpblks' : 'read',
        'variable' : {
            'vt0_field' : 'write'
        }
    }
    assert stringified_dict == expected_analysis

    with open(tmp_path/'driver_simple_driver_dataoffload_analysis.yaml', 'r') as file:
        _dict = yaml.safe_load(file)
    assert _dict == expected_analysis

    ######............ check transformation
    transformation = DataOffloadDeepcopyTransformation(mode='offload')
    scheduler.process(transformation=transformation)

    pragmas = FindNodes(ir.Pragma).visit(driver.body)
    conds = FindNodes(ir.Conditional, greedy=True).visit(driver.body)
    calls = FindNodes(ir.CallStatement).visit(driver.body)
    # filter out target calls, as we only need to check generated boilerplate
    calls = [call for call in calls if not call.name.name.lower() in driver_item.targets]
    check_other_variable_type('offload', conds, calls, pragmas, driver)
loki-ecmwf-0.3.6/loki/transformations/data_offload/tests/test_field_offload.py0000664000175000017500000007570115167130205030075 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki import Sourcefile, Module
import loki.expression.symbols as sym
from loki.frontend import available_frontends, OMNI
from loki.ir import nodes as ir, FindNodes, Pragma, CallStatement
from loki.logging import log_levels
from loki.transformations import FieldOffloadTransformation, FieldOffloadBlockedTransformation
from loki.batch import TransformationError


@pytest.fixture(name="parkind_mod")
def fixture_parkind_mod(tmp_path, frontend):
    fcode = """
    module parkind1
      integer, parameter :: jprb=4
    end module
    """
    return Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])


@pytest.fixture(name="field_module")
def fixture_field_module(tmp_path, frontend):
    fcode = """
    module field_module
      implicit none

      type field_2rb
        real, pointer :: f_ptr(:,:,:)
      end type field_2rb

      type field_3rb
        real, pointer :: f_ptr(:,:,:)
     contains
        procedure :: update_view
      end type field_3rb

      type field_4rb
        real, pointer :: f_ptr(:,:,:)
     contains
        procedure :: update_view
      end type field_4rb

    contains
    subroutine update_view(self, idx)
      class(field_3rb), intent(in)  :: self
      integer, intent(in)           :: idx
    end subroutine
    end module
    """
    return Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])


@pytest.fixture(name="state_module")
def fixture_state_module(tmp_path, parkind_mod, field_module, frontend):  # pylint: disable=unused-argument
    fcode = """
    module state_mod
      use parkind1, only: jprb
      use field_module, only: field_2rb, field_3rb
      implicit none

      type state_type
        real(kind=jprb), dimension(10,10), pointer :: a, b, c
        real(kind=jprb), pointer :: d(10,10,10)
        class(field_3rb), pointer :: f_a, f_b, f_c
        class(field_4rb), pointer :: f_d
        contains
        procedure :: update_view => state_update_view
      end type state_type

    contains

      subroutine state_update_view(self, idx)
        class(state_type), intent(in) :: self
        integer, intent(in)           :: idx
      end subroutine
    end module state_mod
"""
    return Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])


@pytest.mark.parametrize('frontend', available_frontends())
def test_field_offload(frontend, state_module, tmp_path):
    fcode = """
    module driver_mod
      use state_mod, only: state_type
      use parkind1, only: jprb
      use field_module, only: field_2rb, field_3rb
      implicit none

    contains

      subroutine kernel_routine(nlon, nlev, a, b, c)
        integer, intent(in)             :: nlon, nlev
        real(kind=jprb), intent(in)     :: a(nlon,nlev)
        real(kind=jprb), intent(inout)  :: b(nlon,nlev)
        real(kind=jprb), intent(out)    :: c(nlon,nlev)
        integer :: i, j

        do j=1, nlon
          do i=1, nlev
            b(i,j) = a(i,j) + 0.1
            c(i,j) = 0.1
          end do
        end do
      end subroutine kernel_routine

      subroutine driver_routine(nlon, nlev, state)
        integer, intent(in)             :: nlon, nlev
        type(state_type), intent(inout) :: state
        integer                         :: i

        !$loki data
        do i=1,nlev
            call state%update_view(i)
            call kernel_routine(nlon, nlev, state%a, state%b, state%c)
        end do
        !$loki end data

      end subroutine driver_routine
    end module driver_mod
    """
    driver_mod = Module.from_source(
        fcode, frontend=frontend, definitions=state_module, xmods=[tmp_path]
    )
    driver = driver_mod['driver_routine']
    deviceptr_prefix = 'loki_devptr_prefix_'
    driver.apply(FieldOffloadTransformation(devptr_prefix=deviceptr_prefix,
                                            offload_index='i',
                                            field_group_types=['state_type']),
                 role='driver',
                 targets=['kernel_routine'])

    calls = FindNodes(CallStatement).visit(driver.body)
    kernel_call = next(c for c in calls if c.name=='kernel_routine')

    # verify that field offloads are generated properly
    in_calls = [c for c in calls if 'get_device_data_rdonly' in c.name.name.lower()]
    assert len(in_calls) == 1
    inout_calls = [c for c in calls if 'get_device_data_rdwr' in c.name.name.lower()]
    assert len(inout_calls) == 2
    # verify that field sync host calls are generated properly
    sync_calls = [c for c in calls if 'sync_host_rdwr' in c.name.name.lower()]
    assert len(sync_calls) == 2

    # verify that data offload pragmas remain
    pragmas = FindNodes(Pragma).visit(driver.body)
    assert len(pragmas) == 2
    assert all(p.keyword=='loki' and p.content==c for p, c in zip(pragmas, ['data', 'end data']))

    # verify that new pointer variables are created and used in driver calls
    for var in ['state_a', 'state_b', 'state_c']:
        name = deviceptr_prefix + var
        assert name in driver.variable_map
        devptr = driver.variable_map[name]
        assert isinstance(devptr, sym.Array)
        assert len(devptr.shape) == 3
        assert devptr.name in (arg.name for arg in kernel_call.arguments)


@pytest.mark.parametrize('frontend', available_frontends())
def test_field_offload_slices(frontend, parkind_mod, field_module, tmp_path):  # pylint: disable=unused-argument
    fcode = """
    module driver_mod
      use parkind1, only: jprb
      use field_module, only: field_4rb
      implicit none

      type state_type
        real(kind=jprb), dimension(10,10,10), pointer :: a, b, c, d
        class(field_4rb), pointer :: f_a, f_b, f_c, f_d
        contains
        procedure :: update_view => state_update_view
      end type state_type

    contains

      subroutine state_update_view(self, idx)
        class(state_type), intent(in) :: self
        integer, intent(in)           :: idx
      end subroutine

      subroutine kernel_routine(nlon, nlev, a, b, c, d)
        integer, intent(in)             :: nlon, nlev
        real(kind=jprb), intent(in)     :: a(nlon,nlev,nlon)
        real(kind=jprb), intent(inout)  :: b(nlon,nlev)
        real(kind=jprb), intent(out)    :: c(nlon)
        real(kind=jprb), intent(in)     :: d(nlon,nlev,nlon)
        integer :: i, j
      end subroutine kernel_routine

      subroutine driver_routine(nlon, nlev, state)
        integer, intent(in)             :: nlon, nlev
        type(state_type), intent(inout) :: state
        integer                         :: i
        !$loki data
        do i=1,nlev
            call kernel_routine(nlon, nlev, state%a(:,:,1), state%b(:,1,1), state%c(1,1,1), state%d)
        end do
        !$loki end data

      end subroutine driver_routine
    end module driver_mod
    """
    driver_mod = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])['driver_mod']
    driver = driver_mod['driver_routine']
    deviceptr_prefix = 'loki_devptr_prefix_'
    driver.apply(FieldOffloadTransformation(devptr_prefix=deviceptr_prefix,
                                            offload_index='i',
                                            field_group_types=['state_type']),
                 role='driver',
                 targets=['kernel_routine'])

    calls = FindNodes(CallStatement).visit(driver.body)
    kernel_call = next(c for c in calls if c.name=='kernel_routine')
    # verify that new pointer variables are created and used in driver calls
    for var, rank in zip(['state_d', 'state_a', 'state_b', 'state_c',], [4, 3, 2, 1]):
        name = deviceptr_prefix + var
        assert name in driver.variable_map
        devptr = driver.variable_map[name]
        assert isinstance(devptr, sym.Array)
        assert len(devptr.shape) == 4
        assert devptr.name in (arg.name for arg in kernel_call.arguments)
        arg = next(arg for arg in kernel_call.arguments if devptr.name in arg.name)
        assert arg.dimensions == ((sym.RangeIndex((None,None)),)*(rank-1) +
                                 (sym.IntLiteral(1),)*(4-rank) +
                                 (sym.Scalar(name='i'),))


@pytest.mark.parametrize('frontend', available_frontends())
def test_field_offload_multiple_calls(frontend, state_module, tmp_path):
    fcode = """
    module driver_mod
      use parkind1, only: jprb
      use state_mod, only: state_type
      implicit none

    contains

      subroutine kernel_routine(nlon, nlev, a, b, c)
        integer, intent(in)             :: nlon, nlev
        real(kind=jprb), intent(in)     :: a(nlon,nlev)
        real(kind=jprb), intent(inout)  :: b(nlon,nlev)
        real(kind=jprb), intent(out)    :: c(nlon,nlev)
        integer :: i, j

        do j=1, nlon
          do i=1, nlev
            b(i,j) = a(i,j) + 0.1
            c(i,j) = 0.1
          end do
        end do
      end subroutine kernel_routine

      subroutine driver_routine(nlon, nlev, state)
        integer, intent(in)             :: nlon, nlev
        type(state_type), intent(inout) :: state
        integer                         :: i

        !$loki data
        do i=1,nlev
            call state%update_view(i)

            call kernel_routine(nlon, nlev, state%a, state%b, state%c)

            call kernel_routine(nlon, nlev, state%a, state%b, state%c)
        end do
        !$loki end data

      end subroutine driver_routine
    end module driver_mod
    """

    driver_mod = Module.from_source(
        fcode, frontend=frontend, definitions=state_module, xmods=[tmp_path]
    )
    driver = driver_mod['driver_routine']
    deviceptr_prefix = 'loki_devptr_prefix_'
    driver.apply(FieldOffloadTransformation(devptr_prefix=deviceptr_prefix,
                                            offload_index='i',
                                            field_group_types=['state_type']),
                 role='driver',
                 targets=['kernel_routine'])
    calls = FindNodes(CallStatement).visit(driver.body)
    kernel_calls = [c for c in calls if c.name=='kernel_routine']

    # verify that field offloads are generated properly
    in_calls = [c for c in calls if 'get_device_data_rdonly' in c.name.name.lower()]
    assert len(in_calls) == 1
    inout_calls = [c for c in calls if 'get_device_data_rdwr' in c.name.name.lower()]
    assert len(inout_calls) == 2
    # verify that field sync host calls are generated properly
    sync_calls = [c for c in calls if 'sync_host_rdwr' in c.name.name.lower()]
    assert len(sync_calls) == 2

    # verify that data offload pragmas remain
    pragmas = FindNodes(Pragma).visit(driver.body)
    assert len(pragmas) == 2
    assert all(p.keyword=='loki' and p.content==c for p, c in zip(pragmas, ['data', 'end data']))

    # verify that new pointer variables are created and used in driver calls
    for var in ['state_a', 'state_b', 'state_c']:
        name = deviceptr_prefix + var
        assert name in driver.variable_map
        devptr = driver.variable_map[name]
        assert isinstance(devptr, sym.Array)
        assert len(devptr.shape) == 3
        assert devptr.name in (arg.name for kernel_call in kernel_calls for arg in kernel_call.arguments)


@pytest.mark.parametrize('frontend', available_frontends())
def test_field_offload_unknown_kernel(caplog, frontend, state_module, tmp_path):
    fother = """
    module another_module
      implicit none
    contains
      subroutine another_kernel(nlon, nlev, a, b, c)
        integer, intent(in)             :: nlon, nlev
        real, intent(in)     :: a(nlon,nlev)
        real, intent(inout)  :: b(nlon,nlev)
        real, intent(out)    :: c(nlon,nlev)
        integer :: i, j
      end subroutine
    end module
    """

    fcode = """
    module driver_mod
      use parkind1, only: jprb
      use state_mod, only: state_type
      use another_module, only: another_kernel
      implicit none

    contains

      subroutine driver_routine(nlon, nlev, state)
        integer, intent(in)             :: nlon, nlev
        type(state_type), intent(inout) :: state
        integer                         :: i

        !$loki data
        do i=1,nlev
            call state%update_view(i)
            call another_kernel(nlon, nlev, state%a, state%b, state%c)
        end do
        !$loki end data

      end subroutine driver_routine
    end module driver_mod
    """

    Sourcefile.from_source(fother, frontend=frontend, xmods=[tmp_path])
    driver_mod = Module.from_source(
        fcode, frontend=frontend, definitions=state_module, xmods=[tmp_path]
    )
    driver = driver_mod['driver_routine']
    deviceptr_prefix = 'loki_devptr_prefix_'

    field_offload_trafo = FieldOffloadTransformation(devptr_prefix=deviceptr_prefix,
                                                         offload_index='i',
                                                         field_group_types=['state_type'])
    caplog.clear()
    with caplog.at_level(log_levels['WARNING']):
        driver.apply(field_offload_trafo, role='driver', targets=['another_kernel'])
        assert len(caplog.records) == 1
        assert ('[Loki] Data offload: Routine driver_routine has not been enriched '+
                'in another_kernel') in caplog.records[0].message


@pytest.mark.parametrize('frontend', available_frontends())
def test_field_offload_warnings(caplog, frontend, state_module, tmp_path):
    fother_state = """
    module state_type_mod
      implicit none
      type state_type2
        real, dimension(10,10), pointer :: a, b, c
      contains
        procedure :: update_view => state_update_view
      end type state_type2

    contains

      subroutine state_update_view(self, idx)
        class(state_type2), intent(in) :: self
        integer, intent(in)           :: idx
      end subroutine
    end module
    """

    fother_mod= """
    module another_module
      implicit none
    contains
      subroutine another_kernel(nlon, nlev, a, b, c)
        integer, intent(in)             :: nlon, nlev
        real, intent(in)     :: a(nlon,nlev)
        real, intent(inout)  :: b(nlon,nlev)
        real, intent(out)    :: c(nlon,nlev)
        integer :: i, j
      end subroutine
    end module
    """

    fcode = """
    module driver_mod
      use state_type_mod, only: state_type2
      use parkind1, only: jprb
      use state_mod, only: state_type
      use another_module, only: another_kernel

      implicit none

    contains

      subroutine kernel_routine(nlon, nlev, a, b, c)
        integer, intent(in)             :: nlon, nlev
        real(kind=jprb), intent(in)     :: a(nlon,nlev)
        real(kind=jprb), intent(inout)  :: b(nlon,nlev)
        real(kind=jprb), intent(out)    :: c(nlon,nlev)
        integer :: i, j

        do j=1, nlon
          do i=1, nlev
            b(i,j) = a(i,j) + 0.1
            c(i,j) = 0.1
          end do
        end do
      end subroutine kernel_routine

      subroutine driver_routine(nlon, nlev, state, state2)
        integer, intent(in)             :: nlon, nlev
        type(state_type), intent(inout) :: state
        type(state_type2), intent(inout) :: state2

        integer                         :: i
        real(kind=jprb)                 :: a(nlon,nlev)
        real, pointer                   :: loki_devptr_prefix_state_b

        !$loki data
        do i=1,nlev
            call state%update_view(i)
            call kernel_routine(nlon, nlev, a, state%b, state2%c)
        end do
        !$loki end data

      end subroutine driver_routine
    end module driver_mod
    """
    Sourcefile.from_source(fother_state, frontend=frontend, xmods=[tmp_path])
    Sourcefile.from_source(fother_mod, frontend=frontend, xmods=[tmp_path])
    driver_mod = Sourcefile.from_source(
        fcode, frontend=frontend, definitions=state_module, xmods=[tmp_path]
    )['driver_mod']
    driver = driver_mod['driver_routine']
    deviceptr_prefix = 'loki_devptr_prefix_'

    field_offload_trafo = FieldOffloadTransformation(devptr_prefix=deviceptr_prefix,
                                                         offload_index='i',
                                                         field_group_types=['state_type'])
    caplog.clear()
    with caplog.at_level(log_levels['WARNING']):
        driver.apply(field_offload_trafo, role='driver', targets=['kernel_routine'])
        assert len(caplog.records) == 3
        assert (('[Loki] Data offload: Raw array object a encountered in'
                 +' driver_routine that is not wrapped by a Field API object')
                in caplog.records[0].message)
        assert ('[Loki] Data offload: The parent object state2 of type state_type2 is not in the' +
                ' list of field wrapper types') in caplog.records[1].message
        assert ('[Loki] Data offload: The routine driver_routine already has a' +
                ' variable named loki_devptr_prefix_state_b') in caplog.records[2].message


@pytest.mark.parametrize('frontend', available_frontends())
def test_field_offload_aliasing(frontend, state_module, tmp_path):
    fcode = """
    module driver_mod
      use state_mod, only: state_type
      use parkind1, only: jprb
      implicit none

    contains

      subroutine kernel_routine(nlon, nlev, a1, a2, a3)
        integer, intent(in)             :: nlon, nlev
        real(kind=jprb), intent(in)     :: a1(nlon)
        real(kind=jprb), intent(inout)  :: a2(nlon)
        real(kind=jprb), intent(out)    :: a3(nlon)
        integer :: i

        do i=1, nlon
          a1(i) = a2(i) + 0.1
          a3(i) = 0.1
        end do
      end subroutine kernel_routine

      subroutine driver_routine(nlon, nlev, state)
        integer, intent(in)             :: nlon, nlev
        type(state_type), intent(inout) :: state
        integer                         :: i

        !$loki data
        do i=1,nlev
            call state%update_view(i)
            call kernel_routine(nlon, nlev, state%a(:,1), state%a(:,2), state%a(:,3))
        end do
        !$loki end data

      end subroutine driver_routine
    end module driver_mod
    """
    driver_mod = Module.from_source(
        fcode, frontend=frontend, definitions=state_module, xmods=[tmp_path]
    )
    driver = driver_mod['driver_routine']

    field_offload = FieldOffloadTransformation(
        devptr_prefix='', offload_index='i', field_group_types=['state_type']
    )
    driver.apply(field_offload, role='driver', targets=['kernel_routine'])

    calls = FindNodes(ir.CallStatement).visit(driver.body)
    kernel_call = next(c for c in calls if c.name=='kernel_routine')

    assert 'state_a' in driver.variable_map
    assert driver.variable_map['state_a'].type.shape == (':', ':', ':')

    assert kernel_call.arguments[:2] == ('nlon', 'nlev')
    assert kernel_call.arguments[2] == 'state_a(:,1,i)'
    assert kernel_call.arguments[3] == 'state_a(:,2,i)'
    assert kernel_call.arguments[4] == 'state_a(:,3,i)'

    assert len(calls) == 3
    assert calls[0].name == 'state%f_a%get_device_data_rdwr'
    assert calls[0].arguments == ('state_a',)
    assert calls[1] == kernel_call
    assert calls[2].name == 'state%f_a%sync_host_rdwr'
    assert calls[2].arguments == ()

    decls = FindNodes(ir.VariableDeclaration).visit(driver.spec)
    assert len(decls) == 5 if frontend == OMNI else 4
    assert decls[-1].symbols == ('state_a(:,:,:)',)


@pytest.mark.parametrize('frontend', available_frontends())
def test_field_offload_driver_compute(frontend, state_module, tmp_path):
    fcode = """
    module driver_mod
      use state_mod, only: state_type
      use parkind1, only: jprb
      implicit none

    contains

      subroutine driver_routine(nlon, nlev, state)
        integer, intent(in)             :: nlon, nlev
        type(state_type), intent(inout) :: state
        integer                         :: i, ibl

        !$loki data
        do ibl=1,nlev
          call state%update_view(ibl)
          do i=1, nlon
            state%a(i, 1) = state%b(i, 1) + 0.1
            state%a(i, 2) = state%a(i, 1)
          end do

        end do
        !$loki end data

      end subroutine driver_routine
    end module driver_mod
    """
    driver_mod = Module.from_source(
        fcode, frontend=frontend, definitions=state_module, xmods=[tmp_path]
    )
    driver = driver_mod['driver_routine']

    calls = FindNodes(ir.CallStatement).visit(driver.body)
    assert len(calls) == 1
    assert calls[0].name == 'state%update_view'

    field_offload = FieldOffloadTransformation(
        devptr_prefix='', offload_index='ibl', field_group_types=['state_type']
    )
    driver.apply(field_offload, role='driver', targets=['kernel_routine'])

    calls = FindNodes(ir.CallStatement).visit(driver.body)
    assert len(calls) == 3
    assert calls[0].name == 'state%f_b%get_device_data_rdonly'
    assert calls[0].arguments == ('state_b',)
    assert calls[1].name == 'state%f_a%get_device_data_rdwr'
    assert calls[1].arguments == ('state_a',)
    assert calls[2].name == 'state%f_a%sync_host_rdwr'
    assert calls[2].arguments == ()

    assigns = FindNodes(ir.Assignment).visit(driver.body)
    assert len(assigns) == 2
    assert assigns[0].lhs == 'state_a(i,1,ibl)'
    assert assigns[0].rhs == 'state_b(i,1,ibl) + 0.1'
    assert assigns[1].lhs == 'state_a(i,2,ibl)'
    assert assigns[1].rhs == 'state_a(i,1,ibl)'


@pytest.mark.parametrize('frontend', available_frontends())
def test_field_offload_blocked(frontend, state_module, tmp_path):
    fcode = """
    module driver_mod
      use state_mod, only: state_type
      use parkind1, only: jprb
      use field_module, only: field_2rb, field_3rb
      implicit none

    contains

      subroutine kernel_routine(nlon, nlev, a, b, c)
        integer, intent(in)             :: nlon, nlev
        real(kind=jprb), intent(in)     :: a(nlon,nlev)
        real(kind=jprb), intent(inout)  :: b(nlon,nlev)
        real(kind=jprb), intent(out)    :: c(nlon,nlev)
        integer :: i, j

        do j=1, nlon
          do i=1, nlev
            b(i,j) = a(i,j) + 0.1
            c(i,j) = 0.1
          end do
        end do
      end subroutine kernel_routine

      subroutine driver_routine(nlon, nlev, state)
        integer, intent(in)             :: nlon, nlev
        type(state_type), intent(inout) :: state
        integer                         :: i

        !$loki data
        !$loki driver-loop
        do i=1,nlev
            call state%update_view(i)
            call kernel_routine(nlon, nlev, state%a, state%b, state%c)
        end do
        !$loki end data

      end subroutine driver_routine
    end module driver_mod
    """
    driver_mod = Module.from_source(
        fcode, frontend=frontend, definitions=state_module, xmods=[tmp_path]
    )
    driver = driver_mod['driver_routine']
    deviceptr_prefix = 'loki_devptr_prefix_'
    driver.apply(FieldOffloadBlockedTransformation(devptr_prefix=deviceptr_prefix,
                                                   offload_index='i',
                                                   field_group_types=['state_type'],
                                                   block_size=100),
                 role='driver',
                 targets=['kernel_routine'])

    calls = FindNodes(CallStatement).visit(driver.body)
    kernel_call = next(c for c in calls if c.name=='kernel_routine')

    # verify that field offloads are generated properly
    in_calls = [c for c in calls if 'get_device_data_force' in c.name.name.lower()]
    assert len(in_calls) == 3
    # verify that field sync host calls are generated properly
    sync_calls = [c for c in calls if 'sync_host_force' in c.name.name.lower()]
    assert len(sync_calls) == 2

    # verify that data offload pragmas remain
    pragmas = FindNodes(Pragma).visit(driver.body)
    assert len(pragmas) == 3
    assert all(p.keyword=='loki' and p.content==c for p, c in zip(pragmas, ['data', 'driver-loop', 'end data']))

    # verify that new pointer variables are created and used in driver calls
    for var in ['state_a', 'state_b', 'state_c']:
        name = deviceptr_prefix + var
        assert name in driver.variable_map
        devptr = driver.variable_map[name]
        assert isinstance(devptr, sym.Array)
        assert len(devptr.shape) == 3
        assert devptr.name in (arg.name for arg in kernel_call.arguments)


@pytest.mark.parametrize('frontend', available_frontends())
def test_field_offload_blocked_async(frontend, state_module, tmp_path):
    fcode = """
    module driver_mod
      use state_mod, only: state_type
      use parkind1, only: jprb
      use field_module, only: field_2rb, field_3rb
      implicit none

    contains

      subroutine kernel_routine(nlon, nlev, a, b, c)
        integer, intent(in)             :: nlon, nlev
        real(kind=jprb), intent(in)     :: a(nlon,nlev)
        real(kind=jprb), intent(inout)  :: b(nlon,nlev)
        real(kind=jprb), intent(out)    :: c(nlon,nlev)
        integer :: i, j

        do j=1, nlon
          do i=1, nlev
            b(i,j) = a(i,j) + 0.1
            c(i,j) = 0.1
          end do
        end do
      end subroutine kernel_routine

      subroutine driver_routine(nlon, nlev, state)
        integer, intent(in)             :: nlon, nlev
        type(state_type), intent(inout) :: state
        integer                         :: i

        !$loki data
        !$loki driver-loop
        do i=1,nlev
            call state%update_view(i)
            call kernel_routine(nlon, nlev, state%a, state%b, state%c)
        end do
        !$loki end data

      end subroutine driver_routine
    end module driver_mod
    """
    driver_mod = Module.from_source(
        fcode, frontend=frontend, definitions=state_module, xmods=[tmp_path]
    )
    driver = driver_mod['driver_routine']
    deviceptr_prefix = 'loki_devptr_prefix_'
    driver.apply(FieldOffloadBlockedTransformation(devptr_prefix=deviceptr_prefix,
                                                   offload_index='i',
                                                   field_group_types=['state_type'],
                                                   block_size=100,
                                                   asynchronous=True,
                                                   num_queues=3),
                 role='driver',
                 targets=['kernel_routine'])

    calls = FindNodes(CallStatement).visit(driver.body)
    kernel_call = next(c for c in calls if c.name=='kernel_routine')

    # verify that field offloads are generated properly
    in_calls = [c for c in calls if 'get_device_data_force' in c.name.name.lower()]
    assert len(in_calls) == 3
    # verify that field sync host calls are generated properly
    sync_calls = [c for c in calls if 'sync_host_force' in c.name.name.lower()]
    assert len(sync_calls) == 2

    # verify that data offload pragmas remain
    pragmas = FindNodes(Pragma).visit(driver.body)
    assert len(pragmas) == 3
    assert all(p.keyword=='loki' and p.content==c for p, c in zip(pragmas,
                                                                  ['data async(loki_block_queue)',
                                                                   'driver-loop async(loki_block_queue)',
                                                                   'end data']))

    # verify that new pointer variables are created and used in driver calls
    for var in ['state_a', 'state_b', 'state_c']:
        name = deviceptr_prefix + var
        assert name in driver.variable_map
        devptr = driver.variable_map[name]
        assert isinstance(devptr, sym.Array)
        assert len(devptr.shape) == 3
        assert devptr.name in (arg.name for arg in kernel_call.arguments)


@pytest.mark.parametrize('frontend', available_frontends())
def test_field_offload_blocked_warnings(frontend, state_module, tmp_path, caplog):
    fcode = """
    module driver_mod
      use state_mod, only: state_type
      use parkind1, only: jprb
      use field_module, only: field_2rb, field_3rb
      implicit none

    contains

      subroutine kernel_routine(nlon, nlev, a, b, c)
        integer, intent(in)             :: nlon, nlev
        real(kind=jprb), intent(in)     :: a(nlon,nlev)
        real(kind=jprb), intent(inout)  :: b(nlon,nlev)
        real(kind=jprb), intent(out)    :: c(nlon,nlev)
        integer :: i, j

        do j=1, nlon
          do i=1, nlev
            b(i,j) = a(i,j) + 0.1
            c(i,j) = 0.1
          end do
        end do
      end subroutine kernel_routine

      subroutine driver_multiple_loops(nlon, nlev, state)
        integer, intent(in)             :: nlon, nlev
        type(state_type), intent(inout) :: state
        integer                         :: i

        !$loki data
        !$loki driver-loop
        do i=1,nlev
            call state%update_view(i)
            call kernel_routine(nlon, nlev, state%a, state%b, state%c)
        end do
        !$loki driver-loop
        do i=1,nlev
            call state%update_view(i)
            call kernel_routine(nlon, nlev, state%a, state%b, state%c)
        end do
        !$loki end data

      end subroutine driver_multiple_loops

      subroutine driver_no_driver_loop(nlon, nlev, state)
        integer, intent(in)             :: nlon, nlev
        type(state_type), intent(inout) :: state
        integer                         :: i

        !$loki data
        do i=1,nlev
            call state%update_view(i)
        end do
        !$loki end data

      end subroutine driver_no_driver_loop

    end module driver_mod
    """
    driver_mod = Module.from_source(
        fcode, frontend=frontend, definitions=state_module, xmods=[tmp_path]
    )
    driver_multiple = driver_mod['driver_multiple_loops']
    driver_no_loop = driver_mod['driver_no_driver_loop']
    deviceptr_prefix = 'loki_devptr_prefix_'

    # verify that warnings are raised properly
    with caplog.at_level(log_levels['WARNING']):
        offload_trafo = FieldOffloadBlockedTransformation(devptr_prefix=deviceptr_prefix,
                                                          offload_index='i',
                                                          field_group_types=['state_type'],
                                                          block_size=100,
                                                          asynchronous=1,
                                                          num_queues=3)
        assert any('[Loki] FieldOffloadBlockedTransformation: asynchronous kwarg must be a bool' +
                   ' asynchronous set to False' in r.message for r in caplog.records)

        caplog.clear()
        driver_multiple.apply(offload_trafo, role='driver', targets=['kernel_routine'])
        assert any('[Loki] FieldOffloadBlockedTransformation: Multiple driver loops found in ' +
                   'driver_multiple_loops, discarding all but first' in r.message for r in caplog.records)

    caplog.clear()
    with caplog.at_level(log_levels['ERROR']):
        with pytest.raises(TransformationError):
            driver_no_loop.apply(offload_trafo, role='driver', targets=['kernel_routine'])
            assert any('[Loki] FieldOffloadBlockedTransformation: No driver loops found in ' +
                    'driver_no_driver_loops' in r.message for r in caplog.records)
loki-ecmwf-0.3.6/loki/transformations/data_offload/tests/test_global_var.py0000664000175000017500000006213615167130205027426 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from pathlib import Path
import pytest

from loki import Scheduler, FindInlineCalls
from loki.frontend import available_frontends, OMNI
from loki.ir import FindNodes, Pragma, CallStatement, Import

from loki.transformations import (
    GlobalVariableAnalysis, GlobalVarOffloadTransformation,
    GlobalVarHoistTransformation, PragmaModelTransformation
)


@pytest.fixture(scope='module', name='here')
def fixture_here():
    return Path(__file__).parent


@pytest.fixture(name='config')
def fixture_config():
    """
    Default configuration dict with basic options.
    """
    return {
        'default': {
            'mode': 'idem',
            'role': 'kernel',
            'expand': True,
            'strict': True,
            'enable_imports': True,
        },
    }


@pytest.fixture(name='global_variable_analysis_code')
def fixture_global_variable_analysis_code(tmp_path):
    fcode = {
        #------------------------------
        'global_var_analysis_header_mod': (
        #------------------------------
"""
module global_var_analysis_header_mod
    implicit none

    integer, parameter :: nval = 5
    integer, parameter :: nfld = 3

    integer :: n

    integer :: iarr(nfld)
    real :: rarr(nval, nfld)
end module global_var_analysis_header_mod
"""
        ).strip(),
        #----------------------------
        'global_var_analysis_data_mod': (
        #----------------------------
"""
module global_var_analysis_data_mod
    implicit none

    real, allocatable :: rdata(:,:,:)

    type some_type
        real :: val
        real, allocatable :: vals(:,:)
    end type some_type

    type(some_type) :: tt

contains
    subroutine some_routine(i)
        integer, intent(inout) :: i
        i = i + 1
    end subroutine some_routine
end module global_var_analysis_data_mod
"""
        ).strip(),
        #------------------------------
        'global_var_analysis_kernel_mod': (
        #------------------------------
"""
module global_var_analysis_kernel_mod
    use global_var_analysis_header_mod, only: rarr
    use global_var_analysis_data_mod, only: some_routine, some_type

    implicit none

contains
    subroutine kernel_a(arg, tt)
        use global_var_analysis_header_mod, only: iarr, nval, nfld, n

        real, intent(inout) :: arg(:,:)
        type(some_type), intent(in) :: tt
        real :: tmp(n)
        integer :: i, j

        do i=1,nfld
            if (iarr(i) > 0) then
                do j=1,nval
                    arg(j,i) = rarr(j, i) + tt%val
                    call some_routine(arg(j,i))
                enddo
            endif
        enddo
    end subroutine kernel_a

    subroutine kernel_b(arg)
        use global_var_analysis_header_mod, only: iarr, nfld
        use global_var_analysis_data_mod, only: rdata, tt

        real, intent(inout) :: arg(:,:)
        integer :: i

        do i=1,nfld
            if (iarr(i) .ne. 0) then
                rdata(:,:,i) = arg(:,:) + rdata(:,:,i)
            else
                arg(:,:) = tt%vals(:,:)
            endif
        enddo
    end subroutine kernel_b
end module global_var_analysis_kernel_mod
"""
        ).strip(),
        #-------
        'driver': (
        #-------
"""
subroutine driver(arg)
    use global_var_analysis_kernel_mod, only: kernel_a, kernel_b
    use global_var_analysis_data_mod, only: tt
    implicit none

    real, intent(inout) :: arg(:,:)

    !$loki update_device

    call kernel_a(arg, tt)

    call kernel_b(arg)

    !$loki update_host
end subroutine driver
"""
        ).strip()
    }

    for name, code in fcode.items():
        (tmp_path/f'{name}.F90').write_text(code)
    return tmp_path


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('key', (None, 'foobar'))
def test_global_variable_analysis(frontend, key, config, global_variable_analysis_code):
    config['routines'] = {
        'driver': {'role': 'driver'}
    }

    scheduler = Scheduler(
        paths=(global_variable_analysis_code,), config=config, seed_routines='driver',
        frontend=frontend, xmods=(global_variable_analysis_code,)
    )
    scheduler.process(GlobalVariableAnalysis(key=key))
    if key is None:
        key = GlobalVariableAnalysis._key

    # Validate the analysis trafo_data

    # OMNI handles array indices and parameters differently
    if frontend == OMNI:
        nfld_dim = '3'
        nval_dim = '5'
        nfld_data = set()
        nval_data = set()
    else:
        nfld_dim = 'nfld'
        nval_dim = 'nval'
        nfld_data = {('nfld', 'global_var_analysis_header_mod')}
        nval_data = {('nval', 'global_var_analysis_header_mod')}

    expected_trafo_data = {
        'global_var_analysis_header_mod': {
            'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})', 'n'},
            'offload': {}
        },
        'global_var_analysis_data_mod': {
            'declares': {'rdata(:, :, :)', 'tt'},
            'offload': {}
        },
        'global_var_analysis_data_mod#some_routine': {'defines_symbols': set(), 'uses_symbols': set()},
        'global_var_analysis_kernel_mod#kernel_a': {
            'defines_symbols': set(),
            'uses_symbols': nval_data | nfld_data | {
                (f'iarr({nfld_dim})', 'global_var_analysis_header_mod'),
                ('n', 'global_var_analysis_header_mod'),
                (f'rarr({nval_dim}, {nfld_dim})', 'global_var_analysis_header_mod')
            }
        },
        'global_var_analysis_kernel_mod#kernel_b': {
            'defines_symbols': {('rdata(:, :, :)', 'global_var_analysis_data_mod')},
            'uses_symbols': nfld_data | {
                ('rdata(:, :, :)', 'global_var_analysis_data_mod'), ('tt%vals', 'global_var_analysis_data_mod'),
                (f'iarr({nfld_dim})', 'global_var_analysis_header_mod')
            }
        },
        '#driver': {
            'defines_symbols': {('rdata(:, :, :)', 'global_var_analysis_data_mod')},
            'uses_symbols': nval_data | nfld_data | {
                ('rdata(:, :, :)', 'global_var_analysis_data_mod'),
                ('n', 'global_var_analysis_header_mod'),
                ('tt', 'global_var_analysis_data_mod'), ('tt%vals', 'global_var_analysis_data_mod'),
                (f'iarr({nfld_dim})', 'global_var_analysis_header_mod'),
                (f'rarr({nval_dim}, {nfld_dim})', 'global_var_analysis_header_mod')
            }
        }
    }

    assert set(scheduler.items) == set(expected_trafo_data) | {'global_var_analysis_data_mod#some_type'}
    for item in scheduler.items:
        if item == 'global_var_analysis_data_mod#some_type':
            continue
        for trafo_data_key, trafo_data_value in item.trafo_data[key].items():
            assert (
                sorted(
                    tuple(str(vv) for vv in v) if isinstance(v, tuple) else str(v)
                    for v in trafo_data_value
                ) == sorted(expected_trafo_data[item.name][trafo_data_key])
            )


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('key', (None, 'foobar'))
def test_global_variable_offload(frontend, key, config, global_variable_analysis_code):

    config['routines'] = {
        'driver': {'role': 'driver'}
    }

    # OMNI handles array indices and parameters differently
    if frontend == OMNI:
        nfld_dim = '3'
        nval_dim = '5'
    else:
        nfld_dim = 'nfld'
        nval_dim = 'nval'

    scheduler = Scheduler(
        paths=(global_variable_analysis_code,), config=config, seed_routines='driver',
        frontend=frontend, xmods=(global_variable_analysis_code,)
    )
    scheduler.process(GlobalVariableAnalysis(key=key))
    scheduler.process(GlobalVarOffloadTransformation(key=key))
    scheduler.process(PragmaModelTransformation(directive='openacc'))
    driver = scheduler['#driver'].ir

    if key is None:
        key = GlobalVariableAnalysis._key

    expected_trafo_data = {
        'global_var_analysis_header_mod': {
            'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})', 'n'},
            'offload': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})', 'n'}
        },
        'global_var_analysis_data_mod': {
            'declares': {'rdata(:, :, :)', 'tt'},
            'offload': {'rdata(:, :, :)', 'tt%vals'}
        },
    }

    # Verify module offload sets
    for item in [scheduler['global_var_analysis_header_mod'], scheduler['global_var_analysis_data_mod']]:
        for trafo_data_key, trafo_data_value in item.trafo_data[key].items():
            assert (
                sorted(
                    tuple(str(vv) for vv in v) if isinstance(v, tuple) else str(v)
                    for v in trafo_data_value
                ) == sorted(expected_trafo_data[item.name][trafo_data_key])
            )

    # Verify imports have been added to the driver
    expected_imports = {
        'global_var_analysis_header_mod': {'iarr', 'rarr', 'n'},
        'global_var_analysis_data_mod': {'rdata'}
    }

    # We need to check only the first imports as they have to be prepended
    for import_ in driver.imports[:len(expected_imports)]:
        assert {var.name.lower() for var in import_.symbols} == expected_imports[import_.module.lower()]

    expected_h2d_pragmas = {
        'update device': {'iarr', 'rdata', 'rarr', 'n'},
        'enter data copyin': {'tt%vals'}
    }
    expected_d2h_pragmas = {
        'update self': {'rdata'}
    }

    acc_pragmas = [p for p in FindNodes(Pragma).visit(driver.ir) if p.keyword.lower() == 'acc']
    assert len(acc_pragmas) == len(expected_h2d_pragmas) + len(expected_d2h_pragmas)
    for pragma in acc_pragmas[:len(expected_h2d_pragmas)]:
        command, variables = pragma.content.lower().split('(')
        assert command.strip() in expected_h2d_pragmas
        assert set(variables.strip()[:-1].strip().split(', ')) == expected_h2d_pragmas[command.strip()]
    for pragma in acc_pragmas[len(expected_h2d_pragmas):]:
        command, variables = pragma.content.lower().split('(')
        assert command.strip() in expected_d2h_pragmas
        assert set(variables.strip()[:-1].strip().split(', ')) == expected_d2h_pragmas[command.strip()]

    # Verify declarations have been added to the header modules
    expected_declarations = {
        'global_var_analysis_header_mod': {'iarr', 'rarr', 'n'},
        'global_var_analysis_data_mod': {'rdata', 'tt'}
    }

    modules = {
        name: scheduler[name].ir for name in expected_declarations
    }

    for name, module in modules.items():
        acc_pragmas = [p for p in FindNodes(Pragma).visit(module.spec) if p.keyword.lower() == 'acc']
        variables = {
            v.strip()
            for pragma in acc_pragmas
            for v in pragma.content.lower().split('(')[-1].strip()[:-1].split(', ')
        }
        assert variables == expected_declarations[name]


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('directive', ['openacc', 'omp-gpu'])
def test_transformation_global_var_import(here, config, frontend, directive, tmp_path):
    """
    Test the generation of offload instructions of global variable imports.
    """
    config['routines'] = {
        'driver': {'role': 'driver'}
    }

    scheduler = Scheduler(paths=here/'sources/projGlobalVarImports', config=config, frontend=frontend, xmods=[tmp_path])
    scheduler.process(transformation=GlobalVariableAnalysis())
    scheduler.process(transformation=GlobalVarOffloadTransformation())
    scheduler.process(PragmaModelTransformation(directive=directive))

    driver = scheduler['#driver'].ir
    moduleA = scheduler['modulea'].ir
    moduleB = scheduler['moduleb'].ir
    moduleC = scheduler['modulec'].ir

    # check that global variables have been added to driver symbol table
    imports = FindNodes(Import).visit(driver.spec)
    assert len(imports) == 2
    assert imports[0].module != imports[1].module
    assert imports[0].symbols != imports[1].symbols
    for i in imports:
        assert len(i.symbols) == 2
        assert i.module.lower() in ('moduleb', 'modulec')
        assert set(s.name for s in i.symbols) in ({'var2', 'var3'}, {'var4', 'var5'})

    # check that existing acc pragmas have not been stripped and update device/update self added correctly
    pragmas = FindNodes(Pragma).visit(driver.body)
    assert len(pragmas) == 5

    if directive == 'openacc':

        assert pragmas[0].keyword.lower() == 'acc'
        assert 'update device' in pragmas[0].content
        assert 'var2' in pragmas[0].content
        assert 'var3' in pragmas[0].content

        assert pragmas[1].keyword.lower() == 'loki'
        assert 'omp-update-global-vars in(' in pragmas[1].content
        assert 'var2' in pragmas[1].content
        assert 'var3' in pragmas[1].content

        assert pragmas[2].keyword.lower() == 'acc'
        assert pragmas[2].content == 'serial'
        assert pragmas[3].keyword.lower() == 'acc'
        assert pragmas[3].content == 'end serial'

        assert pragmas[4].keyword.lower() == 'acc'
        assert 'update self' in pragmas[4].content
        assert 'var4' in pragmas[4].content
        assert 'var5' in pragmas[4].content

        # check that no declarations have been added for parameters
        pragmas = FindNodes(Pragma).visit(moduleA.spec)
        assert not pragmas

        # check for device-side declarations where appropriate
        pragmas = FindNodes(Pragma).visit(moduleB.spec)
        assert len(pragmas) == 1
        assert pragmas[0].keyword.lower() == 'acc'
        assert 'declare create' in pragmas[0].content
        assert 'var2' in pragmas[0].content
        assert 'var3' in pragmas[0].content

        pragmas = FindNodes(Pragma).visit(moduleC.spec)
        assert len(pragmas) == 1
        assert pragmas[0].keyword.lower() == 'acc'
        assert 'declare create' in pragmas[0].content
        assert 'var4' in pragmas[0].content
        assert 'var5' in pragmas[0].content

    if directive == 'omp-gpu':

        assert pragmas[0].keyword.lower() == 'omp'
        assert 'update to' in pragmas[0].content
        assert 'var2' in pragmas[0].content
        assert 'var3' in pragmas[0].content

        assert pragmas[1].keyword.lower() == 'omp'
        assert 'target enter data map(to:' in pragmas[1].content
        assert 'var2' in pragmas[1].content
        assert 'var3' in pragmas[1].content

        assert pragmas[4].keyword.lower() == 'omp'
        assert 'target update from' in pragmas[4].content
        assert 'var4' in pragmas[4].content
        assert 'var5' in pragmas[4].content

        # check that no declarations have been added for parameters
        pragmas = FindNodes(Pragma).visit(moduleA.spec)
        assert not pragmas

        # check for device-side declarations where appropriate
        pragmas = FindNodes(Pragma).visit(moduleB.spec)
        assert len(pragmas) == 1
        assert pragmas[0].keyword.lower() == 'omp'
        assert 'declare target' in pragmas[0].content
        assert 'var2' in pragmas[0].content
        assert 'var3' in pragmas[0].content

        pragmas = FindNodes(Pragma).visit(moduleC.spec)
        assert len(pragmas) == 1
        assert pragmas[0].keyword.lower() == 'omp'
        assert 'declare target' in pragmas[0].content
        assert 'var4' in pragmas[0].content
        assert 'var5' in pragmas[0].content


@pytest.mark.parametrize('frontend', available_frontends())
def test_transformation_global_var_import_derived_type(here, config, frontend, tmp_path):
    """
    Test the generation of offload instructions of derived-type global variable imports.
    """

    config['default']['enable_imports'] = True
    config['routines'] = {
        'driver_derived_type': {'role': 'driver'}
    }

    scheduler = Scheduler(paths=here/'sources/projGlobalVarImports', config=config, frontend=frontend, xmods=[tmp_path])
    scheduler.process(transformation=GlobalVariableAnalysis())
    scheduler.process(transformation=GlobalVarOffloadTransformation())
    scheduler.process(PragmaModelTransformation(directive='openacc'))

    driver = scheduler['#driver_derived_type'].ir
    module = scheduler['module_derived_type'].ir

    # check that global variables have been added to driver symbol table
    imports = FindNodes(Import).visit(driver.spec)
    assert len(imports) == 1
    assert len(imports[0].symbols) == 2
    assert imports[0].module.lower() == 'module_derived_type'
    assert set(s.name for s in imports[0].symbols) == {'p', 'p0'}

    # check that existing acc pragmas have not been stripped and update device/update self added correctly
    pragmas = FindNodes(Pragma).visit(driver.body)
    assert len(pragmas) == 5
    assert all(p.keyword.lower() == 'acc' for p in pragmas)

    assert 'enter data copyin' in pragmas[0].content
    assert 'p0%x' in pragmas[0].content
    assert 'p0%y' in pragmas[0].content
    assert 'p0%z' in pragmas[0].content
    assert 'p%n' in pragmas[0].content

    assert 'enter data create' in pragmas[1].content
    assert 'p%x' in pragmas[1].content
    assert 'p%y' in pragmas[1].content
    assert 'p%z' in pragmas[1].content

    assert pragmas[2].content == 'serial'
    assert pragmas[3].content == 'end serial'

    assert 'exit data copyout' in pragmas[4].content
    assert 'p%x' in pragmas[4].content
    assert 'p%y' in pragmas[4].content
    assert 'p%z' in pragmas[4].content

    # check for device-side declarations
    pragmas = FindNodes(Pragma).visit(module.spec)
    assert len(pragmas) == 1
    assert pragmas[0].keyword == 'acc'
    assert 'declare create' in pragmas[0].content
    assert 'p' in pragmas[0].content
    assert 'p0' in pragmas[0].content
    assert 'p_array' in pragmas[0].content
    # Note: g is not offloaded because it is not used by the kernel (albeit imported)


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('hoist_parameters', (False, True))
@pytest.mark.parametrize('ignore_modules', (None, ('moduleb',)))
def test_transformation_global_var_hoist(here, config, frontend, hoist_parameters, ignore_modules, tmp_path):
    """
    Test hoisting of global variable imports.
    """
    config['default']['enable_imports'] = True
    config['routines'] = {
        'driver': {'role': 'driver'}
    }

    scheduler = Scheduler(paths=here/'sources/projGlobalVarImports', config=config, frontend=frontend, xmods=[tmp_path])
    scheduler.process(transformation=GlobalVariableAnalysis())
    scheduler.process(transformation=GlobalVarHoistTransformation(hoist_parameters=hoist_parameters,
        ignore_modules=ignore_modules))
    scheduler.process(PragmaModelTransformation())

    driver = scheduler['#driver'].ir
    kernel0 = scheduler['#kernel0'].ir
    kernel_map = {key: scheduler[f'#{key}'].ir for key in ['kernel1', 'kernel2', 'kernel3']}
    some_func = scheduler['func_mod#some_func'].ir

    # symbols within each module
    expected_symbols = {'modulea': ['var0', 'var1'], 'moduleb': ['var2', 'var3'],
            'modulec': ['var4', 'var5']}
    # expected intent of those variables (if hoisted)
    var_intent_map = {'var0': 'in', 'var1': 'in', 'var2': 'in',
            'var3': 'in', 'var4': 'inout', 'var5': 'inout', 'tmp': None}
    # DRIVER
    imports = FindNodes(Import).visit(driver.spec)
    import_names = [_import.module.lower() for _import in imports]
    # check driver imports
    expected_driver_modules = ['modulec']
    expected_driver_modules += ['moduleb'] if ignore_modules is None else []
    # OMNI handles parameters differently, ModuleA only contains parameters
    if frontend != OMNI:
        expected_driver_modules += ['modulea'] if hoist_parameters else []
    assert len(imports) == len(expected_driver_modules)
    assert sorted(expected_driver_modules) == sorted(import_names)
    for _import in imports:
        assert sorted([sym.name for sym in _import.symbols]) == expected_symbols[_import.module.lower()]
    # check driver call
    driver_calls = FindNodes(CallStatement).visit(driver.body)
    expected_args = []
    for module in expected_driver_modules:
        expected_args.extend(expected_symbols[module])
    assert [arg.name for arg in driver_calls[0].arguments] == sorted(expected_args)

    originally = {'kernel1': ['modulea'], 'kernel2': ['moduleb'],
            'kernel3': ['moduleb', 'modulec']}
    # KERNEL0
    expected_vars = expected_args.copy()
    expected_vars.append('a')
    assert [arg.name for arg in kernel0.arguments] == sorted(expected_args)
    assert [arg.name for arg in kernel0.variables] == sorted(expected_vars)
    for var in kernel0.arguments:
        assert kernel0.variable_map[var.name.lower()].type.intent == var_intent_map[var.name.lower()]
        assert var.scope == kernel0
    kernel0_inline_calls = FindInlineCalls().visit(kernel0.body)
    for inline_call in kernel0_inline_calls:
        if ignore_modules is None:
            assert len(inline_call.arguments) == 1
            assert [arg.name for arg in inline_call.arguments] == ['var2']
            assert [arg.name for arg in some_func.arguments] == ['var2']
        else:
            assert len(inline_call.arguments) == 0
            assert len(some_func.arguments) == 0
    kernel0_calls = FindNodes(CallStatement).visit(kernel0.body)
    # KERNEL1 & KERNEL2 & KERNEL3
    for call in kernel0_calls:
        expected_args = []
        expected_imports = []
        kernel_expected_symbols = []
        for module in originally[call.routine.name]:
            # always, since at least 'some_func' is imported
            if call.routine.name == 'kernel1' and module == 'modulea':
                expected_imports.append(module)
                kernel_expected_symbols.append('some_func')
            if module in expected_driver_modules:
                expected_args.extend(expected_symbols[module])
            else:
                # already added
                if module != 'modulea':
                    expected_imports.append(module)
                kernel_expected_symbols.extend(expected_symbols[module])
        assert len(expected_args) == len(call.arguments)
        assert [arg.name for arg in call.arguments] == expected_args
        assert [arg.name for arg in kernel_map[call.routine.name].arguments] == expected_args
        for var in kernel_map[call.routine.name].variables:
            var_intent = kernel_map[call.routine.name].variable_map[var.name.lower()].type.intent
            assert var.scope == kernel_map[call.routine.name]
            assert var_intent == var_intent_map[var.name.lower()]
        if call.routine.name in ['kernel1', 'kernel2']:
            expected_args = ['tmp'] + expected_args
        assert [arg.name for arg in kernel_map[call.routine.name].variables] == expected_args
        kernel_imports = FindNodes(Import).visit(call.routine.spec)
        assert sorted([_import.module.lower() for _import in kernel_imports]) == sorted(expected_imports)
        imported_symbols = [] # _import.symbols for _import in kernel_imports]
        for _import in kernel_imports:
            imported_symbols.extend([sym.name.lower() for sym in _import.symbols])
        assert sorted(imported_symbols) == sorted(kernel_expected_symbols)


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('hoist_parameters', (False, True))
def test_transformation_global_var_derived_type_hoist(here, config, frontend, hoist_parameters, tmp_path):
    """
    Test hoisting of derived-type global variable imports.
    """

    config['default']['enable_imports'] = True
    config['routines'] = {
        'driver_derived_type': {'role': 'driver'}
    }

    scheduler = Scheduler(paths=here/'sources/projGlobalVarImports', config=config, frontend=frontend, xmods=[tmp_path])
    scheduler.process(transformation=GlobalVariableAnalysis())
    scheduler.process(transformation=GlobalVarHoistTransformation(hoist_parameters))
    scheduler.process(PragmaModelTransformation())

    driver = scheduler['#driver_derived_type'].ir
    kernel = scheduler['#kernel_derived_type'].ir

    # DRIVER
    imports = FindNodes(Import).visit(driver.spec)
    assert len(imports) == 1
    assert imports[0].module.lower() == 'module_derived_type'
    assert sorted([sym.name.lower() for sym in imports[0].symbols]) == sorted(['p', 'p_array', 'p0'])
    calls = FindNodes(CallStatement).visit(driver.body)
    assert len(calls) == 1
    # KERNEL
    assert [arg.name for arg in calls[0].arguments] == ['p', 'p0', 'p_array']
    assert [arg.name for arg in kernel.arguments] == ['p', 'p0', 'p_array']
    kernel_imports = FindNodes(Import).visit(kernel.spec)
    assert len(kernel_imports) == 1
    assert [sym.name.lower() for sym in kernel_imports[0].symbols] == ['g']
    assert sorted([var.name for var in kernel.variables]) == ['i', 'j', 'p', 'p0', 'p_array']
    assert kernel.variable_map['p_array'].type.allocatable
    assert kernel.variable_map['p_array'].type.intent == 'inout'
    assert kernel.variable_map['p_array'].type.dtype.name == 'point'
    assert kernel.variable_map['p'].type.intent == 'inout'
    assert kernel.variable_map['p'].type.dtype.name == 'point'
    assert kernel.variable_map['p0'].type.intent == 'in'
    assert kernel.variable_map['p0'].type.dtype.name == 'point'
loki-ecmwf-0.3.6/loki/transformations/data_offload/tests/sources/0000775000175000017500000000000015167130205025360 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/transformations/data_offload/tests/sources/projGlobalVarImports/0000775000175000017500000000000015167130205031502 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/transformations/data_offload/tests/sources/projGlobalVarImports/driver.F900000664000175000017500000000021715167130205033255 0ustar  alastairalastairsubroutine driver()
implicit none

!$loki update_device
!$acc serial
call kernel0()
!$acc end serial
!$loki update_host

end subroutine driver
loki-ecmwf-0.3.6/loki/transformations/data_offload/tests/sources/projGlobalVarImports/moduleA.F900000664000175000017500000000024215167130205033346 0ustar  alastairalastairmodule moduleA
   real, parameter :: var0 = 0.
   real, parameter :: var1 = 0.
contains
   real function some_func()
   end function some_func
end module moduleA
loki-ecmwf-0.3.6/loki/transformations/data_offload/tests/sources/projGlobalVarImports/moduleB.F900000664000175000017500000000010215167130205033342 0ustar  alastairalastairmodule moduleB
   real :: var2
   real :: var3
end module moduleB
loki-ecmwf-0.3.6/loki/transformations/data_offload/tests/sources/projGlobalVarImports/functions.F900000664000175000017500000000026515167130205033775 0ustar  alastairalastairmodule func_mod
implicit none
contains

  real function some_func()
    use moduleB, only: var2
    implicit none
    some_func = var2
  end function some_func

end module func_mod
././@LongLink0000644000000000000000000000015600000000000011605 Lustar  rootrootloki-ecmwf-0.3.6/loki/transformations/data_offload/tests/sources/projGlobalVarImports/module_derived_type.F90loki-ecmwf-0.3.6/loki/transformations/data_offload/tests/sources/projGlobalVarImports/module_derived0000664000175000017500000000055315167130205034417 0ustar  alastairalastairmodule module_derived_type

   type point
      integer :: n
      real, allocatable :: x(:)
      real, allocatable :: y(:)
      real, allocatable :: z(:)
   end type point

   type grid
      type(point), allocatable :: p(:)
   end type grid

   type(point) :: p, p0
   type(point), allocatable :: p_array(:)
   type(grid) :: g

end module module_derived_type
././@LongLink0000644000000000000000000000015600000000000011605 Lustar  rootrootloki-ecmwf-0.3.6/loki/transformations/data_offload/tests/sources/projGlobalVarImports/driver_derived_type.F90loki-ecmwf-0.3.6/loki/transformations/data_offload/tests/sources/projGlobalVarImports/driver_derived0000664000175000017500000000026515167130205034425 0ustar  alastairalastairsubroutine driver_derived_type()
implicit none

!$loki update_device
!$acc serial
call kernel_derived_type()
!$acc end serial
!$loki update_host

end subroutine driver_derived_type
loki-ecmwf-0.3.6/loki/transformations/data_offload/tests/sources/projGlobalVarImports/kernels.F900000664000175000017500000000107515167130205033430 0ustar  alastairalastairsubroutine kernel0()
use func_mod, only: some_func
implicit none
  real a
  call kernel1()
  call kernel2()
  call kernel3()
  a = some_func()
end subroutine kernel0

subroutine kernel1()
use moduleA, only: var0,var1,some_func
implicit none
real :: tmp

tmp = var0 + var1 + some_func()

end subroutine kernel1

subroutine kernel2()
use moduleB, only: var2,var3
implicit none
real :: tmp

tmp = var2 + var3

end subroutine kernel2

subroutine kernel3()
use moduleB, only: var2,var3
use moduleC, only: var4,var5
implicit none

var4 = var2
var5 = var3

end subroutine kernel3
././@LongLink0000644000000000000000000000015600000000000011605 Lustar  rootrootloki-ecmwf-0.3.6/loki/transformations/data_offload/tests/sources/projGlobalVarImports/kernel_derived_type.F90loki-ecmwf-0.3.6/loki/transformations/data_offload/tests/sources/projGlobalVarImports/kernel_derived0000664000175000017500000000047315167130205034413 0ustar  alastairalastairsubroutine kernel_derived_type()
use module_derived_type, only: p,p0,g,p_array
implicit none
integer :: i,j

do i=1,p%n
  p%x(i) = p0%x(i)
  p%y(i) = p0%y(i)
  p%z(i) = p0%z(i)
  do j=1,p%n
    p_array(i)%x(j) = 1.
    p_array(i)%y(j) = 2.
    p_array(i)%z(j) = 3.
  enddo
enddo

end subroutine kernel_derived_type
loki-ecmwf-0.3.6/loki/transformations/data_offload/tests/sources/projGlobalVarImports/moduleC.F900000664000175000017500000000010215167130205033343 0ustar  alastairalastairmodule moduleC
   real :: var4
   real :: var5
end module moduleC
loki-ecmwf-0.3.6/loki/transformations/data_offload/tests/test_offload.py0000664000175000017500000002142515167130205026724 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki import Sourcefile
from loki.frontend import available_frontends
from loki.logging import log_levels
from loki.ir import (
    FindNodes, Pragma, PragmaRegion, Loop, CallStatement,
    pragma_regions_attached, get_pragma_parameters
)

from loki.transformations import DataOffloadTransformation, PragmaModelTransformation


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('assume_deviceptr', [True, False])
@pytest.mark.parametrize('present_on_device', [True, False])
@pytest.mark.parametrize('asynchronous', [True, False])
def test_data_offload_region_openacc(caplog, frontend, assume_deviceptr, present_on_device,
                                     asynchronous):
    """
    Test the creation of a simple device data offload region
    (`!$acc update`) from a `!$loki data` region with a single
    kernel call.
    """

    fcode_driver = f"""
  SUBROUTINE driver_routine(nlon, nlev, a, b, c)
    INTEGER, INTENT(IN)   :: nlon, nlev
    REAL, INTENT(INOUT)   :: a(nlon,nlev)
    REAL, INTENT(INOUT)   :: b(nlon,nlev)
    REAL, INTENT(INOUT)   :: c(nlon,nlev)

    !$loki data {'async(1)' if asynchronous else ''}
    call kernel_routine(nlon, nlev, a, b, c)
    !$loki end data

  END SUBROUTINE driver_routine
"""
    fcode_kernel = """
  SUBROUTINE kernel_routine(nlon, nlev, a, b, c)
    INTEGER, INTENT(IN)   :: nlon, nlev
    REAL, INTENT(IN)      :: a(nlon,nlev)
    REAL, INTENT(INOUT)   :: b(nlon,nlev)
    REAL, INTENT(OUT)     :: c(nlon,nlev)
    INTEGER :: i, j

    do j=1, nlon
      do i=1, nlev
        b(i,j) = a(i,j) + 0.1
        c(i,j) = 0.1
      end do
    end do
  END SUBROUTINE kernel_routine
"""
    driver = Sourcefile.from_source(fcode_driver, frontend=frontend)['driver_routine']
    kernel = Sourcefile.from_source(fcode_kernel, frontend=frontend)['kernel_routine']
    driver.enrich(kernel)

    if assume_deviceptr and not present_on_device:
        caplog.clear()
        with caplog.at_level(log_levels['ERROR']):
            with pytest.raises(RuntimeError):
                DataOffloadTransformation(assume_deviceptr=assume_deviceptr, present_on_device=present_on_device)
                assert len(caplog.records) == 1
                assert ("[Loki] Data offload: Can't assume device pointer arrays without arrays being marked" +
                    "present on device.") in caplog.records[0].message
            return

    trafos = ()
    trafos += (DataOffloadTransformation(assume_deviceptr=assume_deviceptr,
                                                   present_on_device=present_on_device),)
    trafos += (PragmaModelTransformation(directive='openacc'),)
    for trafo in trafos:
        driver.apply(trafo, role='driver', targets=['kernel_routine'])
    pragmas = FindNodes(Pragma).visit(driver.body)
    assert len(pragmas) == 2
    assert all(p.keyword == 'acc' for p in pragmas)
    if assume_deviceptr:
        assert 'deviceptr' in pragmas[0].content
        params = get_pragma_parameters(pragmas[0], only_loki_pragmas=False)
        assert all(var in params['deviceptr'] for var in ('a', 'b', 'c'))
    elif present_on_device:
        assert 'present' in pragmas[0].content
        params = get_pragma_parameters(pragmas[0], only_loki_pragmas=False)
        assert all(var in params['present'] for var in ('a', 'b', 'c'))
    else:
        transformed = driver.to_fortran()
        assert 'copyin( a )' in transformed
        assert 'copy( b )' in transformed
        assert 'copyout( c )' in transformed
        if asynchronous:
            assert 'async( 1 )' in transformed
    if asynchronous:
        assert 'async' in pragmas[0].content
        async_param = get_pragma_parameters(pragmas[0], only_loki_pragmas=False)['async']
        assert async_param =='1', 'async parameter should be 1.'


@pytest.mark.parametrize('frontend', available_frontends())
def test_data_offload_region_complex_remove_openmp(frontend):
    """
    Test the creation of a data offload region (OpenACC) with
    driver-side loops and CPU-style OpenMP pragmas to be removed.
    """

    fcode_driver = """
  SUBROUTINE driver_routine(nlon, nlev, a, b, c, flag)
    INTEGER, INTENT(IN)   :: nlon, nlev
    REAL, INTENT(INOUT)   :: a(nlon,nlev)
    REAL, INTENT(INOUT)   :: b(nlon,nlev)
    REAL, INTENT(INOUT)   :: c(nlon,nlev)
    logical, intent(in) :: flag
    INTEGER :: j

    !$loki data
    call my_custom_timer()

    if(flag)then
       !$omp parallel do private(j)
       do j=1, nlev
         call kernel_routine(nlon, j, a(:,j), b(:,j), c(:,j))
       end do
       !$omp end parallel do
    else
       !$omp parallel do private(j)
       do j=1, nlev
          a(:,j) = 0.
          b(:,j) = 0.
          c(:,j) = 0.
       end do
       !$omp end parallel do
    endif
    call my_custom_timer()

    !$loki end data
  END SUBROUTINE driver_routine
"""
    fcode_kernel = """
  SUBROUTINE kernel_routine(nlon, j, a, b, c)
    INTEGER, INTENT(IN)   :: nlon, j
    REAL, INTENT(IN)      :: a(nlon)
    REAL, INTENT(INOUT)   :: b(nlon)
    REAL, INTENT(INOUT)   :: c(nlon)
    INTEGER :: i

    do j=1, nlon
      b(i) = a(i) + 0.1
      c(i) = 0.1
    end do
  END SUBROUTINE kernel_routine
"""
    driver = Sourcefile.from_source(fcode_driver, frontend=frontend)['driver_routine']
    kernel = Sourcefile.from_source(fcode_kernel, frontend=frontend)['kernel_routine']
    driver.enrich(kernel)

    trafos = ()
    trafos += (DataOffloadTransformation(remove_openmp=True),)
    trafos += (PragmaModelTransformation(directive='openacc'),)
    for trafo in trafos:
        driver.apply(trafo, role='driver', targets=['kernel_routine'])

    assert len(FindNodes(Pragma).visit(driver.body)) == 2
    assert all(p.keyword == 'acc' for p in FindNodes(Pragma).visit(driver.body))

    with pragma_regions_attached(driver):
        # Ensure that loops in the region are preserved
        regions = FindNodes(PragmaRegion).visit(driver.body)
        assert len(regions) == 1
        assert len(FindNodes(Loop).visit(regions[0])) == 2

        # Ensure all activa and inactive calls are there
        calls = FindNodes(CallStatement).visit(regions[0])
        assert len(calls) == 3
        assert calls[0].name == 'my_custom_timer'
        assert calls[1].name == 'kernel_routine'
        assert calls[2].name == 'my_custom_timer'

        # Ensure OpenMP loop pragma is taken out
        assert len(FindNodes(Pragma).visit(regions[0])) == 0

    transformed = driver.to_fortran()
    assert 'copyin( a )' in transformed
    assert 'copy( b, c )' in transformed
    assert '!$omp' not in transformed


@pytest.mark.parametrize('frontend', available_frontends())
def test_data_offload_region_multiple(frontend):
    """
    Test the creation of a device data offload region (`!$acc update`)
    from a `!$loki data` region with multiple kernel calls.
    """

    fcode_driver = """
  SUBROUTINE driver_routine(nlon, nlev, a, b, c, d)
    INTEGER, INTENT(IN)   :: nlon, nlev
    REAL, INTENT(INOUT)   :: a(nlon,nlev)
    REAL, INTENT(INOUT)   :: b(nlon,nlev)
    REAL, INTENT(INOUT)   :: c(nlon,nlev)
    REAL, INTENT(INOUT)   :: d(nlon,nlev)

    !$loki data
    call kernel_routine(nlon, nlev, a, b, c)

    call kernel_routine(nlon, nlev, d, b, a)
    !$loki end data

  END SUBROUTINE driver_routine
"""
    fcode_kernel = """
  SUBROUTINE kernel_routine(nlon, nlev, a, b, c)
    INTEGER, INTENT(IN)   :: nlon, nlev
    REAL, INTENT(IN)      :: a(nlon,nlev)
    REAL, INTENT(INOUT)   :: b(nlon,nlev)
    REAL, INTENT(OUT)     :: c(nlon,nlev)
    INTEGER :: i, j

    do j=1, nlon
      do i=1, nlev
        b(i,j) = a(i,j) + 0.1
        c(i,j) = 0.1
      end do
    end do
  END SUBROUTINE kernel_routine
"""
    driver = Sourcefile.from_source(fcode_driver, frontend=frontend)['driver_routine']
    kernel = Sourcefile.from_source(fcode_kernel, frontend=frontend)['kernel_routine']
    driver.enrich(kernel)

    trafos = ()
    trafos += (DataOffloadTransformation(),)
    trafos += (PragmaModelTransformation(directive='openacc'),)
    for trafo in trafos:
        driver.apply(trafo, role='driver', targets=['kernel_routine'])

    assert len(FindNodes(Pragma).visit(driver.body)) == 2
    assert all(p.keyword == 'acc' for p in FindNodes(Pragma).visit(driver.body))

    # Ensure that the copy direction is the union of the two calls, ie.
    # "a" is "copyin" in first call and "copyout" in second, so it should be "copy"
    transformed = driver.to_fortran()
    assert 'copyin( d )' in transformed
    assert 'copy( b, a )' in transformed
    assert 'copyout( c )' in transformed
loki-ecmwf-0.3.6/loki/transformations/data_offload/offload.py0000664000175000017500000002241515167130205024523 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.batch import Transformation
from loki.expression import Array
from loki.ir import (
    FindNodes, PragmaRegion, CallStatement, Pragma, Transformer,
    pragma_regions_attached, get_pragma_parameters
)
from loki.logging import warning, error
from loki.tools import as_tuple
from loki.types import BasicType


__all__ = ['DataOffloadTransformation']


class DataOffloadTransformation(Transformation):
    """
    Utility transformation to insert data offload regions for GPU devices
    based on marked ``!$loki data`` regions. In the first instance this
    will insert OpenACC data offload regions, but can be extended to other
    offload region semantics (eg. OpenMP-5) in the future.

    Parameters
    ----------
    remove_openmp : bool
        Remove any existing OpenMP pragmas inside the marked region.
    present_on_device : bool
        Assume arrays are already offloaded and present on device"
    assume_deviceptr : bool
        Mark all offloaded arrays as true device-pointers if data offload
        is being managed outside of structured OpenACC data regions.
    """

    def __init__(self, **kwargs):
        # We need to record if we actually added any, so
        # that down-stream processing can use that info
        self.has_data_regions = False
        self.remove_openmp = kwargs.get('remove_openmp', False)
        self.assume_deviceptr = kwargs.get('assume_deviceptr', False)
        self.present_on_device = kwargs.get('present_on_device', False)

        if self.assume_deviceptr and not self.present_on_device:
            error("[Loki] Data offload: Can't assume device pointer arrays without arrays being marked" +
                    "present on device.")
            raise RuntimeError

    def transform_subroutine(self, routine, **kwargs):
        """
        Apply the transformation to a `Subroutine` object.

        Parameters
        ----------
        routine : `Subroutine`
            Subroutine to apply this transformation to.
        role : string
            Role of the `routine` in the scheduler call tree.
            This transformation will only apply at the ``'driver'`` level.
        targets : list or string
            List of subroutines that are to be considered as part of
            the transformation call tree.
        """
        role = kwargs.get('role')
        targets = as_tuple(kwargs.get('targets', None))

        if targets:
            targets = tuple(t.lower() for t in targets)

        if role == 'driver':
            self.remove_openmp_pragmas(routine, targets)
            self.insert_data_offload_pragmas(routine, targets)

    @staticmethod
    def _is_active_loki_data_region(region, targets):
        """
        Utility to decide if a ``PragmaRegion`` is of type ``!$loki data``
        and has active target routines.
        """
        if region.pragma.keyword.lower() != 'loki':
            return False
        if 'data' not in region.pragma.content.lower():
            return False

        # Find all targeted kernel calls
        calls = FindNodes(CallStatement).visit(region)
        calls = [c for c in calls if str(c.name).lower() in targets]
        if len(calls) == 0:
            return False

        return True

    def insert_data_offload_pragmas(self, routine, targets):
        """
        Find ``!$loki data`` pragma regions and create according
        ``!$acc udpdate`` regions.

        Parameters
        ----------
        routine : `Subroutine`
            Subroutine to apply this transformation to.
        targets : list or string
            List of subroutines that are to be considered as part of
            the transformation call tree.
        """
        pragma_map = {}
        with pragma_regions_attached(routine):
            for region in FindNodes(PragmaRegion).visit(routine.body):
                # Only work on active `!$loki data` regions
                if not self._is_active_loki_data_region(region, targets):
                    continue

                # Find all targeted kernel calls
                calls = FindNodes(CallStatement).visit(region)
                calls = [c for c in calls if str(c.name).lower() in targets]

                # Collect the three types of device data accesses from calls
                inargs = ()
                inoutargs = ()
                outargs = ()

                for call in calls:
                    if call.routine is BasicType.DEFERRED:
                        warning(f'[Loki] Data offload: Routine {routine.name} has not been enriched ' +
                                f'in {str(call.name).lower()}')

                        continue

                    for param, arg in call.arg_iter():
                        if isinstance(param, Array) and param.type.intent.lower() == 'in':
                            inargs += (str(arg.name).lower(),)
                        if isinstance(param, Array) and param.type.intent.lower() == 'inout':
                            inoutargs += (str(arg.name).lower(),)
                        if isinstance(param, Array) and param.type.intent.lower() == 'out':
                            outargs += (str(arg.name).lower(),)

                # Sanitize data access categories to avoid double-counting variables
                inoutargs += tuple(v for v in inargs if v in outargs)
                inargs = tuple(v for v in inargs if v not in inoutargs)
                outargs = tuple(v for v in outargs if v not in inoutargs)

                # Filter for duplicates
                inargs = tuple(dict.fromkeys(inargs))
                outargs = tuple(dict.fromkeys(outargs))
                inoutargs = tuple(dict.fromkeys(inoutargs))

                # Now generate the pre- and post pragmas (OpenACC)
                if self.present_on_device:
                    if self.assume_deviceptr:
                        offload_args = inargs + outargs + inoutargs
                        if offload_args:
                            deviceptr = f' vars({", ".join(offload_args)})'
                        else:
                            deviceptr = ''
                        pragma = Pragma(keyword='loki', content=f'device-ptr{deviceptr}')
                        pragma_post = Pragma(keyword='loki', content='end device-ptr')
                    else:
                        offload_args = inargs + outargs + inoutargs
                        if offload_args:
                            present = f' present({", ".join(offload_args)})'
                        else:
                            present = ''
                        pragma = Pragma(keyword='loki', content=f'structured-data {present}')
                        pragma_post = Pragma(keyword='loki', content='end structured-data')

                else:
                    copyin = f'in({", ".join(inargs)})' if inargs else ''
                    copy = f'inout({", ".join(inoutargs)})' if inoutargs else ''
                    copyout = f'out({", ".join(outargs)})' if outargs else ''
                    pragma = Pragma(keyword='loki', content=f'structured-data {copyin} {copy} {copyout}')
                    pragma_post = Pragma(keyword='loki', content='end structured-data')

                # Add async if present
                async_parameter = get_pragma_parameters(region.pragma).get('async', '')
                if async_parameter:
                    pragma = Pragma(keyword='loki', content=pragma.content+f' async({async_parameter})')

                # Add pragmas to map
                pragma_map[region.pragma] = pragma
                pragma_map[region.pragma_post] = pragma_post

                # Record that we actually created a new region
                if not self.has_data_regions:
                    self.has_data_regions = True

        routine.body = Transformer(pragma_map).visit(routine.body)

    def remove_openmp_pragmas(self, routine, targets):
        """
        Remove any existing OpenMP pragmas in the offload regions that
        will have been intended for OpenMP threading rather than
        offload.

        Parameters
        ----------
        routine : `Subroutine`
            Subroutine to apply this transformation to.
        targets : list or string
            List of subroutines that are to be considered as part of
            the transformation call tree.
        """
        pragma_map = {}
        with pragma_regions_attached(routine):
            for region in FindNodes(PragmaRegion).visit(routine.body):
                # Only work on active `!$loki data` regions
                if not self._is_active_loki_data_region(region, targets):
                    continue

                for p in FindNodes(Pragma).visit(routine.body):
                    if p.keyword.lower() == 'omp':
                        pragma_map[p] = None
                for r in FindNodes(PragmaRegion).visit(region):
                    if r.pragma.keyword.lower() == 'omp':
                        pragma_map[r.pragma] = None
                        pragma_map[r.pragma_post] = None

        routine.body = Transformer(pragma_map).visit(routine.body)
loki-ecmwf-0.3.6/loki/transformations/data_offload/global_var.py0000664000175000017500000007441015167130205025223 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from collections import defaultdict
from itertools import chain

from loki.analyse import dataflow_analysis_attached
from loki.batch import Transformation, ProcedureItem, ModuleItem
from loki.expression import Scalar, Array
from loki.ir import (
    FindNodes, CallStatement, Pragma, Import, Comment,
    Transformer, get_pragma_parameters,
    FindInlineCalls, SubstituteExpressions
)
from loki.logging import warning
from loki.tools import (
    as_tuple, flatten, CaseInsensitiveDict, CaseInsensitiveDefaultDict,
    OrderedSet
)
from loki.types import DerivedType


__all__ = [
    'GlobalVariableAnalysis', 'GlobalVarOffloadTransformation',
    'GlobalVarHoistTransformation',
]


class GlobalVariableAnalysis(Transformation):
    """
    Transformation pass to analyse the declaration and use of (global) module variables.

    This analysis is a requirement before applying :any:`GlobalVarOffloadTransformation`.

    Collect data in :any:`Item.trafo_data` for :any:`ProcedureItem` and
    :any:`ModuleItem` items and store analysis results under the
    provided :data:`key` (default: ``'GlobalVariableAnalysis'``) in the
    items' ``trafo_data``.

    For procedures, use the the Loki dataflow analysis functionality to compile
    a list of used and/or defined variables (i.e., read and/or written).
    Store these under the keys ``'uses_symbols'`` and ``'defines_symbols'``,
    respectively.

    For modules/:any:`ModuleItem`, store the list of variables declared in the
    module under the key ``'declares'`` and out of this the subset of variables that
    need offloading to device under the key ``'offload'``.

    Note that in every case, the full variable symbols are stored to allow access to
    type information in transformations using the analysis data.

    The generated trafo_data has the following schema::

        ModuleItem: {
            'declares': OrderedSet(Variable, Variable, ...),
            'offload': OrderedSet(Variable, ...)
        }

        ProcedureItem: {
            'uses_symbols': OrderedSet( (Variable, ''), (Variable, ''), ...),
            'defines_symbols': OrderedSet((Variable, ''), (Variable, ''), ...)
        }

    Parameters
    ----------
    key : str, optional
        Specify a different identifier under which trafo_data is stored
    """

    _key = 'GlobalVariableAnalysis'
    """Default identifier for trafo_data entry"""

    reverse_traversal = True
    """Traversal from the leaves upwards, i.e., modules with global variables are processed first,
    then kernels using them before the driver."""

    item_filter = (ProcedureItem, ModuleItem)
    """Process procedures and modules with global variable declarations."""

    def __init__(self, key=None):
        if key:
            self._key = key

    def transform_module(self, module, **kwargs):
        if 'item' not in kwargs:
            raise RuntimeError('Cannot apply GlobalVariableAnalysis without item to store analysis data')

        item = kwargs['item']

        # Gather all module variables and filter out parameters
        variables = OrderedSet(var for var in module.variables if not var.type.parameter)

        # Initialize and store trafo data
        item.trafo_data[self._key] = {
            'declares': variables,
            'offload': OrderedSet()
        }

    def transform_subroutine(self, routine, **kwargs):
        if 'item' not in kwargs:
            raise RuntimeError('Cannot apply GlobalVariableAnalysis without item to store analysis data')
        if 'sub_sgraph' not in kwargs:
            raise RuntimeError(('Cannot apply GlobalVariableAnalysis without information'
                ' about successors to store offload analysis data'))

        item = kwargs['item']
        sub_sgraph = kwargs.get('sub_sgraph', None)
        successors = sub_sgraph.successors(item) if sub_sgraph is not None else ()

        # Gather all symbols imported in this routine or parent scopes
        import_map = CaseInsensitiveDict()
        scope = routine
        while scope:
            import_map.update(scope.import_map)
            scope = scope.parent

        with dataflow_analysis_attached(routine):
            # Gather read and written symbols that have been imported
            uses_imported_symbols = OrderedSet(
                var for var in routine.body.uses_symbols
                if var.name in import_map or (var.parent and var.parents[0].name in import_map)
            )
            uses_imported_symbols |= OrderedSet(
                var for var in routine.spec.uses_symbols
                if var.name in import_map or (var.parent and var.parents[0].name in import_map)
            )
            defines_imported_symbols = OrderedSet(
                var for var in routine.body.defines_symbols
                if var.name in import_map or (var.parent and var.parents[0].name in import_map)
            )

            # Filter out type and procedure imports by restricting to Scalar and Array symbols
            uses_imported_symbols = OrderedSet(
                var for var in uses_imported_symbols if isinstance(var, (Scalar, Array))
            )
            defines_imported_symbols = OrderedSet(
                var for var in defines_imported_symbols if isinstance(var, (Scalar, Array))
            )

            def _map_var_to_module(var):
                if var.parent:
                    module = var.parents[0].type.module
                    module_var = module.variable_map[var.parents[0].name]
                    dimensions = getattr(module_var, 'dimensions', None)
                    for child in chain(var.parents[1:], (var,)):
                        module_var = child.clone(
                            name=f'{module_var.name}%{child.name}',
                            parent=module_var,
                            scope=module_var.scope
                        )
                    return (module_var.clone(dimensions=dimensions), module.name.lower())
                module = var.type.module
                return (module.variable_map[var.name], module.name.lower())

            # Store symbol lists in trafo data
            item.trafo_data[self._key] = {}
            item.trafo_data[self._key]['uses_symbols'] = OrderedSet(
                _map_var_to_module(var) for var in uses_imported_symbols
            )
            item.trafo_data[self._key]['defines_symbols'] = OrderedSet(
                _map_var_to_module(var) for var in defines_imported_symbols
            )

        # Amend analysis data with data from successors
        # Note: This is a temporary workaround for the incomplete list of successor items
        # provided by the current scheduler implementation
        for successor in successors:
            if isinstance(successor, ProcedureItem):
                item.trafo_data[self._key]['uses_symbols'] |= successor.trafo_data[self._key]['uses_symbols']
                item.trafo_data[self._key]['defines_symbols'] |= successor.trafo_data[self._key]['defines_symbols']


class GlobalVarOffloadTransformation(Transformation):
    """
    Transformation to insert offload directives for module variables used in device routines

    Currently, only OpenACC data offloading is supported.

    This requires a prior analysis pass with :any:`GlobalVariableAnalysis` to collect
    the relevant global variable use information.

    The offload directives are inserted by replacing ``!$loki update_device`` and
    ``!$loki update_host`` pragmas in the driver's source code. Importantly, no offload
    directives are added if these pragmas have not been added to the original source code!

    For global variables, the device-side declarations are added in :meth:`transform_module`.
    For driver procedures, the data offload and pull-back directives are added in
    the utility method :meth:`process_driver`, which is invoked by :meth:`transform_subroutine`.

    For example, the following code:

    .. code-block:: fortran

        module moduleB
           real :: var2
           real :: var3
        end module moduleB

        module moduleC
           real :: var4
           real :: var5
        end module moduleC

        subroutine driver()
        implicit none

        !$loki update_device
        !$acc serial
        call kernel()
        !$acc end serial
        !$loki update_host

        end subroutine driver

        subroutine kernel()
        use moduleB, only: var2,var3
        use moduleC, only: var4,var5
        implicit none
        !$acc routine seq

        var4 = var2
        var5 = var3

        end subroutine kernel

    is transformed to:

    .. code-block:: fortran

        module moduleB
           real :: var2
           real :: var3
          !$acc declare create(var2)
          !$acc declare create(var3)
        end module moduleB

        module moduleC
           real :: var4
           real :: var5
          !$acc declare create(var4)
          !$acc declare create(var5)
        end module moduleC

        subroutine driver()
        implicit none

        !$acc update device( var2,var3 )
        !$acc serial
        call kernel()
        !$acc end serial
        !$acc update self( var4,var5 )

        end subroutine driver

    Nested Fortran derived-types and arrays of derived-types are not currently supported.
    If such an import is encountered, only the device-side declaration will be added to the
    relevant module file, and the offload instructions will have to manually be added afterwards.
    """

    # Include module variable imports in the underlying graph
    # connectivity for traversal with the Scheduler
    item_filter = (ProcedureItem, ModuleItem)

    def __init__(self, key=None):
        self._key = key or GlobalVariableAnalysis._key

    def transform_module(self, module, **kwargs):
        """
        Add device-side declarations for imported variables
        """
        if 'item' not in kwargs:
            raise RuntimeError('Cannot apply GlobalVarOffloadTransformation without trafo_data in item')

        item = kwargs['item']

        # Check for already declared offloads
        acc_pragmas = [pragma for pragma in FindNodes(Pragma).visit(module.spec) if pragma.keyword.lower() == 'acc']
        acc_pragma_parameters = get_pragma_parameters(acc_pragmas, starts_with='declare', only_loki_pragmas=False)
        declared_variables = OrderedSet(flatten([
            v.replace(' ','').lower().split()
            for v in as_tuple(acc_pragma_parameters.get('create'))
        ]))

        # Build list of symbols to be offloaded (discard variables being parameter)
        offload_variables = OrderedSet(
            var.parents[0] if var.parent else var
            for var in item.trafo_data[self._key].get('offload', ()) if not var.type.parameter
        )

        if (invalid_vars := offload_variables - OrderedSet(module.variables)):
            raise RuntimeError(f'Invalid variables in offload analysis: {", ".join(v.name for v in invalid_vars)}')

        # Add ACC declare pragma for offload variables that are not yet declared
        offload_variables = offload_variables - declared_variables
        if offload_variables:
            module.spec.append(
                Pragma(keyword='loki', content=f'create device({", ".join(v.name for v in offload_variables)})')
            )

    def transform_subroutine(self, routine, **kwargs):
        """
        Add data offload and pull-back directives to the driver
        """
        role = kwargs.get('role')
        item = kwargs['item']
        sub_sgraph = kwargs.get('sub_sgraph', None)
        successors = sub_sgraph.successors(item) if sub_sgraph is not None else ()

        if role == 'driver':
            self.process_driver(routine, successors)
        elif role == 'kernel':
            self.process_kernel(item, successors)

    def process_kernel(self, item, successors):
        """
        Propagate offload requirement to the items of the global variables
        """
        successors_map = CaseInsensitiveDict(
            (item.name, item) for item in successors if isinstance(item, ModuleItem)
        )
        for var, module in chain(
            item.trafo_data[self._key]['uses_symbols'],
            item.trafo_data[self._key]['defines_symbols']
        ):
            if var.type.parameter:
                continue
            if successor := successors_map.get(module):
                successor.trafo_data[self._key]['offload'].add(var)

    def process_driver(self, routine, successors):
        """
        Add data offload and pullback directives

        List of variables that requires offloading is obtained from the analysis data
        stored for each successor in :data:`successors`.
        """
        # Empty lists for update directives
        update_device = ()
        update_host = ()

        # Combine analysis data across successor items
        defines_symbols = OrderedSet()
        uses_symbols = OrderedSet()
        for item in successors:
            defines_symbols |= item.trafo_data.get(self._key, {}).get('defines_symbols', OrderedSet())
            uses_symbols |= item.trafo_data.get(self._key, {}).get('uses_symbols', OrderedSet())
            # discard variables being parameter
            parameters = {(var, module) for var, module in uses_symbols if var.type.parameter}
            uses_symbols ^= parameters

        # Filter out arrays of derived types and nested derived types
        # For these, automatic offloading is currently not supported
        exclude_symbols = OrderedSet()
        for var_, module in chain(defines_symbols, uses_symbols):
            var = var_.parents[0] if var_.parent else var_
            if not isinstance(var.type.dtype, DerivedType):
                continue
            if isinstance(var, Array):
                exclude_symbols.add(var)
                warning((
                    '[Loki::GlobalVarOffloadTransformation] '
                    f'Automatic offloading of derived type arrays not implemented: {var} in {routine.name}'
                ))
            if any(isinstance(v.type.dtype, DerivedType) for v in var.type.dtype.typedef.variables):
                exclude_symbols.add(var)
                warning((
                    '[Loki::GlobalVarOffloadTransformation] '
                    f'Automatic offloading of nested derived types not implemented: {var} in {routine.name}'
                ))

        uses_symbols = OrderedSet(
            (var, module) for var, module in uses_symbols
            if var not in exclude_symbols and not (var.parent and var.parents[0] in exclude_symbols)
        )
        defines_symbols = OrderedSet(
            (var, module) for var, module in defines_symbols
            if var not in exclude_symbols and not (var.parent and var.parents[0] in exclude_symbols)
        )

        # All variables that are used in a kernel need a host-to-device transfer
        if uses_symbols:
            update_variables = OrderedSet(
                v for v, _ in uses_symbols
                if not (v.parent or isinstance(v.type.dtype, DerivedType))
            )
            copyin_variables = OrderedSet(v for v, _ in uses_symbols if v.parent)
            if update_variables:
                update_device += (
                    Pragma(keyword='loki', content=f'update device({", ".join(v.name for v in update_variables)})'),
                    # this shouldn't be necessary but is currently necessary because of a bug in OpenMP
                    Pragma(keyword='loki',
                        content=f'omp-update-global-vars in({", ".join(v.name for v in update_variables)})'),
                )
            if copyin_variables:
                content = f'unstructured-data in({", ".join(v.name for v in copyin_variables)})'
                update_device += (
                    Pragma(keyword='loki', content=content),
                )

        # All variables that are written in a kernel need a device-to-host transfer
        if defines_symbols:
            update_variables = OrderedSet(v for v, _ in defines_symbols if not v.parent)
            copyout_variables = OrderedSet(v for v, _ in defines_symbols if v.parent)
            create_variables = OrderedSet(
                v for v in copyout_variables
                if v not in uses_symbols and v.type.allocatable
            )
            if update_variables:
                update_host += (
                    Pragma(keyword='loki', content=f'update host({", ".join(v.name for v in update_variables)})'),
                )
            if copyout_variables:
                update_host += (
                    Pragma(keyword='loki',
                        content=f'exit unstructured-data out({", ".join(v.name for v in copyout_variables)})'),
                )
            if create_variables:
                update_device += (
                    Pragma(keyword='loki',
                        content=f'unstructured-data create({", ".join(v.name for v in create_variables)})'),
                )

        # Replace Loki pragmas with acc data/update pragmas
        pragma_map = {}
        for pragma in FindNodes(Pragma).visit(routine.body):
            if pragma.keyword == 'loki':
                if 'update_device' in pragma.content:
                    pragma_map[pragma] = update_device or None
                if 'update_host' in pragma.content:
                    pragma_map[pragma] = update_host or None

        routine.body = Transformer(pragma_map).visit(routine.body)

        # Add imports for offload variables
        offload_map = defaultdict(OrderedSet)
        for var, module in chain(uses_symbols, defines_symbols):
            offload_map[module].add(var.parents[0] if var.parent else var)

        import_map = CaseInsensitiveDict()
        scope = routine
        while scope:
            import_map.update(scope.import_map)
            scope = scope.parent

        missing_imports_map = defaultdict(OrderedSet)
        for module, variables in offload_map.items():
            missing_imports_map[module] |= OrderedSet(var for var in variables if var.name not in import_map)

        if missing_imports_map:
            routine.spec.prepend(Comment(text=(
                '![Loki::GlobalVarOffloadTransformation] ---------------------------------------'
            )))
            for module, variables in missing_imports_map.items():
                symbols = tuple(var.clone(dimensions=None, scope=routine) for var in variables)
                routine.spec.prepend(Import(module=module, symbols=symbols))

            routine.spec.prepend(Comment(text=(
                '![Loki::GlobalVarOffloadTransformation] '
                '-------- Added global variable imports for offload directives -----------'
            )))


class GlobalVarHoistTransformation(Transformation):
    """
    Transformation to hoist module variables used in device routines

    This requires a prior analysis pass with :any:`GlobalVariableAnalysis` to collect
    the relevant global variable use information.

    Modules to be ignored can be specified. Further, it is possible to
    configure whether parameters/compile time constants are hoisted as well
    or not.

    .. note::
      Hoisted variables that could theoretically be ``intent(out)``
      are despite specified as ``intent(inout)``.

    For example, the following code:

    .. code-block:: fortran

        module moduleB
           real :: var2
           real :: var3
        end module moduleB

        module moduleC
           real :: var4
           real :: var5
        end module moduleC

        subroutine driver()
        implicit none

        call kernel()

        end subroutine driver

        subroutine kernel()
        use moduleB, only: var2,var3
        use moduleC, only: var4,var5
        implicit none

        var4 = var2
        var5 = var3

        end subroutine kernel

    is transformed to:

    .. code-block:: fortran

        module moduleB
           real :: var2
           real :: var3
        end module moduleB

        module moduleC
           real :: var4
           real :: var5
        end module moduleC

        subroutine driver()
        use moduleB, only: var2,var3
        use moduleC, only: var4,var5
        implicit none

        call kernel(var2, var3, var4, var5)

        end subroutine driver

        subroutine kernel(var2, var3, var4, var5)
        implicit none
        real, intent(in) :: var2
        real, intent(in) :: var3
        real, intent(inout) :: var4
        real, intent(inout) :: var5

        var4 = var2
        var5 = var3

        end subroutine kernel

    Parameters
    ----------
    hoist_parameters : bool, optional
        Whether or not to hoist module variables being parameter/compile
        time constants (default: `False`).
    ignore_modules : (list, tuple) of str
        Modules to be ignored (default: `None`, thus no module to be ignored).
    key : str, optional
        Overwrite the key that is used to store analysis results in ``trafo_data``.
    """
    item_filter = ProcedureItem

    def __init__(self, hoist_parameters=False, ignore_modules=None, key=None):
        self._key = key or GlobalVariableAnalysis._key
        self.hoist_parameters = hoist_parameters
        self.ignore_modules = [module.lower() for module in as_tuple(ignore_modules)]

    def transform_subroutine(self, routine, **kwargs):
        """
        Hoist module variables.
        """
        role = kwargs.get('role')
        item = kwargs.get('item', None)
        sub_sgraph = kwargs.get('sub_sgraph', None)
        successors = sub_sgraph.successors(item) if sub_sgraph is not None else ()

        if role == 'driver':
            self.process_driver(routine, successors)
        elif role == 'kernel':
            self.process_kernel(routine, successors, item)

    def process_driver(self, routine, successors):
        """
        Hoist module variables for driver routines.

        This includes: appending the corresponding variables
        to calls within the driver and adding the relevant
        imports.
        """
        # get symbols per routine (successors)
        defines_symbols, uses_symbols = self._get_symbols(successors)

        # append symbols to calls (arguments)
        self._append_call_arguments(routine, uses_symbols, defines_symbols)

        # combine/collect symbols disregarding routine
        all_defines_symbols = OrderedSet.union(*defines_symbols.values(), OrderedSet())
        all_uses_symbols = OrderedSet.union(*uses_symbols.values(), OrderedSet())
        # add imports for symbols hoisted
        symbol_map = defaultdict(OrderedSet)
        for var, module in chain(all_uses_symbols, all_defines_symbols):
            # filter modules that are supposed to be ignored
            if module.lower() in self.ignore_modules:
                continue
            symbol_map[module].add(var.parents[0] if var.parent else var)
        import_map = CaseInsensitiveDict()
        scope = routine
        while scope:
            import_map.update(scope.import_map)
            scope = scope.parent
        missing_imports_map = defaultdict(OrderedSet)
        for module, variables in symbol_map.items():
            missing_imports_map[module] |= OrderedSet(var for var in variables if var.name not in import_map)
        if missing_imports_map:
            routine.spec.prepend(Comment(text=(
                '![Loki::GlobalVarHoistTransformation] ---------------------------------------'
            )))
            for module, variables in missing_imports_map.items():
                symbols = tuple(var.clone(dimensions=None, scope=routine) for var in variables)
                routine.spec.prepend(Import(module=module, symbols=symbols))

            routine.spec.prepend(Comment(text=(
                '![Loki::GlobalVarHoistTransformation] '
                '-------- Added global variable imports for offload directives -----------'
            )))

    def process_kernel(self, routine, successors, item):
        """
        Hoist mdule variables for kernel routines.

        This includes: appending the corresponding variables
        to the routine arguments as well as to calls within the kernel
        and removing the imports that became unused.
        """
        # get symbols per routine (successors)
        defines_symbols, uses_symbols = self._get_symbols(successors)

        # append symbols to routine (arguments)
        self._append_routine_arguments(routine, item)

        # append symbols to calls (arguments)
        self._append_call_arguments(routine, uses_symbols, defines_symbols)

        # get symbols for this routine/kernel
        kernel_defines_symbols = item.trafo_data.get(self._key, {}).get('defines_symbols', OrderedSet())
        kernel_uses_symbols = item.trafo_data.get(self._key, {}).get('uses_symbols', OrderedSet())
        # remove imports for symbols hoisted
        symbol_map = defaultdict(OrderedSet)
        for var, module in chain(kernel_uses_symbols, kernel_defines_symbols):
            # filter modules that are supposed to be ignored
            if module.lower() in self.ignore_modules:
                continue
            symbol_map[module].add(var.parents[0] if var.parent else var)
        import_map = CaseInsensitiveDict(
            (s.name, imprt) for imprt in routine.all_imports[::-1] for s in imprt.symbols
        )
        redundant_imports_map = defaultdict(OrderedSet)
        for module, variables in symbol_map.items():
            redundant = [var.parent[0] if var.parent else var for var in variables]
            redundant = {var.clone(dimensions=None) for var in redundant if var.name in import_map}
            redundant_imports_map[module] |= redundant
        import_map = {}
        imports = FindNodes(Import).visit(routine.spec)
        for _import in imports:
            new_symbols = tuple(
                var.clone(dimensions=None, scope=routine)
                for var in OrderedSet(_import.symbols)-redundant_imports_map[_import.module.lower()]
            )
            if new_symbols:
                import_map[_import] = _import.clone(symbols=new_symbols)
            else:
                import_map[_import] = None
        routine.spec = Transformer(import_map).visit(routine.spec)

    def _get_symbols(self, successors):
        """
        Get module variables/symbols (grouped by routine/successor).
        """
        defines_symbols = CaseInsensitiveDict()
        uses_symbols = CaseInsensitiveDict()
        for item in successors:
            if not isinstance(item, ProcedureItem):
                continue
            defines_symbols[item.local_name] = OrderedSet()
            uses_symbols[item.local_name] = OrderedSet()
            defines_symbols[item.local_name] = item.trafo_data.get(self._key, {}).get('defines_symbols', OrderedSet())
            uses_symbols[item.local_name] = item.trafo_data.get(self._key, {}).get('uses_symbols', OrderedSet())
            # remove parameters if hoist_parameters is False
            if not self.hoist_parameters:
                parameters = {(var, module) for var, module in uses_symbols[item.local_name] if var.type.parameter}
                uses_symbols[item.local_name] ^= parameters
        return defines_symbols, uses_symbols

    def _append_call_arguments(self, routine, uses_symbols, defines_symbols):
        """
        Helper to append variables to the call(s) (arguments).
        """
        symbol_map = CaseInsensitiveDefaultDict(OrderedSet)
        for key, _ in uses_symbols.items():
            all_symbols = uses_symbols[key]|defines_symbols[key]
            for var, module in all_symbols:
                # filter modules that are supposed to be ignored
                if module.lower() in self.ignore_modules:
                    continue
                symbol_map[key].add(var.parents[0] if var.parent else var)
        call_map = {}
        calls = FindNodes(CallStatement).visit(routine.body)
        for call in calls:
            if call.routine.name in uses_symbols:
                arguments = call.arguments
                new_args = sorted(
                    [var.clone(dimensions=None) for var in symbol_map[call.routine.name]],
                    key=lambda symbol: symbol.name
                )
                call_map[call] = call.clone(arguments=arguments + tuple(new_args))
        if call_map:
            routine.body = Transformer(call_map).visit(routine.body)
        inline_calls = FindInlineCalls().visit(routine.body)
        inline_call_map = {}
        for call in inline_calls:
            if call.routine.name in uses_symbols:
                arguments = call.parameters
                new_args = sorted([var.clone(dimensions=None) for var in symbol_map[call.routine.name]],
                        key=lambda symbol: symbol.name)
                inline_call_map[call] = call.clone(parameters=arguments + tuple(new_args))
        if inline_call_map:
            routine.body = SubstituteExpressions(inline_call_map).visit(routine.body)

    def _append_routine_arguments(self, routine, item):
        """
        Helper to append variables to the routine (arguments).
        """
        all_defines_symbols = item.trafo_data.get(self._key, {}).get('defines_symbols', OrderedSet())
        all_defines_vars = [var.parents[0] if var.parent else var for var, _ in all_defines_symbols]
        all_uses_symbols = item.trafo_data.get(self._key, {}).get('uses_symbols', OrderedSet())
        # remove parameters if hoist_parameters is False
        if not self.hoist_parameters:
            parameters = {(var, module) for var, module in all_uses_symbols if var.type.parameter}
            all_uses_symbols ^= parameters
        all_symbols = all_uses_symbols|all_defines_symbols
        new_arguments = []
        for var, module in all_symbols:
            # filter modules that are supposed to be ignored
            if module.lower() in self.ignore_modules:
                continue
            new_arguments.append(var.parents[0] if var.parent else var)
        new_arguments = OrderedSet(new_arguments) # remove duplicates
        new_arguments = [
            arg.clone(scope=routine, type=arg.type.clone(
                intent='inout' if arg in all_defines_vars else 'in',
                parameter=False, initial=None
            )) for arg in new_arguments
        ]
        routine.arguments += tuple(sorted(new_arguments, key=lambda symbol: symbol.name))
loki-ecmwf-0.3.6/loki/transformations/data_offload/offload_deepcopy.py0000664000175000017500000010617315167130205026417 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from collections import defaultdict
from pathlib import Path

import re
try:
    import yaml
    HAVE_YAML = True
except ImportError:
    HAVE_YAML = False

from loki.batch import Transformation, TypeDefItem, ProcedureItem
from loki.ir import (
        nodes as ir, FindNodes, SubstituteExpressions, Transformer,
        pragma_regions_attached, get_pragma_parameters, SubstitutePragmaStrings,
        is_loki_pragma, pragmas_attached
)
from loki.expression import symbols as sym
from loki.analyse.analyse_dataflow import DataflowAnalysisAttacher, DataflowAnalysisDetacher
from loki.transformations.utilities import find_driver_loops, get_integer_variable
from loki.logging import warning
from loki.tools import as_tuple, OrderedSet
from loki.types import BasicType, DerivedType
from loki.transformations.field_api import (
        FieldAPITransferType, field_get_device_data, field_get_host_data, field_delete_device_data,
        FieldAPIAccessorType
)

__all__ = ['DataOffloadDeepcopyAnalysis', 'DataOffloadDeepcopyTransformation']


def strip_nested_dimensions(expr):
    """
    Strip dimensions from array expressions of arbitrary derived-type
    nesting depth.
    """

    parent = expr.parent
    if parent:
        parent = strip_nested_dimensions(parent)
    return expr.clone(dimensions=None, parent=parent)


def get_sanitised_arg_map(arg_map):
    """
    Return sanitised mapping of dummy argument names to arguments.
    """

    _arg_map = {}
    for dummy, arg in arg_map.items():
        if isinstance(arg, sym._Literal):
            continue
        if isinstance(arg, sym.LogicalNot):
            arg = arg.child

        _arg_map[dummy.clone(dimensions=None)] = strip_nested_dimensions(arg)

    return _arg_map


def map_derived_type_arguments(arg_map, analysis):
    """
    Map the root variable of derived-type dummy argument components
    to the corresponding argument.
    """

    _analysis = {}
    for k, v in analysis.items():

        dummy_root = k.parents[0] if k.parents else k
        if not (arg := arg_map.get(dummy_root, None)):
            continue

        expr_map = {dummy_root: arg}
        var = SubstituteExpressions(expr_map).visit(k)

        _analysis[var] = v

    return _analysis


def create_nested_dict(k, v, variable_map):
    """Create nested dict from derived-type expression."""

    name_parts = k.name.split('%', maxsplit=1)
    parent = variable_map[name_parts[0]].clone(dimensions=None)
    if len(name_parts) > 1:
        child_name_parts = name_parts[1].split('%', maxsplit=1)
        child = parent.type.dtype.typedef.variable_map[child_name_parts[0]]
        if len(child_name_parts) > 1:
            child = child.get_derived_type_member(child_name_parts[1])
        v = create_nested_dict(child, v, parent.type.dtype.typedef.variable_map)

    return {parent: v}


def merge_nested_dict(ref_dict, temp_dict, force=False):
    """Merge nested dicts."""

    for key in temp_dict:
        if key in ref_dict:
            if isinstance(temp_dict[key], dict) and isinstance(ref_dict[key], dict):
                ref_dict[key] = merge_nested_dict(ref_dict[key], temp_dict[key], force=force)
            elif force:
                ref_dict[key] = temp_dict[key]
        else:
            ref_dict.update({key: temp_dict[key]})

    return ref_dict


class DeepcopyDataflowAnalysisAttacher(DataflowAnalysisAttacher):
    """
    Dummy argument intents in Fortran also have implications on memory status, and `INTENT(OUT)`
    is therefore fundamentally unsafe for allocatables and pointers. Therefore in order to discern
    write-only accesses to arguments, we have to bypass the intent. This is achieved here by importing
    the dataflow analysis of the child :any:`Subroutine` and ignoring the intents altogether.
    """

    def visit_CallStatement(self, o, **kwargs):

        successor_map = kwargs['successor_map']

        if not o.routine:
            msg = f'[Loki::DataOffloadDeepcopyAnalysis] Cannot apply transformation without enriching calls: {o}.'
            raise RuntimeError(msg)

        child = successor_map.get(o, None)
        if not child:
            return self.visit_Node(o, **kwargs)

        # remap root variable names to current scope
        arg_map = get_sanitised_arg_map(o.arg_map)
        child_analysis = child.trafo_data['DataOffloadDeepcopyAnalysis']['analysis']
        child_analysis = map_derived_type_arguments(arg_map, child_analysis)

        # Dimensions of array arguments must also be included in uses_symbols OrderedSet
        defines = OrderedSet()
        array_args = [v for v in o.arg_map.values() if isinstance(v, sym.Array)]
        uses = OrderedSet(v for a in array_args
                   for v in self._symbols_from_expr(a.dimensions))
        for k, v in child_analysis.items():

            if 'read' in v:
                uses |= {k}
            if 'write' in v:
                defines |= {k}

        return self.visit_Node(o, defines_symbols=defines, uses_symbols=uses, **kwargs)

class DataOffloadDeepcopyAnalysis(Transformation):
    """
    A transformation pass to analyse the usage of subroutine arguments in a call-tree.

    The resulting analysis is a nested dict, of nesting depth equal to the longest
    derived-type expression, containing the access mode of all the arguments used
    in a call-tree. For example, the following assignments:

    .. code-block:: fortran
       a%b%c = a%b%c + 1
       d = e


    would yield the following analysis:

    .. code-block:: python
       {
          a: {
            b: {
               c: 'readwrite'
            }
          },
          d: 'write',
          e: 'read'
       }

    The analysis is stored in the :any:`Item.trafo_data` of the :any:`Item` corresponding to the driver layer
    :any:`Subroutine`. It should be noted that the analysis is stored per driver-layer loop. The driver's
    :any:`Item.trafo_data` also contains :any:`Scheduler` config entries corresponding to the derived-types
    used throughout the call-tree in a :data:`typedef_configs` dict.

    Parameters
    ----------
    output_analysis : bool
       If enabled, the analysis is written to disk as yaml files. For kernels, the files are named
       routine.name_dataoffload_analysis.yaml. For drivers, the files are named
       driver_target-name_offload_analysis.yaml, where "target-name" is the name of the first target
       routine in a given driver loop.
    """

    _key = 'DataOffloadDeepcopyAnalysis'

    reverse_traversal = True
    """Traversal from the leaves upwards"""

    item_filter = (ProcedureItem, TypeDefItem)
    # Modules (correctly) placed in the ignore list contain type definitions and must
    # therefore be processed.
    process_ignored_items = True

    def __init__(self, output_analysis=False):
        self.output_analysis = output_analysis

    def transform_subroutine(self, routine, **kwargs):

        if not (item := kwargs.pop('item', None)):
            msg = f'[Loki::DataOffloadDeepcopyAnalysis] Cannot apply transformation without item: {routine}.'
            raise RuntimeError(msg)

        role = kwargs.pop('role')
        targets = kwargs.pop('targets')
        sgraph = kwargs.pop('sub_sgraph')
        successors = sgraph.successors(item=item)


        if role == 'driver':
            self.process_driver(routine, item, successors, targets, **kwargs)
        if role == 'kernel':
            self.process_kernel(routine, item, successors, **kwargs)

    def stringify_dict(self, _dict):
        """
        Stringify expression keys of a nested dict.
        """

        stringified_dict = {}
        for k, v in _dict.items():
            if isinstance(v, dict):
                stringified_dict[k.name.lower()] = self.stringify_dict(v)
            else:
                stringified_dict[k.name.lower()] = v

        return stringified_dict

    def process_driver(self, routine, item, successors, targets, **kwargs):

        item.trafo_data[self._key] = defaultdict(dict)

        with pragmas_attached(routine, ir.Loop):
            driver_loops = find_driver_loops(routine.body, targets)
        loop_analyses = {}

        for loop in driver_loops:

            # We can't simply map successor.ir: successor here because we may call a routine twice with different
            # arguments
            successor_map = {}
            calls = FindNodes(ir.CallStatement).visit(loop.body)
            for call in calls:
                if (successor := [s for s in successors if call.routine == s.ir]):
                    successor_map[call] = successor[0]

            # The analysis is accumulated on item.trafo_data, so to ensure each driver loop has an independent,
            # analysis we reset it here.
            item.trafo_data[self._key]['analysis'] = {}
            self.process_body(routine.name, item, successors, successor_map, loop)

            symbol_map = routine.symbol_map | routine.all_imported_symbol_map
            layered_dict = {}
            for k, v in item.trafo_data[self._key]['analysis'].items():
                _temp_dict = create_nested_dict(k, v, symbol_map)
                layered_dict = merge_nested_dict(layered_dict, _temp_dict)

            loop_analyses[loop] = layered_dict

            if self.output_analysis:
                if HAVE_YAML:
                    str_layered_dict = self.stringify_dict(layered_dict)
                    base_dir = Path(kwargs['build_args']['output_dir'])
                    if successor_map:
                        target_routine_name = list(successor_map.keys())[0].name
                    else:
                        target_routine_name = routine.name
                    with open(base_dir/f'driver_{target_routine_name}_dataoffload_analysis.yaml', 'w') as f:
                        yaml.dump(str_layered_dict, f)
                else:
                    warning('[Loki::DataOffloadDeepcopyAnalysis] cannot output analysis because yaml is not available.')


        # We store the collected analyses on item.trafo_data
        for loop in driver_loops:
            item.trafo_data[self._key]['analysis'][loop] = loop_analyses[loop]

    def process_kernel(self, routine, item, successors, **kwargs):

        item.trafo_data[self._key] = defaultdict(dict)

        # We can't simply map successor.ir: successor here because we may call a routine twice with different
        # arguments
        successor_map = {}
        for call in FindNodes(ir.CallStatement).visit(routine.body):
            if (successor := [s for s in successors if call.routine == s.ir]):
                successor_map[call] = successor[0]

        self.process_body(routine.name, item, successors, successor_map, routine)

        if self.output_analysis:
            if HAVE_YAML:
                layered_dict = {}
                for k, v in item.trafo_data[self._key]['analysis'].items():
                    _temp_dict = create_nested_dict(k, v, routine.symbol_map)
                    layered_dict = merge_nested_dict(layered_dict, _temp_dict)

                base_dir = Path(kwargs['build_args']['output_dir'])
                with open(base_dir/f'{routine.name.lower()}_dataoffload_analysis.yaml', 'w') as file:
                    str_layered_dict = self.stringify_dict(layered_dict)
                    yaml.dump(str_layered_dict, file)
            else:
                warning('[Loki::DataOffloadDeepcopyAnalysis] cannot output analysis because yaml is not available.')

    def process_body(self, routine_name, item, successors, successor_map, scope_node):
        # gather typedef configs from successors
        self.gather_typedef_configs(successors, item.trafo_data[self._key]['typedef_configs'])

        has_spec = hasattr(scope_node, 'spec')

        # Pointer indirection completely breaks the dataflow analysis, as the target
        # simply appears as if its being "read", regardless of how the pointer is used.
        # Since resolving pointer association is (super) hard, we just warn the user
        # here to double check the dataflow and provide overrides if necessary.
        pointers = any(a.ptr for a in FindNodes(ir.Assignment).visit(scope_node.body))
        if pointers:
            warning(f'[Loki::DataOffloadDeepcopyAnalysis] Pointer associations found in {routine_name}')

        # We make do here (lazily) without a context manager, as this override of the
        # DataflowAnalysisAttacher is not meant for use outside of the current module.
        dataflow_analysis = DeepcopyDataflowAnalysisAttacher(include_literal_kinds=False)
        if has_spec:
            dataflow_analysis.visit(scope_node.spec, successor_map=successor_map)
            dataflow_analysis.visit(scope_node.body, successor_map=successor_map)
        else:
            dataflow_analysis.visit(scope_node, successor_map=successor_map)

        #gather used symbols in specification
        if has_spec:
            spec_uses_symbols = scope_node.body.uses_symbols
        else:
            spec_uses_symbols = OrderedSet()

        if has_spec:
            for v in scope_node.spec.uses_symbols:
                if v.name_parts[0].lower() in getattr(scope_node, '_dummies', []):
                    item.trafo_data[self._key]['analysis'][v.clone(dimensions=None)] = 'read'

        #gather used and defined symbols in body
        if has_spec:
            uses_symbols = scope_node.body.uses_symbols
        else:
            uses_symbols = scope_node.uses_symbols

        for v in uses_symbols:
            if v.name_parts[0].lower() in getattr(scope_node, '_dummies', []) or not has_spec:
                item.trafo_data[self._key]['analysis'][v.clone(dimensions=None)] = 'read'

        if has_spec:
            defines_symbols = scope_node.body.defines_symbols
        else:
            defines_symbols = scope_node.defines_symbols

        for v in defines_symbols:
            if v.name_parts[0].lower() in getattr(scope_node, '_dummies', []) or not has_spec:
                if v in (spec_uses_symbols | uses_symbols):
                    item.trafo_data[self._key]['analysis'][v.clone(dimensions=None)] = 'readwrite'
                else:
                    item.trafo_data[self._key]['analysis'][v.clone(dimensions=None)] = 'write'

        if has_spec:
            DataflowAnalysisDetacher().visit(scope_node.spec)
            DataflowAnalysisDetacher().visit(scope_node.body)
        else:
            DataflowAnalysisDetacher().visit(scope_node)

    def gather_typedef_configs(self, successors, typedef_configs):
        """Gather typedef configs from children."""

        for child in successors:
            if isinstance(child, TypeDefItem) and child.trafo_data.get(self._key, None):
                typedef_configs.update(child.trafo_data[self._key]['typedef_configs'])

    def transform_module(self, module, **kwargs): # pylint: disable=unused-argument
        """Cache the current type definition config for later reuse."""

        item = kwargs['item']
        successors = kwargs['sub_sgraph'].successors(item=item)
        item.trafo_data[self._key] = defaultdict(dict)

        item.trafo_data[self._key]['typedef_configs'][item.ir.name.lower()] = item.config
        self.gather_typedef_configs(successors, item.trafo_data[self._key]['typedef_configs'])


class DataOffloadDeepcopyTransformation(Transformation):
    """
    A transformation that generates a deepcopy of all the arguments to a
    GPU kernel. It relies on the analysis gathered by the
    :any:`DataOffloadDeepcopyAnalysis` transformation, which must therefore
    be run before this. Please note that the analysis and deepcopy are per
    driver-loop, which must be wrapped in a `!$loki data` :any:`PragmaRegion`.

    An underlying assumption of the transformation is that expressions used as
    lvalues and rvalues are of type :any:`BasicType`, i.e. the data
    encompassed by a derived-type variable ``a`` with components ``b`` and ``c`` is
    only ever accessed or modified via fully qualified derived-type expressions
    ``a%b`` or ``a%c``. The only accepted exception to this are memory status checks
    such as ``ubound``, ``lbound``, ``size`` etc.

    The encompassing `!$loki data` :any:`PragmaRegion` can be used to
    to pass hints to the transformation. Consider the following example:

    .. code-block:: fortran

       !$loki data present(a) write(b)
       do ibl=1,nblks
          call kernel(a, b, ...)
       enddo
       !$loki end data

    Marking ``a`` as ``present`` instructs the transformation to skip the deepcopy
    generation for it and simply place it in a ``!$loki structured-data present``
    clause. Marking ``b`` as ``write`` means the contents of the analysis are
    overriden and the generated deepcopy for ``b`` assumes write-only access. Other
    hints that can be passed to the deepcopy generation are:
     - read: Assume read-only access for the specified variables.
     - readwrite: Assume read-write access for the specified variables.
     - device_resident: Don't copy the specificied variables back to host and
                        leave the device allocation intact.
     - temporary: Wipe the device allocation of the specified variables but
                  don't copy them  back to host.

    The transformation supports two modes:
     - offload: Generate device-host deepcopy for the arguments passed to the
                encompassed call-tree.
     - set_pointers: Generate the FIELD_API boiler-plate to set host pointers
                     for any argument representing a field.

    Parameters
    ----------
    mode : str
       Transformation mode, must be either "offload" or "set_pointers".
    """

    _key = 'DataOffloadDeepcopyAnalysis'
    field_array_match_pattern = re.compile('^field_[0-9][a-z][a-z]_array')

    def __init__(self, mode):
        self.mode = mode

    def transform_subroutine(self, routine, **kwargs):

        if not (item := kwargs.get('item', None)):
            msg = '[Loki::DataOffloadDeepcopyTransformation] can only be applied by the Scheduler.'
            raise RuntimeError(msg)

        if not item.trafo_data[self._key]:
            raise RuntimeError(f'[Loki::DataOffloadDeepcopyTransformation] item missing analysis: {item.name}.')

        role = kwargs['role']
        targets = kwargs['targets']

        if role == 'driver':
            self.process_driver(routine, item.trafo_data[self._key]['analysis'],
                                item.trafo_data[self._key]['typedef_configs'], targets)

    @staticmethod
    def update_with_manual_overrides(parameters, analysis, variable_map):
        """Update analysis with manual overrides specified in !loki data pragma."""

        override_map = {}
        for key in ['write', 'read', 'readwrite']:
            _vars = parameters.get(key, None)
            if _vars:
                _vars = [v.strip() for v in _vars.split(',')]
                override_map.update({var: key for var in _vars})

        for v, override in override_map.items():
            name_parts = v.split('%', maxsplit=1)
            var = variable_map[name_parts[0]]
            if len(name_parts) > 1:
                var = var.get_derived_type_member(name_parts[1])
            temp_dict = create_nested_dict(var, override, variable_map)
            analysis = merge_nested_dict(analysis, temp_dict, force=True)

        return analysis

    @staticmethod
    def get_pragma_vars(parameters, category):
        return [v.strip() for v in parameters.get(category, '').split(',')]

    def insert_deepcopy_instructions(self, region, mode, copy, host, wipe, present_vars):
        """Insert the generated deepcopy instructions and wrap the driver loop in
           a `data present` pragma region if applicable."""

        if mode == 'offload':
            # wrap in acc data present pragma
            content = f"structured-data present({', '.join(present_vars)})"
            acc_data_pragma = ir.Pragma(keyword='loki', content=content)
            acc_data_pragma_post = ir.Pragma(keyword='loki', content='end structured-data')

            pragma_map = {region.pragma: (copy, acc_data_pragma)}
            pragma_map.update({region.pragma_post: (acc_data_pragma_post, host, wipe)})
        else:
            # We remove all offload instructions first and non F-API related boiler plate
            vmap = {}

            conds = FindNodes((ir.Conditional, ir.Loop), greedy=True).visit(host)
            for cond in conds:
                calls = FindNodes(ir.CallStatement).visit(cond.body)
                get_host_call = any('get_host_data_rdwr' in v.name.name.lower() for v in calls)

                if not get_host_call:
                    vmap[cond] = None

            host_pragmas = FindNodes(ir.Pragma).visit(host)
            vmap.update({p: None for p in host_pragmas})
            host = Transformer(vmap).visit(host)

            # Now we insert the updated "host" body in the driver layer
            pragma_map = {region.pragma: host, region.pragma_post: None}

        return pragma_map

    def process_driver(self, routine, analyses, typedef_configs, targets):

        pragma_map = {}
        imports = defaultdict(tuple)
        symbol_map = routine.symbol_map | routine.all_imported_symbol_map
        with pragma_regions_attached(routine):
            for region in FindNodes(ir.PragmaRegion).visit(routine.body):

                # Only work on active `!$loki data` regions
                if not is_loki_pragma(region.pragma, starts_with='data'):
                    continue

                parameters = get_pragma_parameters(region.pragma, starts_with='data')
                with pragmas_attached(routine, ir.Loop):
                    driver_loops = find_driver_loops(region.body, targets)

                # skip the deepcopy for variables previously marked as present/private
                present = self.get_pragma_vars(parameters, 'present')
                private = self.get_pragma_vars(parameters, 'private')

                # temporary variables are not copied back to host and are wiped from device memory
                temporary = self.get_pragma_vars(parameters, 'temporary')

                # device_resident variables are left on device (i.e. neither copied back to host nor deleted)
                device_resident = self.get_pragma_vars(parameters, 'device_resident')

                copy, host, wipe = (), (), ()
                present_vars = ()
                for loop in driver_loops:

                    # filter out root-level scalars from the analysis as these are thread-private by default
                    # and should not in any case be copied back to host
                    analysis = {k: v for k, v in analyses[loop].items()
                                if not (isinstance(k.type.dtype, BasicType) and isinstance(k, sym.Scalar))}

                    # update analysis with manual overrides
                    analysis = self.update_with_manual_overrides(parameters, analysis, routine.symbol_map)

                    # recursively traverse analysis and generate deepcopy
                    _copy, _host, _wipe, _imports = self.generate_deepcopy(routine, analysis=analysis,
                                                                           present=present, private=private,
                                                                           temporary=temporary,
                                                                           device_resident=device_resident,
                                                                           typedef_configs=typedef_configs,
                                                                           symbol_map=symbol_map)

                    copy += _copy
                    host += _host
                    wipe += _wipe
                    for mod in _imports:
                        imports[mod] += as_tuple(_imports[mod])

                    present_vars += as_tuple(v.name for v in analysis if not v in private)

                # replace the `!$loki data` PragmaRegion with the generated deepcopy instructions
                pragma_map.update(self.insert_deepcopy_instructions(region, self.mode, copy, host, wipe, present_vars))

        for mod in imports:
            imports[mod] = as_tuple(dict.fromkeys(imports[mod]))
            routine.spec.prepend(as_tuple(ir.Import(module=mod, symbols=imports[mod])))
        routine.body = Transformer(pragma_map).visit(routine.body)

    def wrap_in_loopnest(self, var, body, routine):
        """Wrap body in loop nest corresponding to the shape of var."""

        # Don't wrap an empty body
        if not body:
            return ()

        # initialise working variables
        loop_vars = []
        loopbody = body
        var_with_dims = None
        dimensions = var.dimensions

        # build loop-nest one layer at a time
        for dim in range(len(var.type.shape)):

            loop_vars += [get_integer_variable(routine, f'J{dim+1}')]
            if not loop_vars[-1] in routine.variables:
                routine.variables += as_tuple(loop_vars[-1])

            # Create loop bounds
            lstart = sym.InlineCall(function=sym.ProcedureSymbol('LBOUND', scope=routine),
                                    parameters=(var, sym.IntLiteral(dim+1)))
            lend = sym.InlineCall(function=sym.ProcedureSymbol('UBOUND', scope=routine),
                                    parameters=(var, sym.IntLiteral(dim+1)))
            bounds = sym.LoopRange((lstart, lend))

            var_with_dims = var.clone(dimensions=dimensions)
            dimensions += as_tuple(loop_vars[-1])
            vmap = {var_with_dims: var_with_dims.clone(dimensions=dimensions)}
            str_map = {str(k): str(v) for k, v in vmap.items()}

            SubstitutePragmaStrings(str_map).visit(loopbody)
            loopbody = as_tuple(SubstituteExpressions(vmap).visit(loopbody))

            loop = ir.Loop(variable=loop_vars[-1], bounds=bounds, body=loopbody)
            loopbody = loop

        return as_tuple(loop)

    @staticmethod
    def create_memory_status_test(check, var, body, scope):
        """Wrap a given body in a memory status check."""

        # Don't wrap an empty body
        if not body:
            return ()

        condition = sym.InlineCall(function=sym.ProcedureSymbol(check, scope=scope),
                                   parameters=as_tuple(var))
        return as_tuple(ir.Conditional(condition=condition, body=body))

    @staticmethod
    def enter_data_copyin(var):
        """Generate unstructured data copyin instruction."""
        return as_tuple(ir.Pragma(keyword='loki', content=f'unstructured-data in({var})'))

    @staticmethod
    def enter_data_create(var):
        """Generate unstructured data create instruction."""
        return as_tuple(ir.Pragma(keyword='loki', content=f'unstructured-data create({var})'))

    @staticmethod
    def enter_data_attach(var):
        """Generate unstructured data attach instruction."""
        return as_tuple(ir.Pragma(keyword='loki', content=f'unstructured-data attach({var})'))

    @staticmethod
    def exit_data_detach(var):
        """Generate unstructured data detach instruction."""
        return as_tuple(ir.Pragma(keyword='loki', content=f'exit unstructured-data detach({var}) finalize'))

    @staticmethod
    def exit_data_delete(var):
        """Generate unstructured data delete instruction."""
        return as_tuple(ir.Pragma(keyword='loki', content=f'exit unstructured-data delete({var}) finalize'))

    @staticmethod
    def update_self(var):
        """Pull back data to host."""
        return as_tuple(ir.Pragma(keyword='loki', content=f'update host({var})'))

    def create_field_api_offload(self, var, analysis, typedef_config, parent, scope, symbol_map):

        #TODO: currently this assumes FIELD objects and their associated pointers are
        # components of the same derived-type. This should be generalised for the case
        # where the two are declared separately.

        # Strip view pointer prefix
        view_ptr_prefix = typedef_config.get('view_ptr_prefix', '').lower()
        var_name = var.name.lower().replace(view_ptr_prefix, '')

        # Get FIELD object name
        if not (field_object_name := typedef_config['field_ptr_map'].get(var_name, None)):
            field_object_name = typedef_config.get('field_prefix', '') + var_name.replace('_field', '')

        # Create FIELD object
        variable_map = parent.type.dtype.typedef.variable_map
        field_object = variable_map[field_object_name].clone(parent=parent)
        field_ptr = var.clone(dimensions=None, parent=parent)

        if analysis == 'read':
            access_mode = FieldAPITransferType.READ_ONLY
        elif analysis == 'readwrite':
            access_mode = FieldAPITransferType.READ_WRITE
        else:
            access_mode = FieldAPITransferType.WRITE_ONLY

        imports = defaultdict(tuple)
        device = as_tuple(field_get_device_data(field_object, field_ptr, access_mode, scope,
            accessor_type=FieldAPIAccessorType.GENERIC))
        get_device_call_proc_symbol = device[0].name
        if not get_device_call_proc_symbol in symbol_map:
            imports['FIELD_ACCESS_MODULE'] += as_tuple(get_device_call_proc_symbol)
        device += self.enter_data_attach(field_ptr)
        host = as_tuple(field_get_host_data(field_object, field_ptr, FieldAPITransferType.READ_WRITE, scope,
            accessor_type=FieldAPIAccessorType.GENERIC))
        get_host_call_proc_symbol = host[0].name
        if not get_host_call_proc_symbol in symbol_map:
            imports['FIELD_ACCESS_MODULE'] += as_tuple(get_host_call_proc_symbol)
        wipe = self.exit_data_detach(field_ptr)
        wipe += as_tuple(field_delete_device_data(field_object, scope))

        wipe = self.create_memory_status_test('ASSOCIATED', field_object, wipe, scope)

        return device, host, wipe, imports

    def create_dummy_field_array_typedef_config(self, parent):
        """The scheduler will never traverse the FIELD_RANKSUFF_ARRAY type definitions,
           so we create a dummy typedef config here."""

        if self.field_array_match_pattern.match(parent.type.dtype.typedef.name.lower()):
            typedef_config = {
                'field_prefix': 'F_',
                'field_ptr_suffix': '_FIELD',
                'field_ptr_map': {}
            }
            return typedef_config
        return None

    def generate_deepcopy(self, routine, **kwargs):
        """Recursively traverse the deepcopy analysis to generate the deepcopy instructions."""

        # initialise tuples used to store the deepcopy instructions
        copy, host, wipe, imports = (), (), (), defaultdict(tuple)

        analysis = kwargs.pop('analysis')
        parent = kwargs.pop('parent', None)

        for var in analysis:

            _copy, _host, _wipe, _imports = (), (), (), {}

            # Don't generate a deepcopy for variables marked as present or private
            if var in kwargs['present'] or var in kwargs['private']:
                continue

            # determine if var should be kept on device
            delete = not var in kwargs['device_resident']
            # determine if this is a temporary variable
            temporary = var in kwargs['temporary']

            check = 'ASSOCIATED' if var.type.pointer else None
            if not check:
                check = 'ALLOCATED' if var.type.allocatable else None

            if isinstance(var.type.dtype, DerivedType):

                var_with_parent = var.clone(parent=parent)

                # If we are directly assigning derived-types, rather than operating on members,
                # then we are the lowest level of the analysis and don't want to recurse further
                if not isinstance(analysis[var], str):
                    _copy, _host, _wipe, _imports = self.generate_deepcopy(routine, analysis=analysis[var],
                                                                           parent=var_with_parent, **kwargs)

                #wrap in loop
                if var.type.shape:
                    _copy = self.wrap_in_loopnest(var_with_parent, _copy, routine)
                    _host = self.wrap_in_loopnest(var_with_parent, _host, routine)
                    _wipe = self.wrap_in_loopnest(var_with_parent, _wipe, routine)

                # var must be allocated/deallocated on device
                if not parent or check:
                    _copy = self.enter_data_copyin(var_with_parent) + _copy
                    _wipe += self.exit_data_delete(var_with_parent)

                # wrap in memory status check
                if check:
                    _copy = self.create_memory_status_test(check, var_with_parent, _copy, routine)
                    _host = self.create_memory_status_test(check, var_with_parent, _host, routine)
                    _wipe = self.create_memory_status_test(check, var_with_parent, _wipe, routine)

            else:

                # First determine whether we have a field pointer or a regular array/scalar
                typedef_config = None
                if parent:
                    typedef_config = kwargs['typedef_configs'].get(parent.type.dtype.typedef.name.lower(), None)

                # Create a dummy typedef config for FIELD_RANKSUFF_ARRAY types
                if parent and not typedef_config:
                    typedef_config = self.create_dummy_field_array_typedef_config(parent)

                field = False
                if typedef_config:
                    # Is our pointer in the given list of field ptrs or has the right suffix?
                    suffix = typedef_config['field_ptr_suffix']
                    field = var in typedef_config.get('field_ptrs', [])
                    field = field or re.search(f'{suffix}$', var.name, re.IGNORECASE)

                if field:
                    _copy, _host, _wipe, _imports = self.create_field_api_offload(var, analysis[var], typedef_config,
                                                                                  parent, routine, kwargs['symbol_map'])
                else:
                    # We have a regular array/scalar
                    if not parent or check:
                        if analysis[var] == 'write':
                            _copy = self.enter_data_create(var.clone(parent=parent))
                        else:
                            _copy = self.enter_data_copyin(var.clone(parent=parent))
                        _wipe = self.exit_data_delete(var.clone(parent=parent))

                    # Copy back to host if necessary
                    if analysis[var] != 'read':
                        _host = self.update_self(var.clone(parent=parent))

                    # wrap in memory status check
                    if check:
                        _copy = self.create_memory_status_test(check, var.clone(parent=parent), _copy, routine)
                        _host = self.create_memory_status_test(check, var.clone(parent=parent), _host, routine)
                        _wipe = self.create_memory_status_test(check, var.clone(parent=parent), _wipe, routine)

            copy += as_tuple(_copy)
            if delete and not temporary:
                host += as_tuple(_host)
            if delete:
                wipe += as_tuple(_wipe)
            for mod in _imports:
                imports[mod] += as_tuple(_imports[mod])

        return copy, host, wipe, imports
loki-ecmwf-0.3.6/loki/transformations/loop_blocking.py0000664000175000017500000002752415167130205023335 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


from loki.ir import (
    nodes as ir, Transformer, FindVariables, SubstituteExpressions
)
from loki.subroutine import Subroutine
from loki.expression import (
    symbols as sym, parse_expr, ceil_division, iteration_index
)
from loki.logging import error

__all__ = ['split_loop', 'split_loop_region', 'block_loop_arrays']


class LoopSplittingVariables:
    """
    This class holds the loop splitting variables, e.g. outer loop block sizes and iteration
    bounds. It also holds the original loop variable of the inner loop.
    """

    def __init__(self, loop_var: sym.Variable, block_size):
        self._loop_var = loop_var

        if isinstance(block_size, int):
            blk_size = sym.IntLiteral(block_size)
        elif isinstance(block_size, (sym.Scalar, sym.IntLiteral)):
            blk_size = block_size
        else:
            error("LoopSplittingVariables: Block size argument must be an integer constant or a scalar variable")
            raise ValueError('Block size must a be a an integer constant or a scalar variable')

        self._splitting_vars = (loop_var.clone(name=loop_var.name + "_loop_block_size",
                                               type=loop_var.type.clone(parameter=True,
                                                                        initial=blk_size)),  # pylint: disable=possibly-used-before-assignment
                                loop_var.clone(name=loop_var.name + "_loop_num_blocks"),
                                loop_var.clone(name=loop_var.name + "_loop_block_idx"),
                                loop_var.clone(name=loop_var.name + "_loop_local"),
                                loop_var.clone(name=loop_var.name + "_loop_iter_num"),
                                loop_var.clone(name=loop_var.name + "_loop_block_start"),
                                loop_var.clone(name=loop_var.name + "_loop_block_end")
                                )

    @property
    def loop_var(self):
        return self._loop_var

    @property
    def block_size(self):
        return self._splitting_vars[0]

    @property
    def num_blocks(self):
        return self._splitting_vars[1]

    @property
    def block_idx(self):
        return self._splitting_vars[2]

    @property
    def inner_loop_var(self):
        return self._splitting_vars[3]

    @property
    def iter_num(self):
        return self._splitting_vars[4]

    @property
    def block_start(self):
        return self._splitting_vars[5]

    @property
    def block_end(self):
        return self._splitting_vars[6]

    @property
    def splitting_vars(self):
        return self._splitting_vars


def compute_block_indices(splitting_vars, loop, scope):
    """
    Compute start and end indices for a *block* inside the outer loop.
    """
    block_start = ir.Assignment(splitting_vars.block_start,
                                parse_expr(
                                    f"({splitting_vars.block_idx} - 1) * {splitting_vars.block_size} + 1",
                                    scope=scope)
                                )
    block_end = ir.Assignment(splitting_vars.block_end,
                              sym.InlineCall(sym.DeferredTypeSymbol('MIN', scope=scope),
                                             parameters=(sym.Product(children=(
                                                 splitting_vars.block_idx, splitting_vars.block_size)),
                                                         loop.bounds.num_iterations))
                              )
    return block_start, block_end


def compute_num_blocks(splitting_vars, loop):
    """
    Compute the total number of blocks.
    """
    num_blocks = ir.Assignment(splitting_vars.num_blocks,
                               ceil_division(loop.bounds.num_iterations,
                                             splitting_vars.block_size))
    return num_blocks


def create_inner_loop(splitting_vars: LoopSplittingVariables, loop: ir.Loop, scope):
    """
    Create innermost loop in the loop split.
    """
    iteration_nums = (
        ir.Assignment(splitting_vars.iter_num,
                      parse_expr(
                          f"{splitting_vars.block_start}+{splitting_vars.inner_loop_var}-1"),
                      scope=scope),
        ir.Assignment(loop.variable,
                      iteration_index(splitting_vars.iter_num, loop.bounds))
    )
    inner_loop = loop.clone(variable=splitting_vars.inner_loop_var, body=iteration_nums + loop.body,
                            bounds=sym.LoopRange(
                                (sym.IntLiteral(1), parse_expr(
                                    f"{splitting_vars.block_end} - {splitting_vars.block_start} + 1",
                                    scope=scope))))
    return inner_loop


def split_loop(routine: Subroutine, loop: ir.Loop, block_size: int):
    """
    Blocks a loop by splitting it into an outer loop and inner loop of size `block_size`.

    Parameters
    ----------
    routine: :any:`Subroutine`
        Subroutine object containing the loop. New variables introduced in the
        loop splitting will be declared in the body of routine.
    loop: :any:`Loop`
        Loop to be split.
    block_size: int
        inner loop size (size of blocking blocks)
    """
    splitting_vars = LoopSplittingVariables(loop.variable, block_size)
    routine.variables += splitting_vars.splitting_vars

    inner_loop = create_inner_loop(splitting_vars, loop, routine)

    block_loop_body = (
        compute_block_indices(splitting_vars, loop, routine),
        inner_loop)
    outer_loop = ir.Loop(variable=splitting_vars.block_idx, body=block_loop_body,
                         bounds=sym.LoopRange((sym.IntLiteral(1), splitting_vars.num_blocks)))

    change_map = {loop: (compute_num_blocks(splitting_vars, loop),) + (outer_loop,)}
    Transformer(change_map, inplace=True).visit(routine.body)

    return splitting_vars, inner_loop, outer_loop


def split_loop_region(routine: Subroutine, loop: ir.Loop, block_size: int, data_region):
    """
    Blocks a loop inside a data region and puts the data region inside the outer loop.

    Parameters
    ----------
    routine: :any:`Subroutine`
        Subroutine object containing the loop. New variables introduced in the
        loop splitting will be declared in the body of routine.
    loop: :any:`Loop`
        Loop to be split.
    block_size: int
        inner loop size (size of blocking blocks)
    data_region: :any:`PragmaRegion`,
        data region containing the loop to be blocked
    """
    splitting_vars = LoopSplittingVariables(loop.variable, block_size)
    routine.variables += splitting_vars.splitting_vars

    inner_loop = create_inner_loop(splitting_vars, loop, routine)

    # Create a new data region and place inside loop body
    new_data_region = Transformer({loop: inner_loop}, inplace=False).visit(data_region)
    block_loop_body = (
        compute_block_indices(splitting_vars, loop, routine),
        new_data_region
    )
    outer_loop = ir.Loop(variable=splitting_vars.block_idx, body=block_loop_body,
                         bounds=sym.LoopRange((sym.IntLiteral(1), splitting_vars.num_blocks)))

    change_map = {data_region: (compute_num_blocks(splitting_vars, loop),) + (outer_loop,)}
    Transformer(change_map, inplace=True).visit(routine.body)

    return splitting_vars, inner_loop, outer_loop, new_data_region


def blocked_shape(a: sym.Array, blocking_indices, block_size):
    """
    calculates the dimensions for a blocked version of the array.
    """
    shape = tuple(
        sym.IntLiteral(block_size) if isinstance(dim, sym.Scalar) and any(
            bidx in dim for bidx in blocking_indices) else dim for dim
        in a.shape)
    return shape


def blocked_type(a: sym.Array):
    return a.type.clone(intent=None)


def replace_indices(dimensions, indices: list, replacement_index):
    """
    Returns a new dimension object with all occurences of indices changed to replacement_index.

    Parameters
    ----------
    dimensions:
        Symbolic representation of dimensions or indices.
    indices: list of `Variable`s
        that will be replaced in the new :any:`Dimension` object.
    replacement_index: :any:`Expression`
        replacement for the indices changed.

    Returns
    -------
    """
    dims = tuple(
        replacement_index if isinstance(dim, sym.Scalar) and any(
            blocking_var in dim for blocking_var in indices) else dim for dim
        in dimensions)
    return dims


def block_loop_arrays(routine: Subroutine, splitting_vars, inner_loop: ir.Loop,
                      outer_loop: ir.Loop, blocking_indices):
    """
    Replaces arrays inside the inner loop with blocked counterparts.

    This routine declares array variables to hold the blocks of the arrays used inside
    the loop and replaces array variables inside the loop with their blocked counterparts.
    An array is blocked with the leading dimensions

    Parameters
    ----------
    routine : Subroutine
        routine in which the blocking variables should be added.
    blocking_indices: list of  :any:`Variable`
        list of the index variables that arrays inside the loop should be blocked by.
    inner_loop: :any:`Loop`
        inner loop after loop splitting
    outer_loop : :any:`Loop`
        outer loop body after loop splitting
    blocking_indices : tuple or list of str
           Variable names of the indexes that should be blocked if in array
            expressions.

    """
    # Declare Blocked arrays
    arrays = tuple(var for var in FindVariables().visit(inner_loop.body) if
                   isinstance(var, sym.Array) and any(
                       bi in var for bi in blocking_indices))
    name_map = {a.name: a.name + '_block' for a in arrays}
    block_arrays = tuple(
        a.clone(name=name_map[a.name],
                dimensions=blocked_shape(a, blocking_indices, splitting_vars.block_size),
                type=blocked_type(a)) for a in arrays)
    routine.variables += block_arrays

    # Replace arrays in loop with blocked arrays and update idx
    block_array_expr = (
        a.clone(name=name_map[a.name],
                dimensions=replace_indices(a.dimensions, blocking_indices, inner_loop.variable))
        for a in arrays
    )
    SubstituteExpressions(dict(zip(arrays, block_array_expr)), inplace=True).visit(inner_loop.body)

    # memory copies
    block_range = sym.RangeIndex((splitting_vars.block_start, splitting_vars.block_end))
    local_range = sym.RangeIndex(
        (sym.IntLiteral(1),
         parse_expr(f"{splitting_vars.block_end} - {splitting_vars.block_start} + 1",
                    scope=routine)))
    # input variables
    in_vars = (a for a in arrays if a.type.intent in ('in', 'inout'))
    copyins = tuple(
        ir.Assignment(a.clone(name=name_map[a.name],
                              dimensions=replace_indices(a.dimensions, blocking_indices,
                                                         local_range)),
                      a.clone(
                          dimensions=replace_indices(a.dimensions, blocking_indices, block_range)))
        for a in in_vars)
    # output variables
    out_vars = (a for a in arrays if a.type.intent in ('out', 'inout'))
    copyouts = tuple(
        ir.Assignment(
            a.clone(dimensions=replace_indices(a.dimensions, blocking_indices, block_range)),
            a.clone(name=name_map[a.name],
                    dimensions=replace_indices(a.dimensions, blocking_indices, local_range))
        )
        for a in out_vars)
    change_map = {inner_loop: copyins + (inner_loop,) + copyouts}
    Transformer(change_map, inplace=True).visit(outer_loop)
loki-ecmwf-0.3.6/loki/transformations/dependency.py0000664000175000017500000003401215167130205022620 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.batch import Transformation
from loki.ir import nodes as ir, Transformer, FindNodes
from loki.tools.util import as_tuple, CaseInsensitiveDict

__all__ = ['DuplicateKernel', 'RemoveKernel']


class DuplicateKernel(Transformation):
    """
    Duplicate subroutines which includes the creation of new :any:`Item`s
    as well as the addition of the corresponding new dependencies.

    Therefore, this transformation creates a new item and also implements
    the relevant routines for dry-run pipeline planning runs.

    Parameters
    ----------
    duplicate_kernels : str|tuple|list, optional
        Kernel name(s) to be duplicated.
    duplicate_suffix : str, optional
        Suffix to be used to append the original kernel name(s).
    duplicate_module_suffix : str, optional
        Suffix to be used to append the original module name(s),
        if defined, otherwise `duplicate_suffix`
    duplicate_subgraph : bool, optional
        Whether or not duplicate the subgraph beneath the kernel(s)
        that are duplicated.
    """

    creates_items = True
    reverse_traversal = True

    def __init__(self, duplicate_kernels=None, duplicate_suffix='duplicated',
                 duplicate_module_suffix=None, duplicate_subgraph=False):
        self.suffix = duplicate_suffix
        self.module_suffix = duplicate_module_suffix or duplicate_suffix
        self.duplicate_kernels = tuple(kernel.lower() for kernel in as_tuple(duplicate_kernels))
        self.duplicate_subgraph = duplicate_subgraph

    def _get_new_item_name(self, item):
        """
        Get new/duplicated item name, more specifically ``local_name``,
        ``scope_name`` and ``new_item_name``.

        Parameters
        ----------
        item : :any:`Item`
            The item used to derive ``local_name``,
            ``scope_name`` and ``new_item_name``.
        Returns
        -------
        scope_name : str
            New item scope name.
        new_item_name : str
            New item name.
        local_name : str
            New item local name.
        """
        # Determine new item name
        scope_name = item.scope_name
        local_name = f'{item.local_name}{self.suffix}'
        if scope_name:
            scope_name = f'{scope_name}{self.module_suffix}'
        # Try to get existing item from cache
        new_item_name = f'{scope_name or ""}#{local_name}'
        return scope_name, local_name, new_item_name

    def _get_or_create_or_rename_item(self, item, item_factory, config):
        """
        Get, create or rename item including the scope item if there is a
        scope.

        Parameters
        ----------
        item : :any:`Item`
            Item to duplicate/to use to derive new item.
        item_factory : :any:`ItemFactory`
            The :any:`ItemFactory` to use when creating the items.
        config : :any:`SchedulerConfig`
            The scheduler config to use when instantiating new items.
        Returns
        -------
        :any:`Item`
            Newly created item.
        """
        scope_name, local_name, new_item_name = self._get_new_item_name(item)
        new_item = item_factory.item_cache.get(new_item_name)
        # Try to get an item for the scope or create that first
        if new_item is None and scope_name:
            scope_item = item_factory.item_cache.get(scope_name)
            if scope_item:
                scope = scope_item.ir
                if local_name not in scope and item.local_name in scope:
                    # Rename the existing item to the new name
                    scope[item.local_name].name = local_name

                if local_name in scope:
                    new_item = item_factory.create_from_ir(
                        scope[local_name], scope, config=config
                    )
        # Create new item
        if new_item is None:
            new_item = item_factory.get_or_create_item_from_item(new_item_name, item, config=config)
        return new_item

    def _modify_sgraph(self, sgraph, item, new_items):
        """
        Add new items to graph.

        Parameters
        ----------
        sgraph : :any:`SGraph`
            Directed graph or rather copy of it to
            be modified.
        item : :any:`Item`
            Node to which add the new items.
        new_items : tuple
            Tuple of :any:`Item` to add to graph.
        """
        sgraph.add_nodes(new_items)
        sgraph.add_edges((item, _item) for _item in new_items)

    def _rename_calls(self, new_item, new_dependencies):
        """
        Rename calls and imports according to the newly created
        duplicated items.

        Parameters
        ----------
        new_item : :any:`Item`
            The newly created item for which calls and imports
            to be renamed.
        new_dependencies : dict
            Dictionary used to get information about how
            to rename calls and imports.
        """
        call_map = {}
        for call in FindNodes(ir.CallStatement).visit(new_item.ir.body):
            call_name = str(call.name).lower()
            new_call_name = f'{call_name}{self.suffix}'.lower()
            if new_call_name in new_dependencies:
                call_new_item = new_dependencies[new_call_name]
                proc_symbol = call_new_item.ir.procedure_symbol.rescope(scope=new_item.ir)
                call_map[call] = call.clone(name=proc_symbol)
        # TODO: imports at module level ...
        imp_map = {}
        for imp in FindNodes(ir.Import).visit(new_item.ir.spec):
            # potentially new symbols
            symbol_map = {symbol: symbol.clone(name=f'{symbol.name}{self.suffix}') for symbol in imp.symbols}
            new_symbols = ()
            orig_symbols = ()
            # distinguish imported symbols that should remain and those which should be altered
            for orig_symbol, new_symbol in symbol_map.items():
                if new_symbol in new_dependencies:
                    new_symbols += (new_symbol,)
                else:
                    orig_symbols += (orig_symbol,)
            new_imports = ()
            if new_symbols:
                new_imports += (imp.clone(module=f'{imp.module.lower()}{self.module_suffix}',
                                         symbols=as_tuple(new_symbols)),)
            if orig_symbols:
                new_imports += (imp.clone(symbols=as_tuple(orig_symbols)),)
            if new_imports:
                imp_map[imp] = new_imports
        if call_map:
            new_item.ir.body = Transformer(call_map).visit(new_item.ir.body)
        if imp_map:
            new_item.ir.spec = Transformer(imp_map).visit(new_item.ir.spec)

    def _create_duplicate_items(self, successors, item_factory, config, item, sub_sgraph,
            rename_calls=False, force_duplicate=False, ignore=None):
        """
        Create new/duplicated items.

        Parameters
        ----------
        successors : tuple
            Tuple of :any:`Item`s representing the successor items for which
            new/duplicated items are created..
        item_factory : :any:`ItemFactory`
            The :any:`ItemFactory` to use when creating the items.
        config : :any:`SchedulerConfig`
            The scheduler config to use when instantiating new items.
        item : :any:`Item`
            Starting point/source item from which the successors
            originate from
        sub_sgraph : :any:`SGraph`
            Sgraph (copy) representing the subgraph of the directed
            overall graph.
        rename_calls : bool, optional
            Rename calls/imports in accordance to the duplicated
            kernels.
        force_duplicate : bool, optional
            Check whether successor is within ``duplicate_kernels``
            or duplicate either way.
        Returns
        -------
        tuple
            Tuple of newly created items.
        """
        ignore = as_tuple(ignore)
        new_items = ()
        for child in successors:
            if child.local_name in self.duplicate_kernels or force_duplicate:
                if child.local_name in ignore:
                    continue
                # get/create/rename item
                new_item = self._get_or_create_or_rename_item(child, item_factory, config)
                new_items += as_tuple(new_item)
                # duplicate subgraph?
                if self.duplicate_subgraph:
                    new_item = as_tuple(new_item)[0]
                    # add new_items to sgraph (copy)
                    self._modify_sgraph(sub_sgraph, item, new_items)
                    # get the successors
                    child_ignore = ignore + as_tuple(child.ignore)
                    child_successors = as_tuple([successor for successor in sub_sgraph.successors(child)
                        if successor.local_name not in child_ignore])
                    if child_successors:
                        # create new/duplicated successors
                        new_dependencies = self._create_duplicate_items(child_successors, item_factory, config,
                                sub_sgraph=sub_sgraph, item=new_item, rename_calls=rename_calls, force_duplicate=True,
                                ignore=ignore)
                        new_dependencies_dic = CaseInsensitiveDict((new_item.local_name, new_item)
                                for new_item in new_dependencies)
                        # add dependencies to new/duplicated successors and remove the "old" ones
                        new_item.plan_data.setdefault('additional_dependencies', ())
                        new_item.plan_data.setdefault('removed_dependencies', ())
                        new_item.plan_data['additional_dependencies'] += as_tuple(new_dependencies)
                        new_item.plan_data['removed_dependencies'] += as_tuple(child_successors)
                        sub_sgraph._add_children(new_item, item_factory, config, dependencies=new_dependencies)
                        # rename calls and imports
                        if rename_calls:
                            self._rename_calls(new_item, new_dependencies_dic)
        return tuple(new_items)

    def transform_subroutine(self, routine, **kwargs):
        # Create new dependency items
        item = kwargs.get('item')
        sub_sgraph = kwargs.get('sub_sgraph', None)
        successors = sub_sgraph.successors(item) if sub_sgraph is not None else ()
        ignore = tuple(str(t).lower() for t in as_tuple(kwargs.get('ignore', None)))
        new_dependencies = self._create_duplicate_items(
            successors=successors,
            item_factory=kwargs.get('item_factory'),
            config=kwargs.get('scheduler_config'),
            item=kwargs.get('item'),
            sub_sgraph=sub_sgraph,
            rename_calls=True, ignore=ignore
        )
        new_dependencies = CaseInsensitiveDict((new_item.local_name, new_item) for new_item in new_dependencies)

        # Duplicate calls to kernels
        call_map = {}
        new_imports = []
        for call in FindNodes(ir.CallStatement).visit(routine.body):
            call_name = str(call.name).lower()
            if call_name in self.duplicate_kernels:
                # Duplicate the call
                new_call_name = f'{call_name}{self.suffix}'.lower()
                new_item = new_dependencies[new_call_name]
                proc_symbol = new_item.ir.procedure_symbol.rescope(scope=routine)
                call_map[call] = (call, call.clone(name=proc_symbol))

                # Register the module import
                if new_item.scope_name:
                    new_imports += [ir.Import(module=new_item.scope_name, symbols=(proc_symbol,))]

        if call_map:
            routine.body = Transformer(call_map).visit(routine.body)
            if new_imports:
                routine.spec.prepend(as_tuple(new_imports))

    def plan_subroutine(self, routine, **kwargs):
        item = kwargs.get('item')
        sub_sgraph = kwargs.get('sub_sgraph', None)
        successors = sub_sgraph.successors(item) if sub_sgraph is not None else ()
        ignore = tuple(str(t).lower() for t in as_tuple(kwargs.get('ignore', None)))
        item.plan_data.setdefault('additional_dependencies', ())
        item.plan_data['additional_dependencies'] += self._create_duplicate_items(
            successors=successors,
            item_factory=kwargs.get('item_factory'),
            config=kwargs.get('scheduler_config'),
            item=item,
            sub_sgraph=sub_sgraph,
            ignore=ignore
        )


class RemoveKernel(Transformation):
    """
    Remove subroutines which includes the removal of the relevant :any:`Item`s
    as well as the removal of the corresponding dependencies.

    Therefore, this transformation creates a new item and also implements
    the relevant routines for dry-run pipeline planning runs.

    Parameters
    ----------
    remove_kernels : str|tuple|list, optional
        Kernel name(s) to be removed.
    """

    creates_items = True

    def __init__(self, remove_kernels=None):
        self.remove_kernels = tuple(kernel.lower() for kernel in as_tuple(remove_kernels))

    def transform_subroutine(self, routine, **kwargs):
        call_map = {
            call: None for call in FindNodes(ir.CallStatement).visit(routine.body)
            if str(call.name).lower() in self.remove_kernels
        }
        routine.body = Transformer(call_map).visit(routine.body)

    def plan_subroutine(self, routine, **kwargs):
        item = kwargs.get('item')
        sub_sgraph = kwargs.get('sub_sgraph', None)
        successors = sub_sgraph.successors(item) if sub_sgraph is not None else ()

        item.plan_data.setdefault('removed_dependencies', ())
        item.plan_data['removed_dependencies'] += tuple(
            child for child in successors if child.local_name in self.remove_kernels
        )
loki-ecmwf-0.3.6/loki/transformations/block_index_transformations.py0000664000175000017500000010162015167130205026274 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.batch import Transformation, ProcedureItem
from loki.ir import (
    nodes as ir, FindNodes, Transformer, pragmas_attached,
    pragma_regions_attached, FindVariables, SubstituteExpressions,
    AttachScopes
)
from loki.logging import warning
from loki.module import Module
from loki.tools import as_tuple
from loki.types import SymbolAttributes, BasicType, DerivedType
from loki.expression import (
    symbols as sym, Variable, Array, RangeIndex
)
from loki.transformations.array_indexing import resolve_vector_dimension
from loki.transformations.sanitise import do_resolve_associates
from loki.transformations.utilities import (
    find_driver_loops, recursive_expression_map_update, get_integer_variable,
    check_routine_sequential
)

__all__ = ['BlockViewToFieldViewTransformation', 'InjectBlockIndexTransformation',
        'LowerBlockIndexTransformation', 'LowerBlockLoopTransformation']

class BlockViewToFieldViewTransformation(Transformation):
    """
    A very IFS-specific transformation to replace per-block, i.e. per OpenMP-thread, view pointers with per-field
    view pointers. It should be noted that this transformation only replaces the view pointers but does not actually
    insert the block index into the promoted view pointers. Therefore this transformation must always be followed by
    the :any:`InjectBlockIndexTransformation`.

    For example, the following code:

    .. code-block:: fortran

        do jlon=1,nproma
          mystruct%p(jlon,:) = 0.
        enddo

    is transformed to:

    .. code-block:: fortran

        do jlon=1,nproma
          mystruct%p_field(jlon,:) = 0.
        enddo

    As the rank of ``my_struct%p_field`` is one greater than that of ``my_struct%p``, we would need to also apply
    the :any:`InjectBlockIndexTransformation` to obtain semantically correct code:

    .. code-block:: fortran

        do jlon=1,nproma
          mystruct%p_field(jlon,:,ibl) = 0.
        enddo

    Specific arrays in individual routines can also be marked for exclusion from this transformation by assigning
    them to the `exclude_arrays` list in the :any:`SchedulerConfig`.

    This transformation also creates minimal definitions of FIELD API wrappers (i.e. FIELD_RANKSUFF_ARRAY) and
    uses them to enrich the :any:`DataType` of relevant variable declarations and expression nodes. This is
    required because FIELD API can be built independently of library targets Loki would typically operate on.

    Parameters
    ----------
    horizontal : :any:`Dimension`
        :any:`Dimension` object describing the variable conventions used in code
        to define the horizontal data dimension and iteration space.
    global_gfl_ptr: bool
        Toggle whether thread-local gfl_ptr should be replaced with global.
    key : str, optional
        Specify a different identifier under which trafo_data is stored
    """

    _key = 'BlockViewToFieldViewTransformation'
    """Identifier for trafo_data entry"""

    item_filter = (ProcedureItem,)

    def __init__(self, horizontal, global_gfl_ptr=False):
        self.horizontal = horizontal
        self.global_gfl_ptr = global_gfl_ptr

    def transform_subroutine(self, routine, **kwargs):

        if not (item := kwargs.get('item', None)):
            raise RuntimeError('Cannot apply BlockViewToFieldViewTransformation without item to store definitions')
        sub_sgraph = kwargs.get('sub_sgraph', None)
        successors = sub_sgraph.successors(item) if sub_sgraph is not None else ()

        role = kwargs['role']
        targets = tuple(str(t).lower() for t in as_tuple(kwargs.get('targets', None)))

        exclude_var_names = item.config.get('exclude_var_names', [])

        if role == 'kernel':
            self.process_kernel(routine, item, successors, targets, exclude_var_names)
        if role == 'driver':
            self.process_driver(routine, item, successors, targets, exclude_var_names)

    @staticmethod
    def _get_parkind_suffix(_type):
        return _type.rsplit('_')[1][1:3]

    def _build_parkind_import(self, field_array_module, wrapper_types):

        deferred_type = SymbolAttributes(BasicType.DEFERRED, imported=True)
        _vars = {Variable(name='JP' + self._get_parkind_suffix(t), type=deferred_type, scope=field_array_module)
                for t in wrapper_types}

        return ir.Import(module='PARKIND1', symbols=as_tuple(_vars))

    def _build_field_array_types(self, field_array_module, wrapper_types):
        """
        Build FIELD_RANKSUFF_ARRAY type-definitions.
        """

        typedefs = ()
        for _type in wrapper_types:
            suff = self._get_parkind_suffix(_type)
            kind = field_array_module.symbol_map['JP' + suff]
            rank = int(_type.rsplit('_')[1][0])

            view_shape = (RangeIndex(children=(None, None)),) * (rank - 1)
            array_shape = (RangeIndex(children=(None, None)),) * rank

            if suff == 'IM':
                basetype = BasicType.INTEGER
            elif suff == 'LM':
                basetype = BasicType.LOGICAL
            else:
                basetype = BasicType.REAL

            pointer_type = SymbolAttributes(basetype, pointer=True, kind=kind, shape=view_shape)
            contig_pointer_type = pointer_type.clone(contiguous=True, shape=array_shape)

            pointer_var = Variable(name='P', type=pointer_type, dimensions=view_shape)
            contig_pointer_var = pointer_var.clone(name='P_FIELD', type=contig_pointer_type, dimensions=array_shape) # pylint: disable=no-member

            decls = (ir.VariableDeclaration(symbols=(pointer_var,)),)
            decls += (ir.VariableDeclaration(symbols=(contig_pointer_var,)),)

            field_object_type = SymbolAttributes(DerivedType(name=_type.lower().replace('_array', '')),
                                                 pointer=True, polymorphic=True)
            field_object = Variable(name='F_P', type=field_object_type)
            decls += (ir.VariableDeclaration(symbols=(field_object,)),)

            typedefs += (ir.TypeDef(name=_type, body=decls, parent=field_array_module),) # pylint: disable=unexpected-keyword-arg

        # attach a scope to the symbols in the newly created typedefs
        for typedef in typedefs:
            typedef.rescope_symbols()

        return typedefs

    def _create_dummy_field_api_defs(self, field_array_mod_imports):
        """
        Create dummy definitions for FIELD_API wrapper-types to enrich typedefs.
        """

        wrapper_types = {sym.name for imp in field_array_mod_imports for sym in imp.symbols}

        # create dummy module with empty spec
        field_array_module = Module(name='FIELD_ARRAY_MODULE', spec=ir.Section(body=()))

        # build parkind1 import
        parkind_import = self._build_parkind_import(field_array_module, wrapper_types)
        field_array_module.spec.append(parkind_import)

        # build dummy type definitions
        typedefs = self._build_field_array_types(field_array_module, wrapper_types)
        field_array_module.spec.append(typedefs)

        return [field_array_module,]

    @staticmethod
    def propagate_defs_to_children(key, definitions, successors):
        """
        Enrich all successors with the dummy FIELD_API definitions.
        """

        for child in successors:
            child.ir.enrich(definitions)
            child.trafo_data.update({key: {'definitions': definitions}})

    def process_driver(self, routine, item, successors, targets, exclude_var_names):

        # create dummy definitions for field_api wrapper types
        field_array_mod_imports = [imp for imp in routine.imports if imp.module.lower() == 'field_array_module']
        definitions = []
        if field_array_mod_imports:
            definitions += self._create_dummy_field_api_defs(field_array_mod_imports)

        # propagate dummy field_api wrapper definitions to self and children
        routine.enrich(definitions)
        self.propagate_defs_to_children(self._key, definitions, successors)

        with pragmas_attached(routine, ir.Loop):
            for loop in find_driver_loops(routine.body, targets):
                body = self.process_body(loop.body, item, successors, targets, exclude_var_names)
                body_map = {loop.body: body}
                Transformer(body_map, inplace=True).visit(loop)

    def build_ydvars_global_gfl_ptr(self, var):
        """Replace accesses to thread-local ``YDVARS%GFL_PTR`` with global ``YDVARS%GFL_PTR_G``."""

        if (parent := var.parent):
            parent = self.build_ydvars_global_gfl_ptr(parent)

        _type = var.type
        if 'gfl_ptr' in var.basename.lower():
            _type = parent.variable_map['gfl_ptr_g'].type

        return var.clone(name=var.name.upper().replace('GFL_PTR', 'GFL_PTR_G'),
                         parent=parent, type=_type)

    def process_body(self, body, item, successors, targets, exclude_var_names):

        # build list of type-bound array access using the horizontal index
        _vars = [var for var in FindVariables(unique=False).visit(body)
                if isinstance(var, Array) and var.parents and self.horizontal.index in var.dimensions]

        # build list of type-bound view pointers passed as subroutine arguments
        for call in FindNodes(ir.CallStatement).visit(body):
            if call.name in targets:
                _args = {a: d for d, a in call.arg_map.items() if isinstance(d, Array)}
                _vars += [a for a, d in _args.items()
                          if any(v in d.shape for v in self.horizontal.sizes) and a.parents]

        # filter out variables marked for exclusion
        _vars = [v for v in _vars if not any(e in v for e in exclude_var_names)]

        # replace per-block view pointers with full field pointers
        vmap = {var: var.clone(name=var.name_parts[-1] + '_FIELD',
                               type=var.parent.variable_map[var.name_parts[-1] + '_FIELD'].type)
                for var in _vars}

        # replace thread-private GFL_PTR with global
        if self.global_gfl_ptr:
            vmap.update({v: self.build_ydvars_global_gfl_ptr(vmap.get(v, v))
                         for v in FindVariables(unique=False).visit(body) if 'ydvars%gfl_ptr' in v.name.lower()})
            vmap = recursive_expression_map_update(vmap)

        # propagate dummy field_api wrapper definitions to children
        if item.trafo_data.get(self._key, None):
            definitions = item.trafo_data[self._key]['definitions']
            self.propagate_defs_to_children(self._key, definitions, successors)

        # finally we perform the substitution
        return SubstituteExpressions(vmap).visit(body)


    def process_kernel(self, routine, item, successors, targets, exclude_var_names):

        # Sanitize the subroutine
        do_resolve_associates(routine)

        # Bail if routine is marked as sequential
        if check_routine_sequential(routine):
            return

        resolve_vector_dimension(routine, dimension=self.horizontal)

        # for kernels we process the entire body
        routine.body = self.process_body(routine.body, item, successors, targets, exclude_var_names)


class InjectBlockIndexTransformation(Transformation):
    """
    A transformation pass to inject the block-index in arrays promoted by a previous transformation pass. As such,
    this transformation also relies on the block-index, or a known alias, being *already* present in routines that
    are to be transformed.

    For array access in a :any:`Subroutine` body, it operates by comparing the local shape of an array with its
    declared shape. If the local shape is of rank one less than the declared shape, then the block-index is appended
    to the array's dimensions.

    For :any:`CallStatement` arguments, if the rank of the argument is one less than that of the corresponding
    dummy-argument, the block-index is appended to the argument's dimensions. It should be noted that this logic relies
    on the :any:`CallStatement` being free of any sequence-association.

    For example, the following code:

    .. code-block:: fortran

        subroutine kernel1(nblks, ...)
           ...
           integer, intent(in) :: nblks
           integer :: ibl
           real :: var(jlon,nlev,nblks)

           do ibl=1,nblks
             do jlon=1,nproma
               var(jlon,:) = 0.
             enddo

             call kernel2(var,...)
           enddo
           ...
        end subroutine kernel1

        subroutine kernel2(var, ...)
           ...
           real :: var(jlon,nlev)
        end subroutine kernel2

    is transformed to:

    .. code-block:: fortran

        subroutine kernel1(nblks, ...)
           ...
           integer, intent(in) :: nblks
           integer :: ibl
           real :: var(jlon,nlev,nblks)

           do ibl=1,nblks
             do jlon=1,nproma
               var(jlon,:,ibl) = 0.
             enddo

             call kernel2(var(:,:,ibl),...)
           enddo
           ...
        end subroutine kernel1

        subroutine kernel2(var, ...)
           ...
           real :: var(jlon,nlev)
        end subroutine kernel2

    Specific arrays in individual routines can also be marked for exclusion from this transformation by assigning
    them to the `exclude_arrays` list in the :any:`SchedulerConfig`.

    Parameters
    ----------
    block_dim : :any:`Dimension`
        :any:`Dimension` object describing the variable conventions used in code
        to define the blocking data dimension and iteration space.
    key : str, optional
        Specify a different identifier under which trafo_data is stored
    """

    # This trafo only operates on procedures
    item_filter = (ProcedureItem,)

    def __init__(self, block_dim):
        self.block_dim = block_dim

    def transform_subroutine(self, routine, **kwargs):

        role = kwargs['role']
        targets = tuple(str(t).lower() for t in as_tuple(kwargs.get('targets', None)))

        exclude_arrays = []
        force_inject_arrays = []
        if (item := kwargs.get('item', None)):
            exclude_arrays = item.config.get('exclude_arrays', [])
            force_inject_arrays = item.config.get('force_inject_arrays', [])

        # we skip routines that do not contain the block index or any known alias
        variable_map = routine.variable_map
        if not (block_index := self.get_block_index(routine, variable_map)):
            return

        # replace strings with variables
        force_inject_arrays = [routine.resolve_typebound_var(var, variable_map) for var in force_inject_arrays]

        if role == 'kernel':
            # for kernels we process the entire subroutine body
            routine.body = self.process_body(routine.body, block_index, targets, exclude_arrays, force_inject_arrays)
        elif role == 'driver':
            with pragmas_attached(routine, ir.Loop):
                for loop in find_driver_loops(routine.body, targets):
                    body = self.process_body(loop.body, block_index, targets, exclude_arrays, force_inject_arrays)
                    body_map = {loop.body: body}
                    Transformer(body_map, inplace=True).visit(loop)

    @staticmethod
    def get_call_arg_rank(arg):
        """
        Utility to retrieve the local rank of a :any:`CallStatement` argument.
        """

        rank = len(getattr(arg, 'shape', ()))
        if getattr(arg, 'dimensions', None):
            # We assume here that the callstatement is free of sequence association
            rank = rank - len([d for d in arg.dimensions if not isinstance(d, RangeIndex)])

        return rank

    def get_block_index(self, routine, variable_map):
        """
        Utility to retrieve the block-index loop induction variable.
        """

        if (block_index := variable_map.get(self.block_dim.index, None)):
            return block_index
        if (block_index := [i for i in self.block_dim.indices
                            if i.split('%', maxsplit=1)[0] in variable_map]):
            return routine.resolve_typebound_var(block_index[0], variable_map)
        return None

    def process_body(self, body, block_index, targets, exclude_arrays, force_inject_arrays):
        # The logic for callstatement args differs from other variables in the body,
        # so we build a list to filter
        call_args = []

        # First get rank mismatched call statement args
        vmap = {}
        for call in FindNodes(ir.CallStatement).visit(body):
            if call.name in targets:
                for dummy, arg in call.arg_map.items():
                    call_args += [arg]
                    arg_rank = self.get_call_arg_rank(arg)
                    dummy_rank = len(getattr(dummy, 'shape', ()))
                    if arg_rank - 1 == dummy_rank:
                        dimensions = getattr(arg, 'dimensions', None) or ((RangeIndex((None, None)),) * (arg_rank - 1))
                        vmap.update({arg: arg.clone(dimensions=dimensions + as_tuple(block_index))})

        # Now get the rest of the variables
        for var in FindVariables(unique=False).visit(body):
            if getattr(var, 'dimensions', None) and not var in call_args:

                local_rank = len(var.dimensions)
                decl_rank = local_rank
                # we assume here that all derived-type components we wish to transform
                # have been parsed
                if getattr(var, 'shape', None):
                    decl_rank = len(var.shape)

                if local_rank == decl_rank - 1:
                    dimensions = getattr(var, 'dimensions', None) or ((RangeIndex((None, None)),) * (decl_rank - 1))
                    vmap.update({var: var.clone(dimensions=dimensions + as_tuple(block_index))})

        for var in force_inject_arrays:
            dimensions = getattr(var, 'dimensions', None) or ((RangeIndex((None, None)),) * (len(var.shape) - 1))
            vmap.update({var: var.clone(dimensions=dimensions + as_tuple(block_index))})

        # filter out arrays marked for exclusion
        vmap = {k: v for k, v in vmap.items() if not any(e in k for e in exclude_arrays)}

        # finally we perform the substitution
        return SubstituteExpressions(vmap).visit(body)


class LowerBlockIndexTransformation(Transformation):
    """
    Transformation to lower the block index via appending the block index
    to variable dimensions/shape. However, this only handles variable
    declarations/definitions. Therefore this transformation must always be followed by
    the :any:`InjectBlockIndexTransformation`.

    For example, the following code:

    .. code-block:: fortran

        SUBROUTINE driver (nlon, nlev, nb, var)
            INTEGER, INTENT(IN) :: nlon, nlev, nb
            REAL, INTENT(INOUT) :: var(nlon, nlev, nb)
            DO ibl=1,10
                CALL kernel(nlon, nlev, var(:, :, ibl))
            END DO
        END SUBROUTINE driver

        SUBROUTINE kernel (nlon, nlev, var, another_var, icend, lstart, lend)
            INTEGER, INTENT(IN) :: nlon, nlev
            REAL, INTENT(INOUT) :: var(nlon, nlev)
            DO jk=1,nlev
                DO jl=1,nlon
                    var(jl, jk) = 0.
                END DO
            END DO
        END SUBROUTINE kernel

    is transformed to:

    .. code-block:: fortran

        SUBROUTINE driver (nlon, nlev, nb, var)
            INTEGER, INTENT(IN) :: nlon, nlev, nb
            REAL, INTENT(INOUT) :: var(nlon, nlev, nb)
            DO ibl=1,10
                CALL kernel(nlon, nlev, var(:, :, :), ibl=ibl, nb=nb)
            END DO
        END SUBROUTINE driver

        SUBROUTINE kernel (nlon, nlev, var, another_var, icend, lstart, lend, ibl, nb)
            INTEGER, INTENT(IN) :: nlon, nlev, ibl, nb
            REAL, INTENT(INOUT) :: var(nlon, nlev, nb)
            DO jk=1,nlev
                DO jl=1,nlon
                    var(jl, jk) = 0.
                END DO
            END DO
        END SUBROUTINE kernel

    .. warning::

        The block index is not injected by this transformation. To inject the block index
        to the promoted arrays call the :any:`InjectBlockIndexTransformation` after
        this transformation!

    Parameters
    ----------
    block_dim : :any:`Dimension`
        :any:`Dimension` object describing the variable conventions used in code
        to define the blocking data dimension and iteration space.
    recurse_to_kernels : bool, optional
        Recurse/continue with/to (nested) kernels and lower the block index for those
        as well (default: `False`).
    """
    # This trafo only operates on procedures
    item_filter = (ProcedureItem,)

    def __init__(self, block_dim, recurse_to_kernels=True):
        self.block_dim = block_dim
        self.recurse_to_kernels = recurse_to_kernels

    def transform_subroutine(self, routine, **kwargs):

        role = kwargs['role']
        targets = tuple(str(t).lower() for t in as_tuple(kwargs.get('targets', None)))
        # dispatch driver in any case and recurse to kernels if corresponding flag is set
        if role == 'driver' or (self.recurse_to_kernels and role == 'kernel'):
            self.process(routine, targets, role)

    def process(self, routine, targets, role):
        """
        This method adds the blocking index and size as arguments (if not already
        there yet) for calls and corresponding callees and updates the dimension
        and shape of arguments for calls and callees.

        .. note::
            The injection of the block index for promoted arrays is not part of this
            method (and also not part of this transformation!)
        """
        processed_routines = ()
        variable_map = routine.variable_map
        block_dim_index = get_integer_variable(routine, self.block_dim.index)
        block_dim_size = get_integer_variable(routine, self.block_dim.size)
        for call in FindNodes(ir.CallStatement).visit(routine.body):
            if str(call.name).lower() not in targets:
                continue
            if call.routine is BasicType.DEFERRED:
                warning('[LowerBlockIndexTransformation] Not processing routine ' \
                        f'{call.name}. Call statement not enriched')
                continue
            call_arg_map = dict((v,k) for k,v in call.arg_map.items())
            call_block_dim_size = call_arg_map.get(block_dim_size, block_dim_size)
            new_args = tuple(var for var in [block_dim_index, block_dim_size] if var not in call_arg_map)
            if new_args:
                call._update(kwarguments=call.kwarguments+tuple((new_arg.name, new_arg) for new_arg in new_args))
                if call.routine.name not in processed_routines:
                    call.routine.arguments += tuple((variable_map[new_arg.name].clone(scope=call.routine,
                        type=new_arg.type.clone(intent='in')) for new_arg in new_args))
            # update dimensions and shape
            var_map = {}
            call_variable_map = call.routine.variable_map
            for arg, call_arg in call.arg_iter():
                if isinstance(arg, Array) and len(call_arg.shape) > len(arg.shape):
                    call_routine_var = call_variable_map[arg.name]
                    new_dims = call_routine_var.dimensions + (call_variable_map[call_block_dim_size.name],)
                    new_shape = call_routine_var.shape + (call_variable_map[call_block_dim_size.name],)
                    new_type = call_routine_var.type.clone(shape=new_shape)
                    var_map[call_routine_var] = call_routine_var.clone(dimensions=new_dims, type=new_type)
            call.routine.spec = SubstituteExpressions(var_map).visit(call.routine.spec)
            if role == 'driver':
                _arguments = ()
                for arg in call.arguments:
                    if isinstance(arg, sym.Array):
                        new_dim = tuple(
                            sym.RangeIndex((None, None)) if str(dim).lower() == self.block_dim.index.lower()
                            else dim for dim in arg.dimensions
                        )
                        _arguments += (arg.clone(dimensions=new_dim),)
                    else:
                        _arguments += (arg,)
                _kwarguments = ()
                for kwarg_name, kwarg in call.kwarguments:
                    if isinstance(kwarg, sym.Array):
                        new_dim = tuple(
                            sym.RangeIndex((None, None)) if str(dim).lower() == self.block_dim.index.lower()
                            else dim for dim in kwarg.dimensions
                        )
                        _kwarguments += ((kwarg_name, kwarg.clone(dimensions=new_dim)),)
                    else:
                        _kwarguments += ((kwarg_name, kwarg),)
                call._update(arguments=_arguments, kwarguments=_kwarguments)
            processed_routines += (call.routine.name,)


class LowerBlockLoopTransformation(Transformation):
    """
    Lower the block loop to calls within this loop.

    For example, the following code:

    .. code-block:: fortran

        subroutine driver(nblks, ...)
            ...
            integer, intent(in) :: nblks
            integer :: ibl
            real :: var(jlon,nlev,nblks)

            do ibl=1,nblks
                call kernel2(var,...,nblks,ibl)
            enddo
            ...
        end subroutine driver

        subroutine kernel(var, ..., nblks, ibl)
            ...
            real :: var(jlon,nlev,nblks)

            do jl=1,...
                do jk=1,...
                    var(jk,jl,ibl) = ...
                end do
            end do
        end subroutine kernel

    is transformed to:

    .. code-block:: fortran

        subroutine driver(nblks, ...)
            ...
            integer, intent(in) :: nblks
            integer :: ibl
            real :: var(jlon,nlev,nblks)

            call kernel2(var,..., nblks)
            ...
        end subroutine driver

        subroutine kernel(var, ..., nblks)
            ...
            integer :: ibl
            real :: var(jlon,nlev,nblks)

            do ibl=1,nblks
                do jl=1,...
                    do jk=1,...
                        var(jk,jl,ibl) = ...
                    end do
                end do
            end do
        end subroutine kernel

    Parameters
    ----------
    block_dim : :any:`Dimension`
        :any:`Dimension` object describing the variable conventions used in code
        to define the blocking data dimension and iteration space.
    """
    # This trafo only operates on procedures
    item_filter = (ProcedureItem,)

    def __init__(self, block_dim):
        self.block_dim = block_dim

    def transform_subroutine(self, routine, **kwargs):
        role = kwargs['role']
        targets = tuple(str(t).lower() for t in as_tuple(kwargs.get('targets', None)))
        if role == 'driver':
            self.process_driver(routine, targets)

    @staticmethod
    def arg_to_local_var(routine, var):
        new_args = tuple(arg for arg in routine.arguments if arg.name.lower() != var.name.lower())
        routine.arguments = new_args
        routine.variables += (routine.variable_map[var.name].clone(scope=routine,
            type=routine.variable_map[var.name].type.clone(intent=None)),)

    def local_var(self, call, variables):
        inv_call_arg_map = {v: k for k, v in call.arg_map.items()}
        call_routine_variables = call.routine.variables
        for var in variables:
            if var in inv_call_arg_map:
                self.arg_to_local_var(call.routine, inv_call_arg_map[var])
            else:
                if var not in call_routine_variables:
                    call.routine.variables += (var.clone(scope=call.routine),)

    @staticmethod
    def generate_pragma(loop):
        return ir.Pragma(keyword="loki", content=f"removed_loop var({loop.variable}) \
                    lower({loop.bounds.lower}) upper({loop.bounds.upper}) \
                    step({loop.bounds.step if loop.bounds.step else 1})")

    def update_call_signature(self, call, loop, loop_defined_symbols, additional_kwargs):
        ignore_symbols = [loop.variable.name.lower()] +\
            [symbol.name.lower() for symbol in loop_defined_symbols]
        _arguments = tuple(arg for arg in call.arguments\
                if not (hasattr(arg, 'name') and arg.name.lower() in ignore_symbols))
        _kwarguments = tuple(kwarg for kwarg in call.kwarguments \
                if kwarg[1].name.lower() not in ignore_symbols) + as_tuple(additional_kwargs.items())
        call_pragmas = (self.generate_pragma(loop),)
        call._update(arguments=_arguments, kwarguments=_kwarguments,
                pragma=(call.pragma if call.pragma else ()) + call_pragmas)

    def process_driver(self, routine, targets):
        # find block loops
        with pragma_regions_attached(routine):
            with pragmas_attached(routine, ir.Loop):
                loops = FindNodes(ir.Loop).visit(routine.body)
                loops = [loop for loop in loops if loop.variable in self.block_dim.indices]

                # Remove parallel regions around block loops
                pragma_region_map = {}
                for pragma_region in FindNodes(ir.PragmaRegion).visit(routine.body):
                    for loop in loops:
                        if loop in pragma_region.body:
                            pragma_region_map[pragma_region] = pragma_region.body
                routine.body = Transformer(pragma_region_map, inplace=True).visit(routine.body)

                driver_loop_map = {}
                processed_routines = ()
                calls = ()
                additional_kwargs = {}
                for loop in loops:
                    target_calls = [call for call in FindNodes(ir.CallStatement).visit(loop.body)
                            if str(call.name).lower() in targets]
                    target_calls = [call for call in target_calls if call.routine is not BasicType.DEFERRED]
                    if not target_calls:
                        continue
                    calls += tuple(target_calls)
                    driver_loop_map[loop] = loop.body
                    defined_symbols_loop = [assign.lhs for assign in FindNodes(ir.Assignment).visit(loop.body)]
                    for call in target_calls:
                        if call.routine.name in processed_routines:
                            self.update_call_signature(call, loop, defined_symbols_loop,
                                    additional_kwargs[call.routine.name])
                            continue
                        # 1. Create a copy of the loop with all other call statements removed
                        other_calls = {c: None for c in FindNodes(ir.CallStatement).visit(loop) if c is not call}
                        loop_to_lower = Transformer(other_calls).visit(loop)

                        # 2. Replace all variables according to the caller-callee argument map
                        call_arg_map = dict((v, k) for k, v in call.arg_map.items())
                        loop_to_lower = SubstituteExpressions(call_arg_map).visit(loop_to_lower)

                        # 3. Identify local variables that need to be provided as additional arguments to the call
                        call_routine_variables = {v.name.lower() for v in FindVariables().visit(call.routine.body)}
                        call_routine_variables |= {v.name.lower() for v in call.routine.variables}
                        loop_variables = FindVariables().visit(loop_to_lower.body)
                        loop_variables = [
                            v for v in FindVariables().visit(loop_to_lower.body)
                            if v.name.lower() != loop.variable and v.name.lower() not in call_routine_variables
                            and v not in call_arg_map and isinstance(v, sym.Scalar) and v not in defined_symbols_loop
                        ]
                        additional_kwargs[call.routine.name] = {var.name: var for var in loop_variables}

                        # 4. Inject the loop body into the called routine
                        call.routine.arguments += tuple(additional_kwargs[call.routine.name].values())
                        routine_body = Transformer({c: c.routine.body for c in\
                            FindNodes(ir.CallStatement).visit(loop_to_lower)}).visit(loop_to_lower)
                        routine_body = AttachScopes().visit(routine_body, scope=call.routine)
                        call.routine.body = ir.Section(body=as_tuple(routine_body))

                        # 5. Update the call on the caller side
                        processed_routines += (call.routine.name,)
                        self.local_var(call, defined_symbols_loop + [loop.variable])
                        self.update_call_signature(call, loop, defined_symbols_loop,
                                additional_kwargs[call.routine.name])
                    driver_loop_map[loop] = loop.body
                routine.body = Transformer(driver_loop_map).visit(routine.body)
loki-ecmwf-0.3.6/loki/batch/0000775000175000017500000000000015167130205015760 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/batch/__init__.py0000664000175000017500000000214515167130205020073 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
"""
Batch processing abstraction for processing large source trees with Loki.

This sub-package provides the :any:`Scheduler` class that allows Loki
transformations to be applied over large source trees. For this it
provides the basic :any:`Transformation` and :any:`Pipeline` classes
that provide the core interfaces for batch processing, as well as the
configuration utilities for large call tree traversals.
"""

from loki.batch.configure import * # noqa
from loki.batch.item import * # noqa
from loki.batch.item_factory import * # noqa
from loki.batch.pipeline import * # noqa
from loki.batch.scheduler import * # noqa
from loki.batch.sfilter import * # noqa
from loki.batch.sgraph import * # noqa
from loki.batch.transformation import * # noqa
loki-ecmwf-0.3.6/loki/batch/tests/0000775000175000017500000000000015167130205017122 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/batch/tests/test_batch.py0000664000175000017500000020763715167130205021633 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

# pylint: disable=too-many-lines

from collections import deque
from pathlib import Path
import re
import networkx as nx
import pytest

from loki import (
    Sourcefile, Subroutine, as_tuple, RawSource, TypeDef,
    Scalar, ProcedureSymbol
)
from loki.batch import (
    FileItem, ModuleItem, ProcedureItem, TypeDefItem,
    ProcedureBindingItem, ExternalItem, InterfaceItem, SGraph,
    SchedulerConfig, ItemFactory
)
from loki.frontend import HAVE_FP, REGEX, RegexParserClass
from loki.ir import nodes as ir


pytestmark = pytest.mark.skipif(not HAVE_FP, reason='Fparser not available')


@pytest.fixture(scope='module', name='here')
def fixture_here():
    return Path(__file__).parent


@pytest.fixture(scope='module', name='testdir')
def fixture_testdir(here):
    return here.parent.parent/'tests'


@pytest.fixture(name='default_config', scope='function')
def fixture_default_config():
    """
    Default SchedulerConfig configuration with basic options.
    """
    return {
        'default': {
            'mode': 'idem',
            'role': 'kernel',
            'expand': True,
            'strict': True,
            'disable': ['abort']
        },
        'routines': []
    }


@pytest.fixture(name='comp1_expected_dependencies')
def fixture_comp1_expected_dependencies():
    return {
        '#comp1': ('header_mod', 't_mod', 't_mod#t', '#comp2', 't_mod#t%proc', 't_mod#t%no%way'),
        '#comp2': ('header_mod', 't_mod#t', 'a_mod#a', 'b_mod#b', 't_mod#t%yay%proc'),
        'a_mod#a': ('header_mod',),
        'b_mod#b': (),
        't_mod': ('tt_mod#tt', 'tt_mod#intf'),
        't_mod#t': ('tt_mod#tt', 't_mod#t1'),
        't_mod#t1': (),
        't_mod#t%proc': ('t_mod#t_proc',),
        't_mod#t_proc': ('t_mod#t', 'a_mod#a', 't_mod#t%yay%proc'),
        't_mod#t%no%way': ('t_mod#t1%way',),
        't_mod#t%yay%proc': ('tt_mod#tt%proc',),
        't_mod#t1%way': ('t_mod#my_way',),
        't_mod#my_way': ('t_mod#t1',),
        'tt_mod#tt': (),
        'tt_mod#tt%proc': ('tt_mod#proc',),
        'tt_mod#proc': ('tt_mod#tt',),
        'tt_mod#intf': ('tt_mod#proc',),
        'header_mod': (),
    }


@pytest.fixture(name='mod_proc_expected_dependencies')
def fixture_mod_proc_expected_dependencies():
    return {
        'other_mod#mod_proc': ('tt_mod#tt', 'tt_mod#tt%proc', 'b_mod#b'),
        'tt_mod#tt': (),
        'tt_mod#tt%proc': ('tt_mod#proc',),
        'tt_mod#proc': ('tt_mod#tt',),
        'b_mod#b': ()
    }


@pytest.fixture(name='expected_dependencies')
def fixture_expected_dependencies(comp1_expected_dependencies, mod_proc_expected_dependencies):
    dependencies = {}
    dependencies.update(comp1_expected_dependencies)
    dependencies.update(mod_proc_expected_dependencies)
    return dependencies


@pytest.fixture(name='no_expected_dependencies')
def fixture_no_expected_dependencies():
    return {}


@pytest.fixture(name='file_dependencies')
def fixture_file_dependencies():
    return {
        'source/comp1.F90': (
            'source/comp2.f90',
            'module/t_mod.F90',
            'headers/header_mod.F90'
        ),
        'source/comp2.f90': (
            'module/t_mod.F90',
            'headers/header_mod.F90',
            'module/a_mod.F90',
            'module/b_mod.F90'
        ),
        'module/t_mod.F90': (
            'module/tt_mod.F90',
            'module/a_mod.F90',
        ),
        'module/tt_mod.F90': (),
        'module/a_mod.F90': (
            'headers/header_mod.F90',
        ),
        'module/b_mod.F90': (),
        'headers/header_mod.F90': ()
    }


class VisGraphWrapper:
    """
    Testing utility to parse the generated callgraph visualisation.
    """

    _re_nodes = re.compile(r'\s*\"?(?P[\w%#./]+)\"? \[colo', re.IGNORECASE)
    _re_edges = re.compile(r'\s*\"?(?P[\w%#./]+)\"? -> \"?(?P[\w%#./]+)\"?', re.IGNORECASE)

    def __init__(self, path):
        self.text = Path(path).read_text()

    @property
    def nodes(self):
        return list(self._re_nodes.findall(self.text))

    @property
    def edges(self):
        return list(self._re_edges.findall(self.text))


def get_item(cls, path, name, parser_classes, scheduler_config=None):
    source = Sourcefile.from_file(path, frontend=REGEX, parser_classes=parser_classes)
    if scheduler_config:
        config = scheduler_config.create_item_config(name)
    else:
        config = None
    return cls(name, source=source, config=config)


def test_file_item1(testdir, default_config):
    proj = testdir/'sources/projBatch'

    # A file with simple module that contains a single subroutine
    item = get_item(FileItem, proj/'module/a_mod.F90', 'module/a_mod.F90', RegexParserClass.EmptyClass)
    assert item.name == 'module/a_mod.F90'
    assert item.local_name == item.name
    assert item.scope_name is None
    assert not item.scope
    assert item.ir is item.source
    assert str(item) == 'loki.batch.FileItem'

    # A few checks on the item comparison
    assert item == 'module/a_mod.F90'
    assert item != FileItem('some_name', source=item.source)
    assert item == FileItem(item.name, source=item.source)

    # The file is not parsed at all
    assert not item.source.definitions
    assert isinstance(item.source.ir, ir.Section)
    assert len(item.source.ir.body) == 1
    assert isinstance(item.source.ir.body[0], RawSource)

    # Querying definitions triggers a round of parsing
    assert item.definitions == (item.source['a_mod'],)
    assert len(item.source.definitions) == 1

    # Without the FileItem in the item_cache, the modules will be created as ExternalItem
    assert all(
        isinstance(_item, ExternalItem) and _item.origin_cls is ModuleItem
        for _item in item.create_definition_items(
            item_factory=ItemFactory(), config=SchedulerConfig.from_dict(default_config)
        )
    )

    # Check that external item raises an exception whenever we try to access any IR nodes
    external_item = item.create_definition_items(
        item_factory=ItemFactory(), config=SchedulerConfig.from_dict(default_config)
    )[0]

    for attr in ('ir', 'scope', 'path'):
        with pytest.raises(RuntimeError):
            getattr(external_item, attr)

    item_factory = ItemFactory()
    item_factory.item_cache[item.name] = item
    items = item.create_definition_items(item_factory=item_factory)
    assert len(items) == 1
    assert items[0] != None  # pylint: disable=singleton-comparison  # (intentionally trigger __eq__ here)
    assert items[0].name == 'a_mod'
    assert items[0].definitions == (item.source['a'],)

    # The default behavior would be to have the ProgramUnits parsed already
    item = get_item(FileItem, proj/'module/a_mod.F90', 'module/a_mod.F90', RegexParserClass.ProgramUnitClass)
    assert item.name == 'module/a_mod.F90'
    assert item.definitions == (item.source['a_mod'],)
    assert item.ir is item.source
    item_factory = ItemFactory()
    item_factory.item_cache[item.name] = item
    items = item.create_definition_items(item_factory=item_factory)
    assert len(items) == 1
    assert items[0].name == 'a_mod'
    assert items[0].definitions == (item.source['a'],)


def test_file_item2(testdir):
    proj = testdir/'sources/projBatch'

    # A file with a simple module that contains a single typedef
    item = get_item(FileItem, proj/'module/t_mod.F90', 'module/t_mod.F90', RegexParserClass.ProgramUnitClass)
    assert item.name == 'module/t_mod.F90'
    assert item.definitions == (item.source['t_mod'],)

    item_factory = ItemFactory()
    item_factory.item_cache[item.name] = item
    items = item.create_definition_items(item_factory=item_factory)
    assert len(items) == 1
    assert items[0].name == 't_mod'
    assert items[0].ir is item.source['t_mod']
    # No typedefs because not selected in parser classes
    assert not items[0].ir.typedefs
    # Calling definitions automatically further completes the source
    assert items[0].definitions == (
        items[0].ir['t_proc'],
        items[0].ir['my_way'],
        items[0].ir.typedef_map['t1'],
        items[0].ir.typedef_map['t'],
    )

    # Files don't have dependencies (direct dependencies, anyway)
    assert isinstance(item.dependencies, tuple) and not item.dependencies


def test_file_item3(testdir):
    proj = testdir/'sources/projBatch'

    # The same file but with typedefs parsed from the get-go
    item = get_item(
        FileItem, proj/'module/t_mod.F90', 'module/t_mod.F90',
        RegexParserClass.ProgramUnitClass | RegexParserClass.TypeDefClass
    )
    assert item.name == 'module/t_mod.F90'
    assert item.definitions == (item.source['t_mod'],)

    item_factory = ItemFactory()
    item_factory.item_cache[item.name] = item
    items = item.create_definition_items(item_factory=item_factory)
    assert len(items) == 1
    assert items[0].name == 't_mod'
    assert len(items[0].ir.typedefs) == 2
    assert items[0].definitions == (
        item.source['t_proc'],
        item.source['my_way'],
        item.source['t1'],
        item.source['t'],
    )

    # Filter items when calling create_definition_items()
    item_factory = ItemFactory()
    item_factory.item_cache[item.name] = item
    items = item.create_definition_items(item_factory=item_factory)
    assert not item.create_definition_items(only=ProcedureItem, item_factory=item_factory)
    items = item.create_definition_items(only=ModuleItem, item_factory=item_factory)
    assert len(items) == 1
    assert isinstance(items[0], ModuleItem)
    assert items[0].ir == item.source['t_mod']


def test_module_item1(testdir):
    proj = testdir/'sources/projBatch'

    # A file with simple module that contains a single subroutine and has no dependencies on
    # the module level
    item = get_item(ModuleItem, proj/'module/a_mod.F90', 'a_mod', RegexParserClass.ProgramUnitClass)
    assert item.name == 'a_mod'
    assert item == 'a_mod'
    assert str(item) == 'loki.batch.ModuleItem'
    assert item.ir is item.source['a_mod']
    assert item.definitions == (item.source['a'],)

    item_factory = ItemFactory()
    item_factory.item_cache[item.name] = item
    items = item.create_definition_items(item_factory=item_factory)
    assert len(items) == 1
    assert isinstance(items[0], ProcedureItem)
    assert items[0].ir == item.source['a']

    assert not item.dependencies


def test_module_item2(testdir):
    proj = testdir/'sources/projBatch'

    # A different file with a simple module that contains a single subroutine but has an import
    # dependency on the module level
    item = get_item(ModuleItem, proj/'module/b_mod.F90', 'b_mod', RegexParserClass.ProgramUnitClass)
    assert item.name == 'b_mod'
    assert item.ir is item.source['b_mod']
    assert item.definitions == (item.source['b'],)

    item_factory = ItemFactory()
    item_factory.item_cache[item.name] = item
    items = item.create_definition_items(item_factory=item_factory)
    assert len(items) == 1
    assert isinstance(items[0], ProcedureItem)
    assert items[0].ir == item.source['b']

    dependencies = item.dependencies
    assert len(dependencies) == 1
    assert isinstance(dependencies[0], ir.Import)
    assert dependencies[0].module == 'header_mod'


def test_module_item3(testdir):
    proj = testdir/'sources/projBatch'

    # Make sure the dependencies are also found correctly if done without parsing definitions first
    item = get_item(ModuleItem, proj/'module/b_mod.F90', 'b_mod', RegexParserClass.ProgramUnitClass)
    dependencies = item.dependencies
    assert len(dependencies) == 1 and dependencies[0].module == 'header_mod'


def test_module_item4(testdir):
    proj = testdir/'sources/projInlineCalls'

    # Make sure interfaces are correctly identified as definitions
    item = get_item(ModuleItem, proj/'some_module.F90', 'some_module', RegexParserClass.ProgramUnitClass)
    definitions = item.definitions
    assert len(definitions) == 6
    assert len(item.ir.interfaces) == 1
    assert item.ir.interfaces[0] in definitions

    item_factory = ItemFactory()
    item_factory.item_cache[item.name] = item

    items = item.create_definition_items(item_factory=item_factory)
    assert len(items) == 8
    assert len(set(items)) == 6
    assert 'some_module#add_args' in item_factory.item_cache
    assert isinstance(item_factory.item_cache['some_module#add_args'], InterfaceItem)
    assert item_factory.item_cache['some_module#add_args'] in items


def test_procedure_item1(testdir):
    proj = testdir/'sources/projBatch'

    # A file with a single subroutine definition that calls a routine via interface block
    item = get_item(ProcedureItem, proj/'source/comp1.F90', '#comp1', RegexParserClass.ProgramUnitClass)
    assert item.name == '#comp1'
    assert item == '#comp1'
    assert str(item) == 'loki.batch.ProcedureItem<#comp1>'
    assert item.ir is item.source['comp1']
    assert isinstance(item.definitions, tuple) and not item.definitions

    assert not item.create_definition_items(item_factory=ItemFactory())

    dependencies = item.dependencies
    assert len(dependencies) == 5
    assert isinstance(dependencies[0], ir.Import)
    assert dependencies[0].module == 't_mod'
    assert isinstance(dependencies[1], ir.Import)
    assert dependencies[1].module == 'header_mod'
    assert isinstance(dependencies[2], ir.CallStatement)
    assert dependencies[2].name == 'arg%proc'
    assert isinstance(dependencies[3], ir.CallStatement)
    assert dependencies[3].name == 'comp2'
    assert isinstance(dependencies[4], ir.CallStatement)
    assert dependencies[4].name == 'arg%no%way'

    assert item.targets == ('t_mod', 't', 'nt1', 'header_mod', 'arg%proc', 'comp2', 'arg%no%way')

    # We need to have suitable dependency modules in the cache to spawn the dependency items
    item_factory = ItemFactory()
    item_factory.item_cache[item.name] = item
    item_factory.item_cache.update({
        (i := get_item(ModuleItem, proj/path, name, RegexParserClass.ProgramUnitClass)).name: i
        for path, name in [
            ('module/t_mod.F90', 't_mod'), ('source/comp2.f90', '#comp2'), ('headers/header_mod.F90', 'header_mod')
        ]
    })

    # To ensure any existing items from the item_cache are re-used, we instantiate one for
    # the procedure binding
    # pylint: disable=unsupported-binary-operation
    t_mod_t_proc = get_item(
        ProcedureBindingItem, proj/'module/t_mod.F90', 't_mod#t%proc',
        RegexParserClass.ProgramUnitClass | RegexParserClass.TypeDefClass | RegexParserClass.DeclarationClass
    )
    item_factory.item_cache[t_mod_t_proc.name] = t_mod_t_proc

    items = item.create_dependency_items(item_factory=item_factory)
    assert items == ('t_mod', 't_mod#t', 'header_mod', 't_mod#t%proc', '#comp2', 't_mod#t%no%way')
    assert item_factory.item_cache[t_mod_t_proc.name] is t_mod_t_proc
    assert items[3] is t_mod_t_proc


def test_procedure_item2(testdir):
    proj = testdir/'sources/projBatch'

    # A file with a single subroutine definition that calls two routines via module imports
    item = get_item(ProcedureItem, proj/'source/comp2.f90', '#comp2', RegexParserClass.ProgramUnitClass)
    assert item.name == '#comp2'
    assert item.ir is item.source['comp2']
    assert isinstance(item.definitions, tuple) and not item.definitions

    item_factory = ItemFactory()
    assert not item.create_definition_items(item_factory=item_factory)

    dependencies = item.dependencies
    assert len(dependencies) == 7
    assert isinstance(dependencies[0], ir.Import)
    assert dependencies[0].module == 't_mod'
    assert isinstance(dependencies[1], ir.Import)
    assert dependencies[1].module == 'header_mod'
    assert isinstance(dependencies[2], ir.Import)
    assert dependencies[2].module == 'a_mod'
    assert isinstance(dependencies[3], ir.Import)
    assert dependencies[3].module == 'b_mod'
    assert isinstance(dependencies[4], ir.CallStatement)
    assert dependencies[4].name == 'a'
    assert isinstance(dependencies[5], ir.CallStatement)
    assert dependencies[5].name == 'b'
    assert isinstance(dependencies[6], ir.CallStatement)
    assert dependencies[6].name == 'arg%yay%proc'

    assert item.targets == (
        't_mod', 't', 'header_mod', 'k',
        'a_mod', 'a', 'b_mod', 'b', 'arg%yay%proc'
    )

    # We need to have suitable dependency modules in the cache to spawn the dependency items
    item_factory.item_cache[item.name] = item
    item_factory.item_cache.update({
        (i := get_item(ModuleItem, proj/path, name, RegexParserClass.ProgramUnitClass)).name: i
        for path, name in [
            ('module/t_mod.F90', 't_mod'), ('module/a_mod.F90', 'a_mod'),
            ('module/b_mod.F90', 'b_mod'), ('headers/header_mod.F90', 'header_mod')
        ]
    })
    items = item.create_dependency_items(item_factory=item_factory)
    assert items == ('t_mod#t', 'header_mod', 'a_mod#a', 'b_mod#b', 't_mod#t%yay%proc')

    # Does it still work if we call it again?
    assert items == item.create_dependency_items(item_factory=item_factory)


def test_procedure_item3(testdir):
    proj = testdir/'sources/projBatch'

    # A file with a single subroutine declared in a module that calls a typebound procedure
    # where the type is imported via an import statement in the module scope
    item = get_item(
        ProcedureItem, proj/'module/other_mod.F90', 'other_mod#mod_proc',
        RegexParserClass.ProgramUnitClass
    )
    dependencies = item.dependencies
    assert len(dependencies) == 3
    assert dependencies[0].module == 'tt_mod'
    assert dependencies[1].name == 'arg%proc'
    assert dependencies[2].name == 'b'

    assert item.targets == ('tt_mod', 'tt', 'arg%proc', 'b')

    item_factory = ItemFactory()
    item_factory.item_cache.update({
        item.name: item,
        'tt_mod': get_item(ModuleItem, proj/'module/tt_mod.F90', 'tt_mod', RegexParserClass.ProgramUnitClass),
        'b_mod': get_item(ModuleItem, proj/'module/b_mod.F90', 'b_mod', RegexParserClass.ProgramUnitClass)
    })
    assert item.create_dependency_items(item_factory=item_factory) == ('tt_mod#tt', 'tt_mod#tt%proc', 'b_mod#b')


def test_procedure_item4(testdir):
    proj = testdir/'sources/projBatch'

    # A routine with a typebound procedure call where the typedef is in the same module
    item = get_item(
        ProcedureItem, proj/'module/t_mod.F90', 't_mod#my_way', RegexParserClass.ProgramUnitClass
    )
    dependencies = item.dependencies
    assert len(dependencies) == 2
    assert dependencies[0].name == 't1'
    assert dependencies[1].name == 'this%way'

    assert item.targets == ('t1', 'this%way')

    item_factory = ItemFactory()
    item_factory.item_cache.update({
        item.name: item,
        't_mod': ModuleItem('t_mod', source=item.source)
    })
    items = item.create_dependency_items(item_factory=item_factory)
    assert items == ('t_mod#t1', 't_mod#t1%way')


@pytest.mark.parametrize('config,expected_dependencies,expected_targets', [
    (
        {},
        ('t_mod#t', 'header_mod', 'a_mod#a', 'b_mod#b', 't_mod#t%yay%proc'),
        ('t_mod', 't', 'header_mod', 'k', 'a_mod', 'a', 'b_mod', 'b', 'arg%yay%proc')
    ),
    (
        {'default': {'disable': ['#a']}},
        ('t_mod#t', 'header_mod', 'a_mod#a', 'b_mod#b', 't_mod#t%yay%proc'),
        ('t_mod', 't', 'header_mod', 'k', 'a_mod', 'a', 'b_mod', 'b', 'arg%yay%proc')
    ),
    (
        {'default': {'disable': ['a']}},
        ('t_mod#t', 'header_mod', 'b_mod#b', 't_mod#t%yay%proc'),
        ('t_mod', 't', 'header_mod', 'k', 'a_mod', 'b_mod', 'b', 'arg%yay%proc')
    ),
    (
        {'default': {'disable': ['a', 'a_mod']}},
        ('t_mod#t', 'header_mod', 'b_mod#b', 't_mod#t%yay%proc'),
        ('t_mod', 't', 'header_mod', 'k', 'b_mod', 'b', 'arg%yay%proc'),
    ),
    (
        {'default': {'disable': ['a_mod#a']}},
        ('t_mod#t', 'header_mod', 'b_mod#b', 't_mod#t%yay%proc'),
        ('t_mod', 't', 'header_mod', 'k', 'a_mod', 'b_mod', 'b', 'arg%yay%proc')
    ),
    (
        {'default': {'disable': ['a_mod']}},
        ('t_mod#t', 'header_mod', 'b_mod#b', 't_mod#t%yay%proc'),
        ('t_mod', 't', 'header_mod', 'k', 'b_mod', 'b', 'arg%yay%proc')
    ),
    (
        {'default': {'disable': ['t%yay%proc']}},
        ('t_mod#t', 'header_mod', 'a_mod#a', 'b_mod#b'),
        ('t_mod', 't', 'header_mod', 'k', 'a_mod', 'a', 'b_mod', 'b')
    ),
    (
        {'default': {'disable': ['t_mod#t%yay%proc']}},
        ('t_mod#t', 'header_mod', 'a_mod#a', 'b_mod#b'),
        ('t_mod', 't', 'header_mod', 'k', 'a_mod', 'a', 'b_mod', 'b')
    ),
    (
        {'default': {'disable': ['t_mod#t']}},
        ('header_mod', 'a_mod#a', 'b_mod#b'),
        ('t_mod', 'header_mod', 'k', 'a_mod', 'a', 'b_mod', 'b')
    ),
    (
        {'default': {'disable': ['t_mod']}},
        ('header_mod', 'a_mod#a', 'b_mod#b'),
        ('header_mod', 'k', 'a_mod', 'a', 'b_mod', 'b')
    ),
    (
        {'default': {'disable': ['header_mod']}},
        ('t_mod#t', 'a_mod#a', 'b_mod#b', 't_mod#t%yay%proc'),
        ('t_mod', 't', 'a_mod', 'a', 'b_mod', 'b', 'arg%yay%proc')
    ),
    (
        {'default': {'disable': ['k']}},
        ('t_mod#t', 'a_mod#a', 'b_mod#b', 't_mod#t%yay%proc'),
        ('t_mod', 't', 'header_mod', 'a_mod', 'a', 'b_mod', 'b', 'arg%yay%proc')
    ),
])
def test_procedure_item_with_config(testdir, config, expected_dependencies, expected_targets):
    proj = testdir/'sources/projBatch'
    scheduler_config = SchedulerConfig.from_dict(config)

    # A file with a single subroutine definition that calls two routines via module imports
    item = get_item(
        ProcedureItem, proj/'source/comp2.f90', '#comp2',
        RegexParserClass.ProgramUnitClass, scheduler_config=scheduler_config
    )

    # We need to have suitable dependency modules in the cache to spawn the dependency items
    item_factory = ItemFactory()
    item_factory.item_cache[item.name] = item
    item_factory.item_cache.update({
        (i := get_item(
            ModuleItem, proj/path, name,
            RegexParserClass.ProgramUnitClass, scheduler_config=scheduler_config
        )).name: i
        for path, name in [
            ('module/t_mod.F90', 't_mod'), ('module/a_mod.F90', 'a_mod'),
            ('module/b_mod.F90', 'b_mod'), ('headers/header_mod.F90', 'header_mod')
        ]
    })
    assert item.create_dependency_items(item_factory=item_factory, config=scheduler_config) == expected_dependencies


    assert as_tuple(item.disable) == as_tuple(config.get('default', {}).get('disable', []))
    assert item.targets == as_tuple(expected_targets)


@pytest.mark.parametrize('disable', ['#comp2', 'comp2'])
def test_procedure_item_with_config2(testdir, disable):
    proj = testdir/'sources/projBatch'
    scheduler_config = SchedulerConfig.from_dict({'default': {'disable': [disable]}})

    # Similar to the previous test but checking disabling of subroutines without scope
    item = get_item(
        ProcedureItem, proj/'source/comp1.F90', '#comp1',
        RegexParserClass.ProgramUnitClass, scheduler_config=scheduler_config)

    item_factory = ItemFactory()
    item_factory.item_cache[item.name] = item
    item_factory.item_cache['t_mod'] = get_item(
        ModuleItem, proj/'module/t_mod.F90', 't_mod', RegexParserClass.ProgramUnitClass
    )
    item_factory.item_cache['header_mod'] = get_item(
        ModuleItem, proj/'headers/header_mod.F90', 'header_mod',
        RegexParserClass.ProgramUnitClass, scheduler_config=scheduler_config
    )
    assert item.create_dependency_items(item_factory=item_factory, config=scheduler_config) == (
        't_mod', 't_mod#t', 'header_mod', 't_mod#t%proc', 't_mod#t%no%way'
    )

    assert item.targets == ('t_mod', 't', 'nt1', 'header_mod', 'arg%proc', 'arg%no%way')


@pytest.mark.parametrize('enable_imports', [False, True])
def test_procedure_item_external_item(tmp_path, enable_imports, default_config):
    """
    Test that dependencies to external module procedures are marked as external item
    """
    fcode = """
subroutine procedure_item_external_item
    use external_mod, only: external_proc, unused_external_proc, external_type, external_var
    implicit none
    type(external_type) :: my_type

    call external_proc(1)

    my_type%my_val = external_var
end subroutine procedure_item_external_item
    """
    filepath = tmp_path/'procedure_item_external_item.F90'
    filepath.write_text(fcode)

    default_config['default']['enable_imports'] = enable_imports
    scheduler_config = SchedulerConfig.from_dict(default_config)
    item = get_item(
        ProcedureItem, filepath, '#procedure_item_external_item',
        RegexParserClass.ProgramUnitClass, scheduler_config
    )
    item_factory = ItemFactory()
    item_factory.item_cache[item.name] = item
    items = item.create_dependency_items(item_factory=item_factory, config=scheduler_config)

    # NB: dependencies to imported symbols are not added as external items because it would be impossible
    #     to determine their type. Instead, the external module is marked as a dependency, regardless if
    #     imports are enabled or not.
    #     However, the external procedure with a call statement is recognized as an external procedure
    #     and therefore included in the dependency tree.
    assert items == ('external_mod', 'external_mod#external_proc')
    assert all(isinstance(it, ExternalItem) for it in items)
    assert [it.origin_cls for it in items] == [ModuleItem, ProcedureItem]


@pytest.mark.parametrize('strict', [False, True])
def test_procedure_item_external_item_intfb(tmp_path, strict, default_config):
    """
    Test that dependencies to external procedures are marked as external item
    """
    fcode = """
subroutine procedure_item_external_item
    implicit none
#include "external_proc.intfb.h"

    call external_proc(1)
end subroutine procedure_item_external_item
    """
    filepath = tmp_path/'procedure_item_external_item.F90'
    filepath.write_text(fcode)

    default_config['default']['strict'] = strict
    scheduler_config = SchedulerConfig.from_dict(default_config)
    item = get_item(
        ProcedureItem, filepath, '#procedure_item_external_item',
        RegexParserClass.ProgramUnitClass, scheduler_config
    )
    item_factory = ItemFactory()
    item_factory.item_cache[item.name] = item

    if strict:
        with pytest.raises(RuntimeError):
            item.create_dependency_items(item_factory=item_factory, config=scheduler_config)
    else:
        items = item.create_dependency_items(item_factory=item_factory, config=scheduler_config)
        assert items == ('#external_proc',)
        assert isinstance(items[0], ExternalItem)
        assert items[0].origin_cls == ProcedureItem


def test_procedure_item_from_item1(testdir, default_config):
    proj = testdir/'sources/projBatch'

    # A file with a single subroutine definition that calls a routine via interface block
    item_factory = ItemFactory()
    scheduler_config = SchedulerConfig.from_dict(default_config)
    file_item = item_factory.get_or_create_file_item_from_path(proj/'source/comp1.F90', config=scheduler_config)
    item = file_item.create_definition_items(item_factory=item_factory, config=scheduler_config)[0]
    assert item.name == '#comp1'
    assert isinstance(item, ProcedureItem)

    expected_cache = {str(proj/'source/comp1.F90').lower(), '#comp1'}
    assert set(item_factory.item_cache) == expected_cache

    # Create a new item by duplicating the existing item
    new_item = item_factory.get_or_create_item_from_item('#new_comp1', item, config=scheduler_config)
    expected_cache |= {str(proj/'source/new_comp1.F90').lower(), '#new_comp1'}
    assert set(item_factory.item_cache) == expected_cache

    # Assert the new item differs from the existing item in the name, with the original
    # item unchanged
    assert new_item.name == '#new_comp1'
    assert isinstance(new_item, ProcedureItem)
    assert new_item.ir.name == 'new_comp1'
    assert item.ir.name == 'comp1'

    # Make sure both items have the same dependencies but the dependency
    # objects are distinct objects
    assert item.dependencies == new_item.dependencies
    assert all(d is not new_d for d, new_d in zip(item.dependencies, new_item.dependencies))


def test_procedure_item_from_item2(testdir, default_config):
    proj = testdir/'sources/projBatch'

    # A file with a single subroutine declared in a module that calls a typebound procedure
    # where the type is imported via an import statement in the module scope
    item_factory = ItemFactory()
    scheduler_config = SchedulerConfig.from_dict(default_config)
    file_item = item_factory.get_or_create_file_item_from_path(proj/'module/other_mod.F90', config=scheduler_config)
    mod_item = file_item.create_definition_items(item_factory=item_factory, config=scheduler_config)[0]
    assert mod_item.name == 'other_mod'
    assert isinstance(mod_item, ModuleItem)
    item = mod_item.create_definition_items(item_factory=item_factory, config=scheduler_config)[0]
    assert item.name == 'other_mod#mod_proc'
    assert isinstance(item, ProcedureItem)

    expected_cache = {str(proj/'module/other_mod.F90').lower(), 'other_mod', 'other_mod#mod_proc'}
    assert set(item_factory.item_cache) == expected_cache

    # Create a new item by duplicating the existing item
    new_item = item_factory.get_or_create_item_from_item('my_mod#new_proc', item, config=scheduler_config)[0]
    expected_cache |= {str(proj/'module/my_mod.F90').lower(), 'my_mod', 'my_mod#new_proc'}
    assert set(item_factory.item_cache) == expected_cache

    # Assert the new item differs from the existing item in the name, with the original
    # item unchanged
    assert new_item.name == 'my_mod#new_proc'
    assert isinstance(new_item, ProcedureItem)
    assert new_item.ir.name == 'new_proc'
    assert new_item.ir.parent.name == 'my_mod'
    assert item.ir.name == 'mod_proc'
    assert item.ir.parent.name == 'other_mod'

    # Make sure both items have the same dependencies but the dependency
    # objects are distinct objects
    assert item.dependencies == new_item.dependencies
    assert all(d is not new_d for d, new_d in zip(item.dependencies, new_item.dependencies))


def test_typedef_item(testdir):
    proj = testdir/'sources/projBatch'

    # A file with multiple type definitions, of which we pick one
    item = get_item(
        TypeDefItem, proj/'module/t_mod.F90', 't_mod#t',
        RegexParserClass.ProgramUnitClass | RegexParserClass.TypeDefClass
    )
    assert item.name == 't_mod#t'
    assert str(item) == 'loki.batch.TypeDefItem'
    assert item.ir is item.source['t']
    assert item.scope_ir is item.source['t']
    assert item.transformation_ir is item.source['t_mod']
    assert 'proc' in item.ir.variable_map
    assert item.definitions == item.ir.declarations

    # Without module items in the cache, the definition items will be externals
    assert all(
        isinstance(_item, ExternalItem) and _item.origin_cls is ProcedureBindingItem
        for _item in item.create_definition_items(item_factory=ItemFactory())
    )
    tt_import = item.scope.import_map['tt'].clone(symbols=item.scope.import_map['tt'].symbols[:1])
    assert item.dependencies == (tt_import, item.ir.parent['t1'])

    # Without module items in the cache, the dependency items will be externals
    item_factory = ItemFactory()
    item_factory.item_cache[item.name] = item
    items = item.create_dependency_items(item_factory=ItemFactory())
    assert items == ('tt_mod', 't_mod#t1')
    assert all(
        isinstance(_item, ExternalItem) and _item.origin_cls in (ModuleItem, TypeDefItem)
        for _item in items
    )

    # Need to add the modules of the dependent types
    item_factory = ItemFactory()
    item_factory.item_cache[item.name] = item
    item_factory.item_cache['t_mod'] = ModuleItem('t_mod', source=item.source)
    item_factory.item_cache['tt_mod'] = get_item(
        ModuleItem, proj/'module/tt_mod.F90', 'tt_mod', RegexParserClass.ProgramUnitClass
    )
    assert 'tt_mod#tt' not in item_factory.item_cache
    assert 't_mod#t1' not in item_factory.item_cache
    items = item.create_dependency_items(item_factory=item_factory)
    assert 'tt_mod#tt' in item_factory.item_cache
    assert 't_mod#t1' in item_factory.item_cache
    assert items == (item_factory.item_cache['tt_mod#tt'], item_factory.item_cache['t_mod#t1'])
    assert all(isinstance(i, TypeDefItem) for i in items[1:])
    assert not items[1].dependencies


def test_interface_item_in_module(testdir):
    proj = testdir/'sources/projInlineCalls'

    # A file containing a module, with an interface to declare multiple functions
    # with a common name
    item = get_item(
        InterfaceItem, proj/'some_module.F90', 'some_module#add_args',
        RegexParserClass.ProgramUnitClass | RegexParserClass.InterfaceClass
    )

    assert item.name == 'some_module#add_args'
    assert str(item) == 'loki.batch.InterfaceItem'
    assert item.ir is item.source['some_module'].interface_map['add_args']
    assert {'add_args', 'add_two_args', 'add_three_args'} == set(item.ir.symbols)

    # An interface does not define anything by itself
    assert not item.definitions
    assert not item.create_definition_items(item_factory=ItemFactory())

    # An interface depends on the routines it declares
    assert item.dependencies == ('add_two_args', 'add_three_args')

    # Without module item in the cache, the dependencies will be externals
    scheduler_config = SchedulerConfig.from_dict({'default': {'strict': True}})
    item_factory = ItemFactory()
    item_factory.item_cache[item.name] = item
    assert all(
        isinstance(_item, ExternalItem) and _item.origin_cls is ProcedureItem
        for _item in item.create_dependency_items(item_factory=item_factory, config=scheduler_config)
    )

    # Let's start again with the module item
    item_factory = ItemFactory()
    item_factory.item_cache[item.name] = item
    item_factory.item_cache['some_module'] = ModuleItem('some_module', source=item.source)
    assert 'some_module#add_two_args' not in item_factory.item_cache
    assert 'some_module#add_three_args' not in item_factory.item_cache
    items = item.create_dependency_items(item_factory=item_factory)
    assert 'some_module#add_two_args' in item_factory.item_cache
    assert 'some_module#add_three_args' in item_factory.item_cache
    assert items == (
        item_factory.item_cache['some_module#add_two_args'], item_factory.item_cache['some_module#add_three_args']
    )
    assert all(isinstance(i, ProcedureItem) for i in items)


def test_interface_item_in_subroutine(testdir):
    proj = testdir/'sources/projInlineCalls'

    # A file containing the driver subroutine, which uses an interface to declare an
    # inline call
    item = get_item(
        ProcedureItem, proj/'driver.F90', '#driver',
        RegexParserClass.ProgramUnitClass
    )

    # Make sure the interface is included in the dependencies
    assert len(item.dependencies) == len(item.ir.imports + item.ir.interfaces) + 1 # (+1 for the call)

    item_factory = ItemFactory()
    item_factory.item_cache[item.name] = item

    # Dependency items cannot be created without the corresponding modules present
    with pytest.raises(RuntimeError):
        item.create_dependency_items(item_factory=item_factory)

    # Add the missing dependency modules
    for module_name in ('some_module', 'vars_module'):
        module_item = get_item(ModuleItem, proj/f'{module_name}.F90', module_name, RegexParserClass.ProgramUnitClass)
        item_factory.item_cache[module_item.name] = module_item

    # Dependency items can still not be created because the interface routine is still missing
    with pytest.raises(RuntimeError):
        item.create_dependency_items(item_factory=item_factory)

    # Add the missing dependency
    routine_item = get_item(ProcedureItem, proj/'double_real.F90', '#double_real', RegexParserClass.ProgramUnitClass)
    item_factory.item_cache[routine_item.name] = routine_item

    # Validate dependency items
    items = item.create_dependency_items(item_factory=item_factory)
    assert set(items) == {
        'some_module#return_one', 'some_module', 'some_module#add_args', 'some_module#some_type',
        'vars_module', '#double_real', 'some_module#some_type%do_something'
    }


def test_procedure_binding_item1(testdir):
    proj = testdir/'sources/projBatch'
    # pylint: disable=unsupported-binary-operation
    parser_classes = (
        RegexParserClass.ProgramUnitClass | RegexParserClass.TypeDefClass | RegexParserClass.DeclarationClass
    )

    # A typedef with a procedure binding as well as nested types that have in turn procedure bindings

    # 1. A direct procedure binding
    item = get_item(ProcedureBindingItem, proj/'module/t_mod.F90', 't_mod#t%proc', parser_classes)
    assert item.name == 't_mod#t%proc'
    assert str(item) == 'loki.batch.ProcedureBindingItem'
    assert item.ir is item.source['t'].variable_map['proc']
    assert isinstance(item.definitions, tuple) and not item.definitions
    assert not item.create_definition_items(item_factory=ItemFactory())
    assert item.dependencies == as_tuple(item.source['t_proc'])

    item_factory = ItemFactory()
    item_factory.item_cache.update({'t_mod': ModuleItem('t_mod', source=item.source)})
    items = item.create_dependency_items(item_factory=item_factory)
    assert len(items) == 1
    assert isinstance(items[0], ProcedureItem)
    assert items[0].ir is item.source['t_proc']


def test_procedure_binding_item2(testdir, default_config):
    proj = testdir/'sources/projBatch'
    # pylint: disable=unsupported-binary-operation
    parser_classes = (
        RegexParserClass.ProgramUnitClass | RegexParserClass.TypeDefClass | RegexParserClass.DeclarationClass
    )

    # 2. An indirect procedure binding via a nested type member, where the type is declared in the same module
    item = get_item(ProcedureBindingItem, proj/'module/t_mod.F90', 't_mod#t%no%way', parser_classes)
    assert item.name == 't_mod#t%no%way'
    assert isinstance(item.ir, Scalar)
    assert isinstance(item.definitions, tuple) and not item.definitions
    assert not item.create_definition_items(item_factory=ItemFactory())
    assert item.dependencies == ('no%way',)

    item_factory = ItemFactory()
    item_factory.item_cache[item.name] = item
    # ExternalItem, because item_cache does not contain the relevant module
    assert all(
        isinstance(_item, ExternalItem) and _item.origin_cls is ProcedureBindingItem
        for _item in item.create_dependency_items(
            item_factory=item_factory, config=SchedulerConfig.from_dict(default_config)
        )
    )

    item_factory = ItemFactory()
    item_factory.item_cache[item.name] = item
    item_factory.item_cache['t_mod'] = ModuleItem('t_mod', source=item.source)
    items = item.create_dependency_items(item_factory=item_factory)
    assert len(items) == 1
    assert isinstance(items[0], ProcedureBindingItem)
    assert items[0].name == 't_mod#t1%way'
    assert 't_mod#t1%way' in item_factory.item_cache

    assert 't_mod#my_way' not in item_factory.item_cache
    next_items = items[0].create_dependency_items(item_factory=item_factory)
    assert len(next_items) == 1
    assert isinstance(next_items[0], ProcedureItem)
    assert next_items[0].ir is item.source['my_way']
    assert 't_mod#my_way' in item_factory.item_cache


def test_procedure_binding_item3(testdir):
    proj = testdir/'sources/projBatch'
    # pylint: disable=unsupported-binary-operation
    parser_classes = (
        RegexParserClass.ProgramUnitClass | RegexParserClass.TypeDefClass | RegexParserClass.DeclarationClass
    )

    # 3. An indirect procedure binding via a nested type member, where the type is declared in a different module
    item = get_item(ProcedureBindingItem, proj/'module/t_mod.F90', 't_mod#t%yay%proc', parser_classes)
    assert item.name == 't_mod#t%yay%proc'
    assert isinstance(item.ir, Scalar)
    assert isinstance(item.definitions, tuple) and not item.definitions
    assert not item.create_definition_items(item_factory=ItemFactory())
    assert item.dependencies == ('yay%proc',)

    item_factory = ItemFactory()
    item_factory.item_cache[item.name] = item
    item_factory.item_cache['tt_mod'] = get_item(ModuleItem, proj/'module/tt_mod.F90', 'tt_mod', parser_classes)
    items = item.create_dependency_items(item_factory=item_factory)
    assert len(items) == 1
    assert isinstance(items[0], ProcedureBindingItem)
    assert items[0].name == 'tt_mod#tt%proc'
    assert 'tt_mod#tt%proc' in item_factory.item_cache

    assert 'tt_mod#proc' not in item_factory.item_cache
    next_items = items[0].create_dependency_items(item_factory=item_factory)
    assert len(next_items) == 1
    assert isinstance(next_items[0], ProcedureItem)
    assert next_items[0].ir is items[0].source['proc']
    assert 'tt_mod#proc' in item_factory.item_cache


@pytest.mark.parametrize('config,expected_dependencies', [
    ({}, (('tt_mod#tt%proc',), ('tt_mod#proc',))),
    ({'default': {'disable': ['tt_mod#proc']}}, (('tt_mod#tt%proc',), ())),
    ({'default': {'disable': ['proc']}}, (('tt_mod#tt%proc',), ())),
    ({'default': {'disable': ['tt%proc']}}, ((),)),
    ({'default': {'disable': ['tt_mod#tt%proc']}}, ((),)),
])
def test_procedure_binding_with_config(testdir, config, expected_dependencies):
    proj = testdir/'sources/projBatch'
    # pylint: disable=unsupported-binary-operation
    parser_classes = (
        RegexParserClass.ProgramUnitClass | RegexParserClass.TypeDefClass | RegexParserClass.DeclarationClass
    )

    item = get_item(ProcedureBindingItem, proj/'module/t_mod.F90', 't_mod#t%yay%proc', parser_classes)

    # We need to have suitable dependency modules in the cache to spawn the dependency items
    item_factory = ItemFactory()
    item_factory.item_cache[item.name] = item
    item_factory.item_cache['tt_mod'] = get_item(
        ModuleItem, proj/'module/tt_mod.F90', 'tt_mod', RegexParserClass.ProgramUnitClass
    )
    scheduler_config = SchedulerConfig.from_dict(config)

    for dependencies in expected_dependencies:
        items = item.create_dependency_items(item_factory, config=scheduler_config)
        assert items == dependencies
        if items:
            item = items[0]


def test_item_graph(testdir, comp1_expected_dependencies):
    """
    Build a :any:`nx.Digraph` from a dummy call hierarchy to check the incremental parsing and
    discovery behaves as expected.
    """
    proj = testdir/'sources/projBatch'
    suffixes = ['.f90', '.F90']

    path_list = [f for ext in suffixes for f in proj.glob(f'**/*{ext}')]
    assert len(path_list) == 8

    # Map item names to items
    item_factory = ItemFactory()

    # Instantiate the basic list of items (files, modules, subroutines)
    for path in path_list:
        relative_path = str(path.relative_to(proj))
        file_item = get_item(FileItem, path, relative_path, RegexParserClass.ProgramUnitClass)
        item_factory.item_cache[relative_path] = file_item
        item_factory.item_cache.update(
            (item.name, item) for item in file_item.create_definition_items(item_factory=item_factory)
        )

    # Populate a graph from a seed routine
    seed = '#comp1'
    queue = deque()
    full_graph = nx.DiGraph()
    full_graph.add_node(item_factory.item_cache[seed])
    queue.append(item_factory.item_cache[seed])

    while queue:
        item = queue.popleft()
        dependencies = item.create_dependency_items(item_factory=item_factory)
        new_items = [i for i in dependencies if i not in full_graph]
        if new_items:
            full_graph.add_nodes_from(new_items)
            queue.extend(new_items)
        full_graph.add_edges_from((item, dependency) for dependency in dependencies)

    # Need to add the cyclic dependency (which isn't included in the fixture)
    comp1_expected_dependencies['t_mod#my_way'] += ('t_mod#t1%way',)

    assert set(full_graph) == set(comp1_expected_dependencies)
    assert {(a.name, b.name) for a, b in full_graph.edges} == {
        (a, b) for a, deps in comp1_expected_dependencies.items() for b in deps
    }

    # Note: quick visualization for debugging can be done using matplotlib
    # import matplotlib.pyplot as plt
    # nx.draw_planar(full_graph, with_labels=True)
    # plt.show()
    # # -or-
    # plt.savefig('test_item_graph.png')


@pytest.mark.parametrize('seed,dependencies_fixture', [
    ('#comp1', 'comp1_expected_dependencies'),
    ('other_mod#mod_proc', 'mod_proc_expected_dependencies'),
    (['#comp1', 'other_mod#mod_proc'], 'expected_dependencies'),
    ('#foobar', 'no_expected_dependencies'),
    # Not fully-qualified procedure name for a free subroutine
    ('comp1', 'comp1_expected_dependencies'),
     # Not fully-qualified procedure name for a module procedure
    ('mod_proc', 'mod_proc_expected_dependencies'),
])
def test_sgraph_from_seed(tmp_path, testdir, default_config, seed, dependencies_fixture, request):
    expected_dependencies = request.getfixturevalue(dependencies_fixture)
    proj = testdir/'sources/projBatch'
    suffixes = ['.f90', '.F90']

    path_list = [f for ext in suffixes for f in proj.glob(f'**/*{ext}')]
    assert len(path_list) == 8

    scheduler_config = SchedulerConfig.from_dict(default_config)
    item_factory = ItemFactory()

    # Instantiate the basic list of items (files, modules, subroutines)
    for path in path_list:
        relative_path = str(path.relative_to(proj))
        file_item = get_item(
            FileItem, path, relative_path, RegexParserClass.ProgramUnitClass,
            scheduler_config
        )
        item_factory.item_cache[relative_path] = file_item
        item_factory.item_cache.update(
            (item.name, item)
            for item in file_item.create_definition_items(item_factory=item_factory, config=scheduler_config)
        )

    # Create the graph
    sgraph = SGraph.from_seed(seed, item_factory, scheduler_config)

    # Check the graph
    assert set(sgraph.items) == set(expected_dependencies)
    assert set(sgraph.dependencies) == {
        (node, dependency)
        for node, dependencies in expected_dependencies.items()
        for dependency in dependencies
    }

    # Check the graph visualization
    graph_file = tmp_path/'sgraph_from_seed.dot'
    sgraph.export_to_file(graph_file)
    assert graph_file.exists()
    assert graph_file.with_suffix('.dot.pdf').exists()

    vgraph = VisGraphWrapper(graph_file)
    assert set(vgraph.nodes) == {item.upper() for item in expected_dependencies}
    assert set(vgraph.edges) == {
        (node.upper(), dependency.upper())
        for node, dependencies in expected_dependencies.items()
        for dependency in dependencies
    }


@pytest.mark.parametrize('seed,disable,active_nodes', [
    ('#comp1', ('comp2', 'a'), (
        '#comp1', 't_mod', 't_mod#t', 'header_mod', 't_mod#t%proc', 't_mod#t%no%way',
        't_mod#t_proc', 't_mod#t%yay%proc', 'tt_mod#tt%proc', 'tt_mod#proc',
        't_mod#t1%way', 't_mod#my_way', 'tt_mod#tt', 't_mod#t1', 'tt_mod#intf'
    )),
    ('#comp1', ('comp2', 'a', 't_mod#t%no%way'), (
        '#comp1', 't_mod', 't_mod#t', 'header_mod', 't_mod#t%proc',
        't_mod#t_proc', 't_mod#t%yay%proc', 'tt_mod#tt%proc', 'tt_mod#proc',
        'tt_mod#tt', 't_mod#t1', 'tt_mod#intf'
    )),
    ('#comp1', ('#comp2', 't1%way'), (
        '#comp1', 't_mod', 't_mod#t', 'header_mod', 't_mod#t%proc', 't_mod#t%no%way',
        't_mod#t_proc', 't_mod#t%yay%proc', 'tt_mod#tt%proc', 'tt_mod#proc',
        'tt_mod#tt', 't_mod#t1', 'a_mod#a', 'tt_mod#intf'
    )),
    ('t_mod#t_proc', ('t_mod#t1', 'proc'), (
        't_mod#t_proc', 't_mod#t', 'tt_mod#tt', 'a_mod#a', 'header_mod',
        't_mod#t%yay%proc', 'tt_mod#tt%proc'
    ))
])
def test_sgraph_disable(testdir, default_config, expected_dependencies, seed, disable, active_nodes):
    proj = testdir/'sources/projBatch'
    suffixes = ['.f90', '.F90']

    path_list = [f for ext in suffixes for f in proj.glob(f'**/*{ext}')]
    assert len(path_list) == 8

    default_config['default']['disable'] = disable
    scheduler_config = SchedulerConfig.from_dict(default_config)
    item_factory = ItemFactory()

    # Instantiate the basic list of items (files, modules, subroutines)
    for path in path_list:
        relative_path = str(path.relative_to(proj))
        file_item = get_item(
            FileItem, path, relative_path, RegexParserClass.ProgramUnitClass,
            scheduler_config
        )
        item_factory.item_cache[relative_path] = file_item
        item_factory.item_cache.update(
            (item.name, item)
            for item in file_item.create_definition_items(item_factory=item_factory, config=scheduler_config)
        )

    # Create the graph
    sgraph = SGraph.from_seed(seed, item_factory, scheduler_config)

    # Check the graph
    assert set(sgraph.items) == set(active_nodes)
    assert set(sgraph.dependencies) == {
        (node, dependency)
        for node, dependencies in expected_dependencies.items()
        for dependency in dependencies
        if node in active_nodes and dependency in active_nodes
    }


@pytest.mark.parametrize('seed,routines,active_nodes', [
    (
        '#comp1', {
            'comp1': {'expand': False}
        }, (
            '#comp1',
        )
    ),
    (
        '#comp2', {
            'comp2': {'block': ['a', 'b']},
            't_mod': {'block': ['a']}
        }, (
            '#comp2', 't_mod#t', 'header_mod', 't_mod#t%yay%proc',
            'tt_mod#tt', 't_mod#t1', 'tt_mod#tt%proc', 'tt_mod#proc'
        )
    ),
    (
        '#comp2', {
            'comp2': {'ignore': ['a'], 'block': ['b']},
            't_mod': {'ignore': ['a']}
        }, (
            '#comp2', 't_mod#t', 'header_mod', 't_mod#t%yay%proc',
            'tt_mod#tt', 't_mod#t1', 'tt_mod#tt%proc', 'tt_mod#proc',
            'a_mod#a'
        )
    ),
])
def test_sgraph_routines(testdir, default_config, expected_dependencies, seed, routines, active_nodes):
    proj = testdir/'sources/projBatch'
    suffixes = ['.f90', '.F90']

    path_list = [f for ext in suffixes for f in proj.glob(f'**/*{ext}')]
    assert len(path_list) == 8

    default_config['routines'] = routines
    scheduler_config = SchedulerConfig.from_dict(default_config)
    item_factory = ItemFactory()

    # Instantiate the basic list of items (files, modules, subroutines)
    for path in path_list:
        relative_path = str(path.relative_to(proj))
        file_item = get_item(
            FileItem, path, relative_path, RegexParserClass.ProgramUnitClass,
            scheduler_config
        )
        item_factory.item_cache[relative_path] = file_item
        item_factory.item_cache.update(
            (item.name, item)
            for item in file_item.create_definition_items(item_factory=item_factory, config=scheduler_config)
        )

    # Create the graph
    sgraph = SGraph.from_seed(seed, item_factory, scheduler_config)

    # Check the graph
    assert set(sgraph.items) == set(active_nodes)
    assert set(sgraph.dependencies) == {
        (node, dependency)
        for node, dependencies in expected_dependencies.items()
        for dependency in dependencies
        if node in active_nodes and dependency in active_nodes
    }

    targets = expected_dependencies[seed]
    targets = [t.replace('t_mod#t%', 'arg%') for t in targets]
    targets = [t.rsplit('#', maxsplit=1)[-1] for t in targets]

    # Without full parse and enriching (as done in the Scheduler before processing),
    # the type of the imported symbol cannot be determined and therefore global
    # variables like `nt1` or parameters like 'k' are listed as targets
    if seed == '#comp1':
        targets += ['nt1']
    if seed == '#comp2':
        targets += ['t_mod', 'b_mod', 'a_mod', 'k']

    if 'block' in routines[seed[1:]]:
        targets = [t for t in targets if t not in routines[seed[1:]]['block']]
    assert set(item_factory.item_cache[seed].targets) == set(targets)

    item_factory.item_cache['t_mod'].source.make_complete()
    item_factory.item_cache['header_mod'].source.make_complete()
    item_factory.item_cache[seed].source.make_complete()
    item_factory.item_cache[seed].ir.enrich([
        item_factory.item_cache['t_mod'].ir, item_factory.item_cache['header_mod'].ir]
    )

    # With fully-parsed and enriched source, we are able to distinguish between
    # the types of imported symbols and consequently global variables and parameters
    # are no longer listed as targets
    if 'nt1' in targets:
        targets.remove('nt1')
    if 'k' in targets:
        targets.remove('k')
    assert set(item_factory.item_cache[seed].targets) == set(targets)


def test_sgraph_filegraph(testdir, default_config, file_dependencies):
    proj = testdir/'sources/projBatch'
    suffixes = ['.f90', '.F90']

    path_list = [f for ext in suffixes for f in proj.glob(f'**/*{ext}')]
    assert len(path_list) == 8

    scheduler_config = SchedulerConfig.from_dict(default_config)
    item_factory = ItemFactory()

    # Instantiate the basic list of items (files, modules, subroutines)
    for path in path_list:
        file_item = FileItem(
            name=str(path),
            source=Sourcefile.from_file(path, frontend=REGEX, parser_classes=RegexParserClass.ProgramUnitClass),
            config=scheduler_config.create_item_config(str(path))
        )
        item_factory.item_cache[file_item.name] = file_item
        item_factory.item_cache.update(
            (item.name, item)
            for item in file_item.create_definition_items(item_factory=item_factory, config=scheduler_config)
        )

    # Create the graph
    sgraph = SGraph.from_seed('#comp1', item_factory, scheduler_config)

    # Derive the file graph
    file_graph = SGraph.as_filegraph(sgraph, item_factory, scheduler_config)

    assert set(file_graph.items) == {str(proj/name) for name in file_dependencies}
    assert set(file_graph.dependencies) == {
        (str(proj/node), str(proj/dependency))
        for node, dependencies in file_dependencies.items()
        for dependency in dependencies
    }

def discover_proj_typebound_item_factory(testdir, scheduler_config):
    proj = testdir/'sources/projTypeBound'
    suffixes = ['.f90', '.F90']

    path_list = [f for ext in suffixes for f in proj.glob(f'**/*{ext}')]
    assert len(path_list) == 3

    item_factory = ItemFactory()

    # Instantiate the basic list of items (files, modules, subroutines)
    for path in path_list:
        file_item = FileItem(
            name=str(path),
            source=Sourcefile.from_file(path, frontend=REGEX, parser_classes=RegexParserClass.ProgramUnitClass),
        )
        item_factory.item_cache[file_item.name] = file_item

        definitions = {
            item.name: item
            for item in file_item.create_definition_items(item_factory=item_factory, config=scheduler_config)
        }
        item_factory.item_cache.update(definitions)

        module_names = [item.name for item in definitions.values() if isinstance(item, ModuleItem)]
        definitions = {
            item.name: item
            for module in module_names
            for item in item_factory.item_cache[module].create_definition_items(
                item_factory=item_factory, config=scheduler_config
            )
        }
        item_factory.item_cache.update(definitions)

        type_names = [item.name for item in definitions.values() if isinstance(item, TypeDefItem)]
        definitions = {
            item.name: item
            for type_ in type_names
            for item in item_factory.item_cache[type_].create_definition_items(
                item_factory=item_factory, config=scheduler_config
            )
        }
        item_factory.item_cache.update(definitions)

    return item_factory


@pytest.mark.parametrize('name,config_override,item_type,ir_type,attrs_to_check,dependency_items', [
    (
        ##########
        '#driver',
        ##########
        # This depends on  modules (via unqualified imports), typebound procedures,
        # and typebound procedures in nested derived types
        {},
        ProcedureItem,
        Subroutine,
        {
            'calls': (
                'some_type%other_routine',
                'some_type%some_routine',
                'header_type%member_routine',
                'header_type%routine',
                'other%member',
                'other%var%member_routine'
            ),
            'targets': (
                'typebound_item',
                'typebound_header',
                'typebound_other',
                'other',
                'obj%other_routine',
                'obj2%some_routine',
                'header%member_routine',
                'header%routine',
                'other_obj%member',
                'derived%var%member_routine'
            )
        },
        (
            'typebound_item',
            'typebound_header',
            'typebound_other#other_type',
            'typebound_item#some_type%other_routine',
            'typebound_item#some_type%some_routine',
            'typebound_header#header_type%member_routine',
            'typebound_header#header_type%routine',
            'typebound_other#other_type%member',
            'typebound_other#other_type%var%member_routine'
        )
    ),
    (
        ###############################
        'typebound_item#other_routine',
        ###############################
        # This is a module routine that depends on a subroutine from an unqualified import,
        # and typebound procedures in the same module
        {},
        ProcedureItem,
        Subroutine,
        {
            'calls': ('abor1', 'some_type%routine1', 'some_type%routine2'),
            'targets': ('some_type', 'abor1', 'self%routine1', 'self%routine2'),
        },
        (
            'typebound_item#some_type',
            'typebound_header#abor1',
            'typebound_item#some_type%routine1',
            'typebound_item#some_type%routine2'
        ),
    ),
    (
        #########################
        'typebound_item#routine',
        #########################
        # This is a module routine that depends on a type bound procedure,
        # which is listed as disabled in the scheduler config with fully qualified name
        {
            'disable': ['typebound_item#some_type%some_routine'],
        },
        ProcedureItem,
        Subroutine,
        {
            'calls': ('some_type%some_routine',),
            'targets': ('some_type',)
        },
        ('typebound_item#some_type',),
    ),
    (
        #########################
        'typebound_item#routine',
        #########################
        # This is a module routine that depends on a type bound procedure,
        # which is listed as disabled in the scheduler config without providing scope
        {
            'disable': ['some_type%some_routine'],
        },
        ProcedureItem,
        Subroutine,
        {
            'calls': ('some_type%some_routine',),
            'targets': ('some_type',)
        },
        ('typebound_item#some_type',),
    ),
    (
        #########################
        'typebound_item#routine1',
        #########################
        # This is a module routine that depends on a module procedure,
        # which is listed as disabled in the scheduler config without providing scope
        {
            'disable': ['module_routine'],
        },
        ProcedureItem,
        Subroutine,
        {
            'calls': ('module_routine',),
            'targets': ('some_type',)
        },
        ('typebound_item#some_type',),
    ),
    (
        #######################################
        'typebound_item#some_type%some_routine',
        #######################################
        # This is a procedure binding, where the bound procedure is listed as disabled
        # without providing fully qualified name - This means that the dependency item
        # is not created
        {
            'disable': ['some_routine'],
        },
        ProcedureBindingItem,
        ProcedureSymbol,
        {
            'calls': (),
            'targets': (),
        },
        (),
    ),
    (
        #######################################
        'typebound_item#some_type%some_routine',
        #######################################
        # This is a procedure binding, where the bound procedure is listed as disabled
        # with fully qualified name provided - this means that the dependency item
        # is not created
        {
            'disable': ['typebound_item#some_routine'],
        },
        ProcedureBindingItem,
        ProcedureSymbol,
        {
            'calls': (),
            'targets': (),
        },
        (),
    ),
    (
        #######################################
        'typebound_item#some_type%some_routine',
        #######################################
        # This is a procedure binding, where the bound procedure is listed as ignored,
        # which still includes it in the targets list
        {
            'ignore': ['some_routine'],
        },
        ProcedureBindingItem,
        ProcedureSymbol,
        {
            'calls': (),
            'targets': ('some_routine',)
        },
        ('typebound_item#some_routine',),
    ),
    (
        #######################################
        'typebound_item#some_type%some_routine',
        #######################################
        # This is a procedure binding, where the bound procedure is listed as blocked,
        # which excludes it from the targets list
        {
            'block': ['some_routine'],
        },
        ProcedureBindingItem,
        ProcedureSymbol,
        {
            'calls': (),
            'targets': ()
        },
        (),
    ),
    (
        ###################################
        'typebound_item#some_type%routine',
        ###################################
        # This is a procedure binding with renaming
        {},
        ProcedureBindingItem,
        ProcedureSymbol,
        {
            'calls': (),
            'targets': ('module_routine',),
        },
        ('typebound_item#module_routine',),
    ),
    (
        #############################
        'typebound_other#other_type',
        #############################
        # This is a derived type definition that has a dependency on another
        # type that is imported from another module, and renamed upon import
        {},
        TypeDefItem,
        TypeDef,
        {
            'calls': (),
            'targets': ('typebound_header', 'header')
        },
        ('typebound_header#header_type',),
    ),
])
def test_batch_typebound_item(
    testdir, default_config,
    name, config_override, item_type, ir_type, attrs_to_check, dependency_items
):
    """
    Test the basic regex frontend nodes in :any:`Item` objects for fast dependency detection
    for type-bound procedures.
    """
    default_config['default'].update(config_override)
    scheduler_config = SchedulerConfig.from_dict(default_config)
    item_factory = discover_proj_typebound_item_factory(testdir, scheduler_config)

    item = item_factory.item_cache[name]
    assert isinstance(item, item_type)
    assert isinstance(item.ir, ir_type)

    for key, value in attrs_to_check.items():
        assert getattr(item, key) == value

    assert item.create_dependency_items(item_factory, scheduler_config) == dependency_items


def test_batch_typebound_nested_item(testdir, default_config):
    """
    Test the basic regex frontend nodes in :any:`Item` objects for fast dependency detection
    for type-bound procedures for calls to nested derived type bindings
    """
    scheduler_config = SchedulerConfig.from_dict(default_config)
    item_factory = discover_proj_typebound_item_factory(testdir, scheduler_config)

    item = item_factory.item_cache['typebound_other#other_member']
    assert isinstance(item, ProcedureItem)
    assert isinstance(item.ir, Subroutine)
    assert len(item.dependencies) == 4

    assert isinstance(item.dependencies[0], ir.Import)
    assert item.dependencies[0].module == 'typebound_header'

    # Verify that the call to the nested type's routine is added when creating
    # dependency items
    assert 'typebound_other#other_type%var%member_routine' not in item_factory
    assert item.create_dependency_items(item_factory, scheduler_config) == (
        'typebound_other#other_type',
        'typebound_header#header_member_routine',
        'typebound_other#other_type%var%member_routine',
    )
    assert 'typebound_other#other_type%var%member_routine' in item_factory

    # Verify that the nested binding item can correctly resolve this to the binding
    # in the type
    proc_bind_item = item_factory.item_cache['typebound_other#other_type%var%member_routine']
    assert isinstance(proc_bind_item, ProcedureBindingItem)
    assert isinstance(proc_bind_item.ir, Scalar)
    assert proc_bind_item.ir == 'var'
    assert proc_bind_item.dependencies == ('var%member_routine',)
    assert proc_bind_item.create_dependency_items(item_factory, scheduler_config) == (
        'typebound_header#header_type%member_routine',
    )

    # Verify that the binding in the type correctly resolves to the module routine
    nested_bind_item = item_factory.item_cache['typebound_header#header_type%member_routine']
    assert isinstance(nested_bind_item, ProcedureBindingItem)
    assert isinstance(nested_bind_item.ir, ProcedureSymbol)
    assert nested_bind_item.ir == 'member_routine'
    assert nested_bind_item.create_dependency_items(item_factory, scheduler_config) == (
        'typebound_header#header_member_routine',
    )

    # Verify that we're now at the module routine
    routine_item = item_factory.item_cache['typebound_header#header_member_routine']
    assert isinstance(routine_item, ProcedureItem)
    assert isinstance(routine_item.ir, Subroutine)

    # Lastly, look at the deeply nested call...
    nested_call_item = item_factory.item_cache['typebound_other#nested_call']
    assert isinstance(nested_call_item, ProcedureItem)
    assert nested_call_item.create_dependency_items(item_factory, scheduler_config) == (
        'typebound_other#outer_type', 'typebound_other#outer_type%other%var%member_routine'
    )

    # ...and see if we can chase the deeply nested dependencies correctly
    other_var_member_item = item_factory.item_cache['typebound_other#outer_type%other%var%member_routine']
    assert isinstance(other_var_member_item, ProcedureBindingItem)
    assert isinstance(other_var_member_item.ir, Scalar)
    assert other_var_member_item.dependencies == ('other%var%member_routine',)
    assert other_var_member_item.create_dependency_items(item_factory, scheduler_config) == (
        'typebound_other#other_type%var%member_routine',
    )


def test_batch_typebound_item_targets(default_config):
    default_config['default']['disable'] += ['timer_mod']

    fcode = """
MODULE TYPEBOUND_ITEM_TARGETS_MOD
    USE TIMER_MOD, ONLY: PERFORMANCE_TIMER
    IMPLICIT NONE
CONTAINS
    SUBROUTINE DRIVER
        IMPLICIT NONE
        TYPE(PERFORMANCE_TIMER) :: TIMER

        CALL TIMER%START()

        ! DO SOMETHING

        CALL TIMER%END()
    END SUBROUTINE DRIVER
END MODULE TYPEBOUND_ITEM_TARGETS_MOD
    """.strip()

    source = Sourcefile.from_source(fcode, parser_classes=RegexParserClass.ProgramUnitClass, frontend=REGEX)
    source.path = 'None'
    item_factory = ItemFactory()
    scheduler_config = SchedulerConfig.from_dict(default_config)

    file_item = item_factory.get_or_create_file_item_from_source(source, scheduler_config)
    assert file_item.targets == ()
    assert file_item.definitions == (source['typebound_item_targets_mod'],)
    assert file_item.dependencies == ()

    file_definitions = file_item.create_definition_items(item_factory, scheduler_config)
    assert file_definitions == ('typebound_item_targets_mod',)

    module_item = file_definitions[0]
    assert module_item.targets == ()
    assert module_item.definitions == (source['driver'],)
    assert module_item.dependencies == module_item.ir.imports

    module_definitions = module_item.create_definition_items(item_factory, scheduler_config)
    assert module_definitions == ('typebound_item_targets_mod#driver',)
    assert module_item.create_dependency_items(item_factory, scheduler_config) == ()

    driver_item = module_definitions[0]
    assert driver_item.targets == ()
    assert driver_item.definitions == ()
    assert ('timer_mod', 'timer%start', 'timer%end') == (
        (driver_item.dependencies[0].module.lower(),) +
        tuple(dep.name for dep in driver_item.dependencies[1:])
    )

    assert driver_item.create_definition_items(item_factory, scheduler_config) == ()
    assert driver_item.create_dependency_items(item_factory, scheduler_config) == ()
loki-ecmwf-0.3.6/loki/batch/tests/__init__.py0000664000175000017500000000057015167130205021235 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
loki-ecmwf-0.3.6/loki/batch/tests/test_transformation.py0000664000175000017500000006074515167130205023615 0ustar  alastairalastairfrom functools import partial
from pathlib import Path
import pytest

from loki import (
    Sourcefile, Subroutine, FindInlineCalls, fgen, IntLiteral, Module, Function
)
from loki.batch import Transformation, Pipeline, ProcedureItem, TransformationError
from loki.jit_build import jit_compile, clean_test
from loki.frontend import available_frontends, OMNI, REGEX
from loki.ir import nodes as ir, FindNodes
from loki.transformations import (
    replace_selected_kind, FileWriteTransformation
)


@pytest.fixture(scope='module', name='rename_transform')
def fixture_rename_transform():

    class RenameTransform(Transformation):
        """
        Simple `Transformation` object that renames subroutine and modules.
        """

        def transform_file(self, sourcefile, **kwargs):
            sourcefile.ir.prepend(
                ir.Comment(text="! [Loki] RenameTransform applied")
            )

        def transform_subroutine(self, routine, **kwargs):
            routine.name += '_test'

        def transform_module(self, module, **kwargs):
            module.name += '_test'

    return RenameTransform()


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('method', ['source', 'transformation'])
@pytest.mark.parametrize('lazy', [False, True])
def test_transformation_apply(rename_transform, frontend, method, lazy, tmp_path):
    """
    Apply a simple transformation that renames routines and modules, and
    test that this also works when the original source object was parsed
    using lazy construction.
    """
    fcode = """
module mymodule
  real(kind=4) :: myvar
end module mymodule

subroutine myroutine(a, b)
  real(kind=4), intent(inout) :: a, b

  a = a + b
end subroutine myroutine
"""
    # Let source apply transformation to all items and verify
    source = Sourcefile.from_source(fcode, frontend=REGEX if lazy else frontend, xmods=[tmp_path])
    assert source._incomplete is lazy
    if method == 'source':
        if lazy:
            with pytest.raises(TransformationError):
                source.apply(rename_transform)
            source.make_complete(frontend=frontend, xmods=[tmp_path])
        source.apply(rename_transform)
    elif method == 'transformation':
        if lazy:
            with pytest.raises(TransformationError):
                rename_transform.apply(source)
            source.make_complete(frontend=frontend, xmods=[tmp_path])
        rename_transform.apply(source)
    else:
        raise ValueError(f'Unknown method "{method}"')
    assert not source._incomplete

    assert isinstance(source.ir.body[0], ir.Comment)
    assert source.ir.body[0].text == '! [Loki] RenameTransform applied'

    assert source.modules[0].name == 'mymodule'
    assert source.subroutines[0].name == 'myroutine'

    if method == 'source':
        source.modules[0].apply(rename_transform, recurse_to_contained_nodes=True)
        source.subroutines[0].apply(rename_transform, recurse_to_contained_nodes=True)
    else:
        rename_transform.apply(source.modules[0], recurse_to_contained_nodes=True)
        rename_transform.apply(source.subroutines[0], recurse_to_contained_nodes=True)

    assert source.modules[0].name == 'mymodule_test'
    assert source['mymodule_test'] == source.modules[0]
    assert source.subroutines[0].name == 'myroutine_test'
    assert source['myroutine_test'] == source.subroutines[0]


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('target, apply_method', [
    ('module_routine', lambda transform, obj, **kwargs: obj.apply(transform, **kwargs)),
    ('myroutine', lambda transform, obj, **kwargs: transform.apply_subroutine(obj, **kwargs))
])
@pytest.mark.parametrize('lazy', [False, True])
def test_transformation_apply_subroutine(rename_transform, frontend, target, apply_method, lazy, tmp_path):
    """
    Apply a simple transformation that renames routines and modules
    """
    fcode = """
module mymodule
  real(kind=4) :: myvar

contains

  subroutine module_routine(argument)
    real(kind=4), intent(inout) :: argument

    argument = member_func()

  contains
    function member_func() result(res)
      real(kind=4) :: res

      res = 4.
    end function member_func
  end subroutine module_routine
end module mymodule

subroutine myroutine(a, b)
  real(kind=4), intent(inout) :: a, b

  a = a + b
end subroutine myroutine
"""
    source = Sourcefile.from_source(fcode, frontend=REGEX if lazy else frontend, xmods=[tmp_path])
    assert source._incomplete is lazy
    assert source[target]._incomplete is lazy

    if lazy:
        with pytest.raises(TransformationError):
            apply_method(rename_transform, source[target])
        source[target].make_complete(frontend=frontend, xmods=[tmp_path])
    apply_method(rename_transform, source[target])

    assert source._incomplete is lazy  # This should only have triggered a re-parse on the actual transformation target
    assert not source[f'{target}_test']._incomplete
    assert source.modules[0].name == 'mymodule'
    assert source['mymodule'] == source.modules[0]
    if target == 'module_routine':
        # Let only the inner module routine apply the transformation
        assert source.subroutines[0].name == 'myroutine'
        assert source['myroutine'] == source.subroutines[0]
    elif target == 'myroutine':
        # Apply transformation explicitly to the outer routine
        assert source.subroutines[0].name == 'myroutine_test'
        assert source['myroutine_test'] == source.subroutines[0]
    assert len(source.all_subroutines) == 2  # Ignore member func
    if target == 'module_routine':
        assert source.all_subroutines[1].name == 'module_routine_test'
        assert source['module_routine_test'] == source.all_subroutines[1]
        assert len(source['module_routine_test'].members) == 1
        assert source['module_routine_test'].members[0].name == 'member_func'
    elif target == 'myroutine':
        assert source.all_subroutines[1].name == 'module_routine'
        assert source['module_routine'] == source.all_subroutines[1]


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('apply_method', [
    lambda transform, obj, **kwargs: obj.apply(transform, **kwargs),
    lambda transform, obj, **kwargs: transform.apply_module(obj, **kwargs)
])
@pytest.mark.parametrize('lazy', [False, True])
def test_transformation_apply_module(rename_transform, frontend, apply_method, lazy, tmp_path):
    """
    Apply a simple transformation that renames routines and modules
    """
    fcode = """
module mymodule
  real(kind=4) :: myvar

contains

  subroutine module_routine(argument)
    real(kind=4), intent(inout) :: argument

    argument = argument  + 1.
  end subroutine module_routine
end module mymodule

subroutine myroutine(a, b)
  real(kind=4), intent(inout) :: a, b

  a = a + b
end subroutine myroutine
"""
    source = Sourcefile.from_source(fcode, frontend=REGEX if lazy else frontend, xmods=[tmp_path])
    assert source._incomplete is lazy
    assert source['mymodule']._incomplete is lazy
    assert source['myroutine']._incomplete is lazy

    if lazy:
        with pytest.raises(TransformationError):
            apply_method(rename_transform, source['mymodule'])
        source['mymodule'].make_complete(frontend=frontend, xmods=[tmp_path])
    apply_method(rename_transform, source['mymodule'])

    assert source._incomplete is lazy
    assert not source['mymodule_test']._incomplete
    assert source['myroutine']._incomplete is lazy
    assert source.modules[0].name == 'mymodule_test'
    assert source['mymodule_test'] == source.modules[0]
    assert len(source.all_subroutines) == 2
    # Outer subroutine is untouched, since we apply all
    # transformations to anything in the module.
    assert source.subroutines[0].name == 'myroutine'
    assert source['myroutine'] == source.subroutines[0]

    assert source.all_subroutines[1].name == 'module_routine'
    assert source['module_routine'] == source.all_subroutines[1]


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_replace_selected_kind(tmp_path, frontend):
    """
    Test correct replacement of all `selected_x_kind` calls by
    iso_fortran_env constant.
    """
    fcode = """
subroutine transform_replace_selected_kind(i, a)
  use iso_fortran_env, only: int8
  implicit none
  integer, parameter :: jprm = selected_real_kind(6,37)
  integer(kind=selected_int_kind(9)), intent(out) :: i
  real(kind=selected_real_kind(13,300)), intent(out) :: a
  integer(kind=int8) :: j = 1
  integer(kind=selected_int_kind(1)) :: k = 9
  real(kind=selected_real_kind(7)) :: b = 5._jprm
  real(kind=selected_real_kind(r=2, p=4)) :: c = 1.

  i = j + k
  a = b + c + real(4, kind=selected_real_kind(6, r=37))
end subroutine transform_replace_selected_kind
    """.strip()

    routine = Subroutine.from_source(fcode, frontend=frontend)
    imports = FindNodes(ir.Import).visit(routine.spec)
    assert len(imports) == 1 and imports[0].module.lower() == 'iso_fortran_env'
    assert len(imports[0].symbols) == 1 and imports[0].symbols[0].name.lower() == 'int8'

    # Test the original implementation
    filepath = tmp_path/(f'{routine.name}_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname=routine.name)

    i, a = function()
    assert i == 10
    assert a == 10.

    # Apply transformation and check imports
    replace_selected_kind(routine)
    assert not [call for call in FindInlineCalls().visit(routine.ir)
                if call.name.lower().startswith('selected')]

    imports = FindNodes(ir.Import).visit(routine.spec)
    assert len(imports) == 1 and imports[0].module.lower() == 'iso_fortran_env'

    source = fgen(routine).lower()
    assert not 'selected_real_kind' in source
    assert not 'selected_int_kind' in source

    if frontend == OMNI:
        # F£$%^% OMNI replaces randomly SOME selected_real_kind calls by
        # (wrong!) integer kinds
        symbols = {'int8', 'real32', 'real64'}
    else:
        symbols = {'int8', 'int32', 'real32', 'real64'}

    assert len(imports[0].symbols) == len(symbols)
    assert {s.name.lower() for s in imports[0].symbols} == symbols

    # Test the transformed implementation
    iso_filepath = tmp_path/(f'{routine.name}_replaced_{frontend}.f90')
    iso_function = jit_compile(routine, filepath=iso_filepath, objname=routine.name)

    i, a = iso_function()
    assert i == 10
    assert a == 10.

    clean_test(filepath)
    clean_test(iso_filepath)


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('post_apply_rescope_symbols', [True, False])
def test_transformation_post_apply_subroutine(tmp_path, frontend, post_apply_rescope_symbols):
    """Verify that post_apply is called for subroutines."""

    #### Test that rescoping is applied and effective ####

    tmp_routine = Subroutine('some_routine')
    class ScopingErrorTransformation(Transformation):
        """Intentionally idiotic transformation that introduces a scoping error."""

        def transform_subroutine(self, routine, **kwargs):
            i = routine.variable_map['i']
            j = i.clone(name='j', scope=tmp_routine, type=i.type.clone(intent=None))
            routine.variables += (j,)
            routine.body.append(ir.Assignment(lhs=j, rhs=IntLiteral(2)))
            routine.body.append(ir.Assignment(lhs=i, rhs=j))
            routine.name += '_transformed'
            assert routine.variable_map['j'].scope is tmp_routine

    fcode = """
subroutine transformation_post_apply(i)
  integer, intent(out) :: i
  i = 1
end subroutine transformation_post_apply
    """.strip()

    routine = Subroutine.from_source(fcode, frontend=frontend)

    # Test the original implementation
    filepath = tmp_path/(f'{routine.name}_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname=routine.name)

    i = function()
    assert i == 1

    # Apply transformation and make sure variable scope is correct
    routine.apply(ScopingErrorTransformation(), post_apply_rescope_symbols=post_apply_rescope_symbols)
    if post_apply_rescope_symbols:
        # Scope is correct
        assert routine.variable_map['j'].scope is routine
    else:
        # Scope is wrong
        assert routine.variable_map['j'].scope is tmp_routine

    new_filepath = tmp_path/(f'{routine.name}_{frontend}.f90')
    new_function = jit_compile(routine, filepath=new_filepath, objname=routine.name)

    i = new_function()
    assert i == 2

    clean_test(filepath)
    clean_test(new_filepath)


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('post_apply_rescope_symbols', [True, False])
def test_transformation_post_apply_module(tmp_path, frontend, post_apply_rescope_symbols):
    """Verify that post_apply is called for modules."""

    #### Test that rescoping is applied and effective ####

    tmp_scope = Module('some_module')
    class ScopingErrorTransformation(Transformation):
        """Intentionally idiotic transformation that introduces a scoping error."""

        def transform_module(self, module, **kwargs):
            i = module.variable_map['i']
            j = i.clone(name='j', scope=tmp_scope, type=i.type.clone(intent=None))
            module.variables += (j,)
            routine = module.subroutines[0]
            routine.body.prepend(ir.Assignment(lhs=i, rhs=j))
            routine.body.prepend(ir.Assignment(lhs=j, rhs=IntLiteral(2)))
            module.name += '_transformed'
            assert module.variable_map['j'].scope is tmp_scope

    fcode = """
module module_post_apply
  integer :: i = 0
contains
  subroutine test_post_apply(ret)
    integer, intent(out) :: ret
    i = i + 1
    ret = i
  end subroutine test_post_apply
end module module_post_apply
    """.strip()

    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    # Test the original implementation
    filepath = tmp_path/(f'{module.name}_{frontend}_{post_apply_rescope_symbols!s}.f90')
    mod = jit_compile(module, filepath=filepath, objname=module.name)

    i = mod.test_post_apply()
    assert i == 1

    # Apply transformation
    module.apply(ScopingErrorTransformation(), post_apply_rescope_symbols=post_apply_rescope_symbols)
    if post_apply_rescope_symbols:
        # Scope is correct
        assert module.variable_map['j'].scope is module
    else:
        # Scope is wrong
        assert module.variable_map['j'].scope is tmp_scope

    new_filepath = tmp_path/(f'{module.name}_{frontend}_{post_apply_rescope_symbols!s}.f90')
    new_mod = jit_compile(module, filepath=new_filepath, objname=module.name)

    i = new_mod.test_post_apply()
    assert i == 3

    clean_test(filepath)
    clean_test(new_filepath)


def test_transformation_file_write(tmp_path):
    """Verify that files get written with correct filenames"""

    fcode = """
subroutine rick()
  print *, "PRINT ME!"
end subroutine rick
"""
    source = Sourcefile.from_source(fcode)
    source.path = Path('rick.F90')
    item = ProcedureItem(name='#rick', source=source, config={'mode': 'roll'})

    # Test mode and suffix overrides
    ricks_path = tmp_path/'rick.roll.java'
    if ricks_path.exists():
        ricks_path.unlink()
    FileWriteTransformation(suffix='.java').apply(source=source, item=item,
                                                  build_args={'output_dir': tmp_path})
    assert ricks_path.exists()
    ricks_path.unlink()

    item = ProcedureItem(name='#rick', source=source)
    # Test default file writes
    ricks_path = tmp_path/'rick.loki.F90'
    if ricks_path.exists():
        ricks_path.unlink()
    FileWriteTransformation().apply(source=source, item=item, build_args={'output_dir': tmp_path})
    assert ricks_path.exists()
    ricks_path.unlink()

    # Test writing with "items" only (as in file graph traversal)
    ricks_path = tmp_path/'rick.loki.F90'
    if ricks_path.exists():
        ricks_path.unlink()
    FileWriteTransformation().apply(source=source, items=(item,), build_args={'output_dir': tmp_path})
    assert ricks_path.exists()
    ricks_path.unlink()

    # Check error behaviour if no item provided
    with pytest.raises(TransformationError):
        FileWriteTransformation().apply(source=source)


def test_transformation_pipeline_simple():
    """
    Test the instantiation of a :any:`Pipeline` from a partial definition.
    """

    class PrependTrafo(Transformation):
        def __init__(self, name='Rick', relaxed=False):
            self.name = name
            self.relaxed = relaxed

        def transform_subroutine(self, routine, **kwargs):
            greeting = 'Whazzup' if self.relaxed else 'Hello'
            routine.body.prepend(ir.Comment(text=f'! {greeting} {self.name}'))

    class AppendTrafo(Transformation):
        def __init__(self, name='Dave', in_french=False):
            self.name = name
            self.in_french = in_french

        def transform_subroutine(self, routine, **kwargs):
            greeting = 'Au revoir' if self.in_french else 'Goodbye'
            routine.body.append(ir.Comment(text=f'! {greeting}, {self.name}'))

    # Define a pipline as a combination of transformation classes
    # and a set pre-defined constructor flags
    GreetingPipeline = partial(
        Pipeline, classes=(PrependTrafo, AppendTrafo), relaxed=True
    )

    # Instantiate the pipeline object with additional constructor flags
    pipeline = GreetingPipeline(name='Bob', in_french=True)

    assert pipeline.transformations and len(pipeline.transformations) == 2
    assert isinstance(pipeline.transformations[0], PrependTrafo)
    assert pipeline.transformations[0].name == 'Bob'
    assert isinstance(pipeline.transformations[1], AppendTrafo)
    assert pipeline.transformations[1].name == 'Bob'
    assert pipeline.transformations[1].in_french

    # Now apply the pipeline to a simple subroutine
    fcode = """
subroutine test_pipeline
  integer :: i
  real :: a, b

  do i=1,3
    a = a + b
  end do
end subroutine test_pipeline
"""
    routine = Subroutine.from_source(fcode)
    pipeline.apply(routine)

    assert isinstance(routine.body.body[0], ir.Comment)
    assert routine.body.body[0].text == '! Whazzup Bob'
    assert isinstance(routine.body.body[-1], ir.Comment)
    assert routine.body.body[-1].text == '! Au revoir, Bob'


def test_transformation_pipeline_constructor():
    """
    Test the correct argument handling when instantiating a
    :any:`Pipeline` from a partial definitions.
    """

    class DoSomethingTrafo(Transformation):
        def __init__(self, a, b=None, c=True, d='yes'):
            self.a = a
            self.b = b
            self.c = c
            self.d = d

    class DoSomethingElseTrafo(Transformation):
        def __init__(self, b=None, d='no'):
            self.b = b
            self.d = d

    MyPipeline = partial(
        Pipeline, classes=(
            DoSomethingTrafo,
            DoSomethingElseTrafo,
        ),
        a=42
    )

    p1 = MyPipeline(b=66, d='yes')
    assert p1.transformations[0].a == 42
    assert p1.transformations[0].b == 66
    assert p1.transformations[0].c is True
    assert p1.transformations[0].d == 'yes'
    assert p1.transformations[1].b == 66
    assert p1.transformations[1].d == 'yes'

    # Now we use inheritance to propagate defaults

    class DoSomethingDifferentTrafo(DoSomethingTrafo):
        def __init__(self, e=1969, **kwargs):
            super().__init__(**kwargs)
            self.e = e

    MyOtherPipeline = partial(
        Pipeline, classes=(
            DoSomethingDifferentTrafo,
            DoSomethingElseTrafo,
        ),
        a=42
    )

    # Now check if inheritance works
    p2 = MyOtherPipeline(b=66, d='yes', e=1977)
    assert p2.transformations[0].a == 42
    assert p2.transformations[0].b == 66
    assert p2.transformations[0].c is True
    assert p2.transformations[0].d == 'yes'
    assert p2.transformations[0].e == 1977
    assert p2.transformations[1].b == 66
    assert p2.transformations[1].d == 'yes'


def test_transformation_pipeline_compose():
    """
    Test append / prepend functionalities of :any:`Pipeline` objects.
    """

    fcode = """
subroutine test_pipeline_compose(a)
  implicit none
  real, intent(inout) :: a
  a = a + 1.0
end subroutine test_pipeline_compose
"""

    class YesTrafo(Transformation):
        def transform_subroutine(self, routine, **kwargs):
            routine.body.append( ir.Comment(text='! Yes !') )

    class NoTrafo(Transformation):
        def transform_subroutine(self, routine, **kwargs):
            routine.body.append( ir.Comment(text='! No !') )

    class MaybeTrafo(Transformation):
        def transform_subroutine(self, routine, **kwargs):
            routine.body.append( ir.Comment(text='! Maybe !') )

    class MaybeNotTrafo(Transformation):
        def transform_subroutine(self, routine, **kwargs):
            routine.body.append( ir.Comment(text='! Maybe not !') )

    pipeline = Pipeline(classes=(YesTrafo, NoTrafo))
    pipeline.prepend(MaybeTrafo())
    pipeline.append(MaybeNotTrafo())

    routine = Subroutine.from_source(fcode)
    pipeline.apply(routine)

    comments = FindNodes(ir.Comment).visit(routine.body)
    assert len(comments) == 4
    assert comments[0].text == '! Maybe !'
    assert comments[1].text == '! Yes !'
    assert comments[2].text == '! No !'
    assert comments[3].text == '! Maybe not !'

    # Now try the same trick, but with the native addition API
    pipe_a = Pipeline(classes=(MaybeTrafo,))
    pipe_b = Pipeline(classes=(MaybeNotTrafo,YesTrafo))
    pipe = YesTrafo() + pipe_a + pipe_b + NoTrafo()

    with pytest.raises(TypeError):
        pipe += lambda t: t

    routine = Subroutine.from_source(fcode)
    pipe.apply(routine)

    comments = FindNodes(ir.Comment).visit(routine.body)
    assert len(comments) == 5
    assert comments[0].text == '! Yes !'
    assert comments[1].text == '! Maybe !'
    assert comments[2].text == '! Maybe not !'
    assert comments[3].text == '! Yes !'
    assert comments[4].text == '! No !'

    # Check that the string representation is sane
    assert '[\w%#./-]+)\"? \[colo', re.IGNORECASE)
    _re_edges = re.compile(r'\s*\"?(?P[\w%#./-]+)\"? -> \"?(?P[\w%#./-]+)\"?', re.IGNORECASE)

    def __init__(self, path):
        with Path(path).open('r') as f:
            self.text = f.read()

    @property
    def nodes(self):
        return list(self._re_nodes.findall(self.text))

    @property
    def edges(self):
        return list(self._re_edges.findall(self.text))


def test_scheduler_enrichment(testdir, config, frontend, tmp_path):
    projA = testdir/'sources/projA'

    scheduler = Scheduler(
        paths=projA, includes=projA/'include', config=config,
        seed_routines=['driverA'], frontend=frontend, xmods=[tmp_path]
    )

    for item in SFilter(scheduler.sgraph, item_filter=ProcedureItem):
        dependency_map = CaseInsensitiveDict(
            (item_.local_name, item_) for item_ in scheduler.sgraph.successors(item)
        )
        for call in FindNodes(ir.CallStatement).visit(item.ir.body):
            if call_item := dependency_map.get(str(call.name)):
                assert call.routine is call_item.ir


def test_scheduler_empty_config(testdir, frontend, tmp_path):
    """
    Test that instantiating the Scheduler without config works (albeit it's not very useful)
    This fixes #373
    """
    projA = testdir/'sources/projA'

    scheduler = Scheduler(
        paths=projA, includes=projA/'include',
        seed_routines=['driverA'], frontend=frontend, xmods=[tmp_path]
    )
    assert scheduler.items == ('drivera_mod#drivera',)


@pytest.mark.skipif(not graphviz_present(), reason='Graphviz is not installed')
@pytest.mark.parametrize('with_file_graph', [True, False, 'filegraph_simple'])
@pytest.mark.parametrize('with_legend', [True, False])
@pytest.mark.parametrize('seed', ['driverA', 'driverA_mod#driverA'])
def test_scheduler_graph_simple(
        tmp_path, testdir, config, frontend, driverA_dependencies,
        with_file_graph, with_legend, seed
):
    """
    Create a simple task graph from a single sub-project:

    projA: driverA -> kernelA -> compute_l1 -> compute_l2
                           |
                           | --> another_l1 -> another_l2
    """

    # Combine directory globbing and explicit file paths for lookup
    projA = testdir/'sources/projA'
    paths = [projA/'module', projA/'source/another_l1.F90', projA/'source/another_l2.F90']

    scheduler = Scheduler(
        paths=paths, includes=projA/'include', config=config,
        seed_routines=seed, frontend=frontend, xmods=[tmp_path]
    )

    assert set(scheduler.items) == {item.lower() for item in driverA_dependencies}
    assert set(scheduler.dependencies) == {
        (item.lower(), child.lower())
        for item, children in driverA_dependencies.items()
        for child in children
    }

    if with_file_graph:
        file_graph = scheduler.file_graph
        expected_file_dependencies = {
            'module/driverA_mod.f90': ('module/kernelA_mod.F90', 'module/header_mod.f90'),
            'module/kernelA_mod.F90': ('module/compute_l1_mod.f90', 'source/another_l1.F90'),
            'module/compute_l1_mod.f90': ('module/compute_l2_mod.f90',),
            'module/compute_l2_mod.f90': (),
            'source/another_l1.F90': ('source/another_l2.F90', 'module/header_mod.f90'),
            'source/another_l2.F90': ('module/header_mod.f90',),
            'module/header_mod.f90': (),
        }
        assert set(file_graph.items) == {str(projA/name).lower() for name in expected_file_dependencies}
        assert set(file_graph.dependencies) == {
            (str(projA/a).lower(), str(projA/b).lower())
            for a, deps in expected_file_dependencies.items() for b in deps
        }

    # Testing of callgraph visualisation
    cg_path = tmp_path/'callgraph_simple'
    if not isinstance(with_file_graph, bool):
        with_file_graph = tmp_path/with_file_graph
    scheduler.callgraph(cg_path, with_file_graph=with_file_graph, with_legend=with_legend)

    vgraph = VisGraphWrapper(cg_path)
    if with_legend:
        assert set(vgraph.nodes) == {item.upper() for item in driverA_dependencies} | {
            'FileItem', 'ModuleItem', 'ProcedureItem', 'TypeDefItem',
            'ProcedureBindingItem', 'InterfaceItem', 'ExternalItem'
        }
    else:
        assert set(vgraph.nodes) == {item.upper() for item in driverA_dependencies}
    assert set(vgraph.edges) == {
        (item.upper(), child.upper())
        for item, children in driverA_dependencies.items()
        for child in children
    }

    if with_file_graph:
        if isinstance(with_file_graph, bool):
            fg_path = cg_path.with_name(f'{cg_path.stem}_file_graph{cg_path.suffix}')
        else:
            fg_path = tmp_path/with_file_graph
        fgraph = VisGraphWrapper(fg_path)
        assert set(fgraph.nodes) == {name.lower() for name in expected_file_dependencies}
        assert set(fgraph.edges) == {
            (a.lower(), b.lower())
            for a, deps in expected_file_dependencies.items() for b in deps
        }

        fg_path.unlink()
        fg_path.with_suffix('.pdf').unlink(missing_ok=True)

    cg_path.unlink()
    cg_path.with_suffix('.pdf').unlink(missing_ok=True)


@pytest.mark.skipif(not graphviz_present(), reason='Graphviz is not installed')
@pytest.mark.parametrize('seed', ['compute_l1', 'compute_l1_mod#compute_l1'])
def test_scheduler_graph_partial(tmp_path, testdir, config, frontend, seed):
    """
    Create a sub-graph from a select set of branches in  single project:

    projA: compute_l1 -> compute_l2

           another_l1 -> another_l2
    """
    projA = testdir/'sources/projA'

    config['routines'] = {
        seed: {
            'role': 'driver',
            'expand': True,
        },
        'another_l1': {
            'role': 'driver',
            'expand': True,
        },
    }

    scheduler = Scheduler(paths=projA, includes=projA/'include', config=config, frontend=frontend, xmods=[tmp_path])

    expected_items = [
        'compute_l1_mod#compute_l1', 'compute_l2_mod#compute_l2', '#another_l1', '#another_l2'
    ]
    expected_dependencies = [
        ('compute_l1_mod#compute_l1', 'compute_l2_mod#compute_l2'),
        ('#another_l1', '#another_l2')
    ]

    # Check the correct sub-graph is generated
    assert all(n in scheduler.items for n in expected_items)
    assert all(e in scheduler.dependencies for e in expected_dependencies)
    assert 'driverA' not in scheduler.items
    assert 'kernelA' not in scheduler.items

    # Testing of callgraph visualisation
    cg_path = tmp_path/'callgraph_partial'
    scheduler.callgraph(cg_path)

    vgraph = VisGraphWrapper(cg_path)
    assert all(n.upper() in vgraph.nodes for n in expected_items)
    assert all((e[0].upper(), e[1].upper()) in vgraph.edges for e in expected_dependencies)
    assert 'DRIVERA' not in vgraph.nodes
    assert 'KERNELA' not in vgraph.nodes

    cg_path.unlink()
    if cg_path.with_suffix('.pdf').exists():
        cg_path.with_suffix('.pdf').unlink()


@pytest.mark.skipif(not graphviz_present(), reason='Graphviz is not installed')
def test_scheduler_graph_config_file(tmp_path, testdir, frontend):
    """
    Create a sub-graph from a branches using a config file:

    projA: compute_l1 -> compute_l2

           another_l1 -> another_l2
    """
    projA = testdir/'sources/projA'
    config = projA/'scheduler_partial.config'

    scheduler = Scheduler(paths=projA, includes=projA/'include', config=config, frontend=frontend, xmods=[tmp_path])

    expected_dependencies = {
        'compute_l1_mod#compute_l1': (),
        '#another_l1': ('#another_l2', 'header_mod'),
        '#another_l2': ('header_mod',),
        'header_mod': (),
    }

    # Check the correct sub-graph is generated
    assert set(scheduler.items) == set(expected_dependencies)
    assert set(scheduler.dependencies) == {
        (a, b) for a, deps in expected_dependencies.items() for b in deps
    }
    assert 'compute_l2' not in scheduler.items  # We're blocking `compute_l2` in config file

    # Testing of callgraph visualisation
    cg_path = tmp_path/'callgraph_config_file'
    scheduler.callgraph(cg_path)
    vgraph = VisGraphWrapper(cg_path)

    # We're blocking compute_l2 but it's still in the VGraph
    assert set(vgraph.nodes) == {name.upper() for name in expected_dependencies} | {'COMPUTE_L2'}
    assert set(vgraph.edges) == {
        (a.upper(), b.upper()) for a, deps in expected_dependencies.items() for b in deps
    } | {('COMPUTE_L1_MOD#COMPUTE_L1', 'COMPUTE_L2')}

    cg_path.unlink()
    if cg_path.with_suffix('.pdf').exists():
        cg_path.with_suffix('.pdf').unlink()


@pytest.mark.skipif(not graphviz_present(), reason='Graphviz is not installed')
@pytest.mark.parametrize('seed', ['driverA', 'driverA_mod#driverA'])
def test_scheduler_graph_blocked(tmp_path, testdir, config, frontend, seed):
    """
    Create a simple task graph with a single branch blocked:

    projA: driverA -> kernelA -> compute_l1 -> compute_l2
                           |
                           X --> 
    """
    projA = testdir/'sources/projA'

    config['default']['block'] = ['another_l1']

    scheduler = Scheduler(
        paths=projA, includes=projA/'include', config=config,
        seed_routines=[seed], frontend=frontend, xmods=[tmp_path]
    )

    expected_dependencies = {
        'drivera_mod#drivera': ('kernela_mod#kernela', 'header_mod', 'header_mod#header_type'),
        'kernela_mod#kernela': ('compute_l1_mod#compute_l1',),
        'compute_l1_mod#compute_l1': ('compute_l2_mod#compute_l2',),
        'compute_l2_mod#compute_l2': (),
        'header_mod#header_type': (),
        'header_mod': ()
    }

    assert set(scheduler.items) == set(expected_dependencies)
    assert set(scheduler.dependencies) == {
        (a, b) for a, deps in expected_dependencies.items() for b in deps
    }

    assert '#another_l1' not in scheduler.items
    assert '#another_l2' not in scheduler.items
    assert ('kernelA', 'another_l1') not in scheduler.dependencies
    assert ('another_l1', 'another_l2') not in scheduler.dependencies

    # Testing of callgraph visualisation
    cg_path = tmp_path/'callgraph_block'
    fg_path = cg_path.with_name(cg_path.name + '_file_graph')
    scheduler.callgraph(cg_path, with_file_graph=True)
    vgraph = VisGraphWrapper(cg_path)

    # We're blocking another_l1, but it's still in the VGraph
    assert set(vgraph.nodes) == {n.upper() for n in expected_dependencies} | {'ANOTHER_L1'}
    assert set(vgraph.edges) == {
        (a.upper(), b.upper()) for a, deps in expected_dependencies.items() for b in deps
    } | {('KERNELA_MOD#KERNELA', 'ANOTHER_L1')}

    file_dependencies = {
        'drivera_mod.f90': ('kernela_mod.f90', 'header_mod.f90'),
        'kernela_mod.f90': ('compute_l1_mod.f90',),
        'compute_l1_mod.f90': ('compute_l2_mod.f90',),
        'compute_l2_mod.f90': (),
        'header_mod.f90': ()
    }

    vgraph = VisGraphWrapper(fg_path)
    assert set(vgraph.nodes) == set(file_dependencies)
    assert set(vgraph.edges) == {(a, b) for a, deps in file_dependencies.items() for b in deps}

    cg_path.unlink()
    if cg_path.with_suffix('.pdf').exists():
        cg_path.with_suffix('.pdf').unlink()
    fg_path.unlink()
    if fg_path.with_suffix('.pdf').exists():
        fg_path.with_suffix('.pdf').unlink()



@pytest.mark.parametrize('seed', ['driverA', 'driverA_mod#driverA'])
def test_scheduler_definitions(testdir, config, frontend, seed, tmp_path):
    """
    Create a simple task graph and inject type info via `definitions`.

    projA: driverA -> kernelA -> compute_l1 -> compute_l2
                           |
                     
                           | --> another_l1 -> another_l2
    """
    projA = testdir/'sources/projA'

    header = Sourcefile.from_file(projA/'module/header_mod.f90', frontend=frontend)

    scheduler = Scheduler(
        paths=projA, definitions=header['header_mod'], includes=projA/'include',
        config=config, seed_routines=[seed], frontend=frontend, xmods=[tmp_path]
    )

    driver = scheduler.item_factory.item_cache['drivera_mod#drivera'].ir
    call = FindNodes(ir.CallStatement).visit(driver.body)[0]
    assert call.arguments[0].parent.type.dtype.typedef is not BasicType.DEFERRED
    assert fexprgen(call.arguments[0].shape) == '(:,)'
    assert call.arguments[1].parent.type.dtype.typedef is not BasicType.DEFERRED
    assert fexprgen(call.arguments[1].shape) == '(3, 3)'


@pytest.mark.parametrize('seed', ['compute_l1', 'compute_l1_mod#compute_l1'])
def test_scheduler_process(testdir, config, frontend, seed, tmp_path):
    """
    Create a simple task graph from a single sub-project
    and apply a simple transformation to it.

    projA: driverA -> kernelA -> compute_l1 -> compute_l2
                           |            
                           |
                           | --> another_l1 -> another_l2
                                        
    """
    projA = testdir/'sources/projA'

    config['routines'] = {
        seed: {
            'role': 'driver',
            'expand': True,
        },
        'another_l1': {
            'role': 'driver',
            'expand': True,
        },
    }

    scheduler = Scheduler(paths=projA, includes=projA/'include', config=config, frontend=frontend, xmods=[tmp_path])

    class RoleComment(Transformation):
        """
        Simply add role as a comment in the subroutine body.
        """
        def transform_subroutine(self, routine, **kwargs):
            role = kwargs.get('role', None)
            routine.body.prepend(ir.Comment(f'! {role}'))

    # Apply re-naming transformation and check result
    scheduler.process(transformation=RoleComment())

    key_role_map = {
        'compute_l1_mod#compute_l1': 'driver',
        'compute_l2_mod#compute_l2': 'kernel',
        '#another_l1': 'driver',
        '#another_l2': 'kernel',
    }
    for key, role in key_role_map.items():
        comment = scheduler[key].ir.body.body[0]
        assert isinstance(comment, ir.Comment)
        assert comment.text == f'! {role}'


@pytest.mark.skipif(not graphviz_present(), reason='Graphviz is not installed')
@pytest.mark.parametrize('seed', ['driverE_single', 'driverE_mod#driverE_single'])
def test_scheduler_process_filter(testdir, config, frontend, seed, tmp_path):
    """
    Applies simple kernels over complex callgraphs to check that we
    only apply to the entities requested and only once!

    projA: driverE_single -> kernelE -> compute_l1 -> compute_l2
                                  |
                                  | --> ghost_busters
    """
    projA = testdir/'sources/projA'
    projB = testdir/'sources/projB'

    config['routines'] = {
        seed: {'role': 'driver', 'expand': True,},
    }

    scheduler = Scheduler(
        paths=[projA, projB], includes=projA/'include', config=config, frontend=frontend, xmods=[tmp_path]
    )

    class XMarksTheSpot(Transformation):
        """
        Prepend an 'X' comment to a given :any:`Subroutine`
        """
        def transform_subroutine(self, routine, **kwargs):
            routine.body.prepend(ir.Comment('! X'))

    # Apply transformation and check result
    scheduler.process(transformation=XMarksTheSpot())

    key_x_map = {
        'drivere_mod#drivere_single': True,
        'drivere_mod#drivere_multiple': False,
        'kernele_mod#kernele': True,
        'kernele_mod#kernelet': False,
        'compute_l1_mod#compute_l1': True,
        'compute_l2_mod#compute_l2': True,
    }

    # Internal member procedure is not included
    assert not any(
        item.name.endswith('#ghost_busters')
        for item in scheduler.item_factory.item_cache.values()
    )

    for key, is_transformed in key_x_map.items():
        item = scheduler[key]
        if is_transformed:
            item_ir = item.ir
        else:
            # key should not be found in the callgraph but scope should still exist in the
            # item_cache because the file has been indexed
            assert item is None
            scope_name, local_name = key.split('#')
            assert scope_name in scheduler.item_factory.item_cache
            item_ir = scheduler.item_factory.item_cache[scope_name].ir[local_name]
        first_node = item_ir.body.body[0]
        first_node_is_x = isinstance(first_node, ir.Comment) and first_node.text == '! X'
        assert first_node_is_x == is_transformed


@pytest.mark.skipif(not graphviz_present(), reason='Graphviz is not installed')
def test_scheduler_graph_multiple_combined(tmp_path, testdir, config, driverB_dependencies, frontend):
    """
    Create a single task graph spanning two projects

    projA: driverB -> kernelB -> compute_l1 -> compute_l2
                         |
    projB:          ext_driver -> ext_kernel
    """
    projA = testdir/'sources/projA'
    projB = testdir/'sources/projB'

    scheduler = Scheduler(
        paths=[projA, projB], includes=projA/'include', config=config,
        seed_routines=['driverB_mod#driverB'], frontend=frontend, xmods=[tmp_path]
    )

    assert set(scheduler.items) == {item.lower() for item in driverB_dependencies}
    assert set(scheduler.dependencies) == {
        (item.lower(), child.lower())
        for item, children in driverB_dependencies.items()
        for child in children
    }

    # Testing of callgraph visualisation
    cg_path = tmp_path/'callgraph_multiple_combined'
    scheduler.callgraph(cg_path)

    vgraph = VisGraphWrapper(cg_path)
    assert set(vgraph.nodes) == {item.upper() for item in driverB_dependencies}
    assert set(vgraph.edges) == {
        (item.upper(), child.upper())
        for item, children in driverB_dependencies.items()
        for child in children
    }

    cg_path.unlink()
    if cg_path.with_suffix('.pdf').exists():
        cg_path.with_suffix('.pdf').unlink()


@pytest.mark.skipif(not graphviz_present(), reason='Graphviz is not installed')
def test_scheduler_graph_multiple_separate(tmp_path, testdir, config, frontend):
    """
    Tests combining two scheduler graphs, where that an individual
    sub-branch is pruned in the driver schedule, while IPA meta-info
    is still injected to create a seamless jump between two distinct
    schedules for projA and projB

    projA: driverB -> kernelB -> compute_l1 -> compute_l2
                         |
                     

    projB:            ext_driver -> ext_kernel
    """
    projA = testdir/'sources/projA'
    projB = testdir/'sources/projB'

    configA = config.copy()
    configA['routines'] = {
        'kernelB': {
            'role': 'kernel',
            'ignore': ['ext_driver'],
            'enrich': ['ext_driver'],
        },
    }

    schedulerA = Scheduler(
        paths=[projA, projB], includes=projA/'include', config=configA,
        seed_routines=['driverB'], frontend=frontend, xmods=[tmp_path]
    )

    expected_dependenciesA = {
        'driverb_mod#driverb': (
            'kernelb_mod#kernelb',
            'header_mod#header_type',
            'header_mod',
        ),
        'kernelb_mod#kernelb': (
            'compute_l1_mod#compute_l1',
            'ext_driver_mod#ext_driver',
        ),
        'compute_l1_mod#compute_l1': (
            'compute_l2_mod#compute_l2',
        ),
        'compute_l2_mod#compute_l2': (),
        'header_mod#header_type': (),
        'header_mod': (),
    }

    ignored_dependenciesA = {
        'ext_driver_mod#ext_driver': ('ext_kernel_mod', 'ext_kernel_mod#ext_kernel',),
        'ext_kernel_mod': (),
        'ext_kernel_mod#ext_kernel': (),
    }

    assert set(schedulerA.items) == set(chain(expected_dependenciesA, ignored_dependenciesA))
    assert set(schedulerA.dependencies) == {
        (a, b)
        for a, deps in chain(expected_dependenciesA.items(), ignored_dependenciesA.items())
        for b in deps
    }
    assert all(schedulerA[name].is_ignored for name in ignored_dependenciesA)
    assert all(not schedulerA[name].is_ignored for name in expected_dependenciesA)

    # Test callgraph visualisation
    cg_path = tmp_path/'callgraph_multiple_separate_A'
    schedulerA.callgraph(cg_path)

    vgraph = VisGraphWrapper(cg_path)
    assert set(vgraph.nodes) == {n.upper() for n in chain(expected_dependenciesA, ignored_dependenciesA)}
    assert set(vgraph.edges) == {
        (a.upper(), b.upper())
        for a, deps in chain(expected_dependenciesA.items(), ignored_dependenciesA.items())
        for b in deps
    }

    cg_path.unlink()
    if cg_path.with_suffix('.pdf').exists():
        cg_path.with_suffix('.pdf').unlink()

    # Test second scheduler instance that holds the receiver items
    configB = config.copy()
    configB['routines'] = {
        'ext_driver': { 'role': 'kernel' },
    }

    schedulerB = Scheduler(
        paths=projB, config=configB, seed_routines=['ext_driver'],
        frontend=frontend, xmods=[tmp_path]
    )

    # TODO: Technically we should check that the role=kernel has been honoured in B
    assert 'ext_driver_mod#ext_driver' in schedulerB.items
    assert 'ext_kernel_mod#ext_kernel' in schedulerB.items
    assert ('ext_driver_mod#ext_driver', 'ext_kernel_mod#ext_kernel') in schedulerB.dependencies

    # Check that the call from kernelB to ext_driver has been enriched with IPA meta-info
    call = FindNodes(ir.CallStatement).visit(schedulerA['kernelb_mod#kernelb'].ir.body)[1]
    assert isinstance(call.routine, Subroutine)
    assert fexprgen(call.routine.arguments) == '(vector(:), matrix(:, :))'

    # Test callgraph visualisation
    cg_path = tmp_path/'callgraph_multiple_separate_B'
    schedulerB.callgraph(cg_path)

    vgraphB = VisGraphWrapper(cg_path)
    assert 'EXT_DRIVER_MOD#EXT_DRIVER' in vgraphB.nodes
    assert 'EXT_KERNEL_MOD#EXT_KERNEL' in vgraphB.nodes
    assert ('EXT_DRIVER_MOD#EXT_DRIVER', 'EXT_KERNEL_MOD#EXT_KERNEL') in vgraphB.edges

    cg_path.unlink()
    if cg_path.with_suffix('.pdf').exists():
        cg_path.with_suffix('.pdf').unlink()


@pytest.mark.parametrize('strict', [True, False])
def test_scheduler_graph_multiple_separate_enrich_fail(testdir, config, frontend, strict, tmp_path):
    """
    Tests that explicit enrichment in "strict" mode will fail because it can't
    find ext_driver

    projA: driverB -> kernelB -> compute_l1 -> compute_l2
                         |
                     

    projB:            ext_driver -> ext_kernelfail
    """
    projA = testdir/'sources/projA'

    configA = config.copy()
    configA['default']['strict'] = strict
    configA['routine'] = [
        {
            'name': 'kernelB',
            'role': 'kernel',
            'ignore': ['ext_driver'],
            'enrich': ['ext_driver'],
        },
    ]

    schedulerA = Scheduler(
        paths=[projA], includes=projA/'include', config=configA,
        seed_routines=['driverB'], frontend=frontend, xmods=[tmp_path]
    )

    expected_dependenciesA = {
        'driverB_mod#driverB': ('kernelB_mod#kernelB', 'header_mod', 'header_mod#header_type'),
        'kernelB_mod#kernelB': ('compute_l1_mod#compute_l1', 'ext_driver_mod#ext_driver'),
        'compute_l1_mod#compute_l1': ('compute_l2_mod#compute_l2',),
        'compute_l2_mod#compute_l2': (),
        'header_mod': (),
        'header_mod#header_type': (),
        'ext_driver_mod#ext_driver': (),
    }

    assert set(schedulerA.items) == {node.lower() for node in expected_dependenciesA}
    assert set(schedulerA.dependencies) == {
        (a.lower(), b.lower()) for a, deps in expected_dependenciesA.items() for b in deps
    }

    class DummyTrafo(Transformation):
        pass

    if strict:
        with pytest.raises(RuntimeError):
            schedulerA.process(transformation=DummyTrafo())
    else:
        schedulerA.process(transformation=DummyTrafo())


def test_scheduler_module_dependency(testdir, config, frontend, tmp_path):
    """
    Ensure dependency chasing is done correctly, even with surboutines
    that do not match module names.

    projA: driverC -> kernelC -> compute_l1 -> compute_l2
                           |
    projC:                 | --> routine_one -> routine_two
    """
    projA = testdir/'sources/projA'
    projC = testdir/'sources/projC'

    scheduler = Scheduler(
        paths=[projA, projC], includes=projA/'include', config=config,
        seed_routines=['driverC_mod#driverC'], frontend=frontend, xmods=[tmp_path]
    )

    expected_dependencies = {
        'driverc_mod#driverc': ('header_mod', 'header_mod#header_type', 'kernelc_mod#kernelc',),
        'kernelc_mod#kernelc': ('compute_l1_mod#compute_l1', 'proj_c_util_mod#routine_one',),
        'compute_l1_mod#compute_l1': ('compute_l2_mod#compute_l2',),
        'compute_l2_mod#compute_l2': (),
        'proj_c_util_mod#routine_one': ('proj_c_util_mod#routine_two',),
        'proj_c_util_mod#routine_two': (),
        'header_mod#header_type': (),
        'header_mod': (),
    }
    assert set(scheduler.items) == set(expected_dependencies)
    assert set(scheduler.dependencies) == {
        (a, b) for a, deps in expected_dependencies.items() for b in deps
    }

    # Ensure that we got the right routines from the module
    assert scheduler['proj_c_util_mod#routine_one'].ir.name == 'routine_one'
    assert scheduler['proj_c_util_mod#routine_two'].ir.name == 'routine_two'


def test_scheduler_module_dependencies_unqualified(testdir, config, frontend, tmp_path):
    """
    Ensure dependency chasing is done correctly for unqualified module imports.

    projA: driverD -> kernelD -> compute_l1 -> compute_l2
                           |
                    < proj_c_util_mod>
                           |
    projC:                 | --> routine_one -> routine_two
    """
    projA = testdir/'sources/projA'
    projC = testdir/'sources/projC'

    scheduler = Scheduler(
        paths=[projA, projC], includes=projA/'include', config=config,
        seed_routines=['driverD_mod#driverD'], frontend=frontend, xmods=[tmp_path]
    )

    expected_dependencies = {
        'driverd_mod#driverd': ('kerneld_mod#kerneld', 'header_mod', 'header_mod#header_type'),
        'kerneld_mod#kerneld': ('compute_l1_mod#compute_l1', 'proj_c_util_mod#routine_one'),
        'compute_l1_mod#compute_l1': ('compute_l2_mod#compute_l2',),
        'compute_l2_mod#compute_l2': (),
        'proj_c_util_mod#routine_one': ('proj_c_util_mod#routine_two',),
        'proj_c_util_mod#routine_two': (),
        'header_mod#header_type': (),
        'header_mod': (),
    }
    assert set(scheduler.items) == set(expected_dependencies)
    assert set(scheduler.dependencies) == {
        (a, b) for a, deps in expected_dependencies.items() for b in deps
    }

    # Ensure that we got the right routines from the module
    assert scheduler['proj_c_util_mod#routine_one'].ir.name == 'routine_one'
    assert scheduler['proj_c_util_mod#routine_two'].ir.name == 'routine_two'


@pytest.mark.parametrize('strict', [True, False])
def test_scheduler_missing_files(testdir, config, frontend, strict, tmp_path):
    """
    Ensure that ``strict=True`` triggers failure if source paths are
    missing and that ``strict=False`` goes through gracefully.

    projA: driverC -> kernelC -> compute_l1 -> compute_l2
                           |
    projC:                 < cannot find path >
    """
    projA = testdir/'sources/projA'

    config['default']['strict'] = strict
    scheduler = Scheduler(
        paths=[projA], includes=projA/'include', config=config,
        seed_routines=['driverC_mod#driverC'], frontend=frontend, xmods=[tmp_path]
    )

    expected_dependencies = {
        'driverc_mod#driverc': ('kernelc_mod#kernelc', 'header_mod#header_type', 'header_mod'),
        'kernelc_mod#kernelc': ('compute_l1_mod#compute_l1', 'proj_c_util_mod#routine_one'),
        'compute_l1_mod#compute_l1': ('compute_l2_mod#compute_l2',),
        'compute_l2_mod#compute_l2': (),
        'header_mod#header_type': (),
        'header_mod': (),
        'proj_c_util_mod#routine_one': (),
    }
    assert set(scheduler.items) == set(expected_dependencies)
    assert set(scheduler.dependencies) == {
        (a, b) for a, deps in expected_dependencies.items() for b in deps
    }

    # Ensure that the missing items are not in the graph
    assert isinstance(scheduler['proj_c_util_mod#routine_one'], ExternalItem)
    assert 'proj_c_util_mod#routine_two' not in scheduler.items

    # Check processing with missing items
    class CheckApply(Transformation):

        def apply(self, source, post_apply_rescope_symbols=False, plan_mode=False, **kwargs):
            assert 'item' in kwargs
            assert not isinstance(kwargs['item'], ExternalItem)
            super().apply(
                source, post_apply_rescope_symbols=post_apply_rescope_symbols,
                plan_mode=plan_mode, **kwargs
            )

    if strict:
        with pytest.raises(RuntimeError):
            scheduler.process(CheckApply())
    else:
        scheduler.process(CheckApply())


@pytest.mark.parametrize('preprocess', [False, True])   # NB: With preprocessing, ext_driver is no longer
                                                        #     wrapped inside a module but instead imported
                                                        #     via an intfb.h
def test_scheduler_dependencies_ignore(tmp_path, testdir, preprocess, frontend):
    """
    Test multi-lib transformation by applying the :any:`DependencyTransformation`
    over two distinct projects with two distinct invocations.

    projA: driverB -> kernelB -> compute_l1 -> compute_l2
                         |
    projB:          ext_driver -> ext_kernel
    """
    projA = testdir/'sources/projA'
    projB = testdir/'sources/projB'

    configA = SchedulerConfig.from_dict({
        'default': {
            'role': 'kernel', 'expand': True, 'strict': True, 'enable_imports': True
        },
        'routines': {
            'driverB': {'role': 'driver'},
            'kernelB': {'ignore': ['ext_driver']},
        }
    })

    configB = SchedulerConfig.from_dict({
        'default': {
            'role': 'kernel', 'expand': True, 'strict': True, 'enable_imports': True
        },
        'routines': {
            'ext_driver': {'role': 'kernel'}
        }
    })

    schedulerA = Scheduler(
        paths=[projA, projB], includes=projA/'include', config=configA,
        frontend=frontend, preprocess=preprocess, xmods=[tmp_path]
    )

    schedulerB = Scheduler(
        paths=projB, includes=projB/'include', config=configB,
        frontend=frontend, preprocess=preprocess, xmods=[tmp_path]
    )

    expected_items_a = [
        'driverB_mod#driverB', 'kernelB_mod#kernelB',
        'compute_l1_mod#compute_l1', 'compute_l2_mod#compute_l2',
        'header_mod', 'header_mod#header_type'
    ]
    if preprocess:
        expected_items_b = [
            '#ext_driver', 'ext_kernel_mod', 'ext_kernel_mod#ext_kernel'
        ]
    else:
        expected_items_b = [
            'ext_driver_mod#ext_driver', 'ext_kernel_mod', 'ext_kernel_mod#ext_kernel'
        ]

    assert set(schedulerA.items) == {n.lower() for n in expected_items_a + expected_items_b}
    assert all(not schedulerA[name].is_ignored for name in expected_items_a)
    assert all(schedulerA[name].is_ignored for name in expected_items_b)

    assert set(schedulerB.items) == {n.lower() for n in expected_items_b}

    # Testing of callgraph visualisation
    cg_path = tmp_path/'callgraph_dependencies_ignore'
    fg_path = cg_path.with_name(cg_path.name + '_file_graph')
    schedulerA.callgraph(cg_path, with_file_graph=True)

    vgraph = VisGraphWrapper(cg_path)
    assert set(vgraph.nodes) == {n.upper() for n in expected_items_a + expected_items_b}

    file_dependencies = {
        'proja/module/driverb_mod.f90': ('proja/module/header_mod.f90', 'proja/module/kernelb_mod.f90'),
        'proja/module/header_mod.f90': (),
        'proja/module/kernelb_mod.f90': ('proja/module/compute_l1_mod.f90', 'projb/external/ext_driver_mod.f90'),
        'proja/module/compute_l1_mod.f90': ('proja/module/compute_l2_mod.f90',),
        'proja/module/compute_l2_mod.f90': (),
        'projb/external/ext_driver_mod.f90': ('projb/module/ext_kernel.f90',),
        'projb/module/ext_kernel.f90': (),
    }

    vgraph = VisGraphWrapper(fg_path)
    assert set(vgraph.nodes) == set(file_dependencies)
    assert set(vgraph.edges) == {(a, b) for a, deps in file_dependencies.items() for b in deps}

    cg_path.unlink()
    if cg_path.with_suffix('.pdf').exists():
        cg_path.with_suffix('.pdf').unlink()
    fg_path.unlink()
    if fg_path.with_suffix('.pdf').exists():
        fg_path.with_suffix('.pdf').unlink()

    # Apply dependency injection transformation and ensure only the root driver is not transformed
    transformations = (
        ModuleWrapTransformation(module_suffix='_mod'),
        DependencyTransformation(suffix='_test', module_suffix='_mod')
    )
    for transformation in transformations:
        schedulerA.process(transformation)

    assert schedulerA.items[0].source.all_subroutines[0].name == 'driverB'
    assert schedulerA.items[1].source.all_subroutines[0].name == 'kernelB_test'
    assert schedulerA.items[4].source.all_subroutines[0].name == 'compute_l1_test'
    assert schedulerA.items[5].source.all_subroutines[0].name == 'compute_l2_test'

    # Note that 'ext_driver' and 'ext_kernel' are no longer part of the dependency graph because the
    # renaming makes it impossible to discover the non-transformed routines
    assert all(not name in schedulerA for name in expected_items_b)
    assert 'ext_driver_test_mod#ext_driver_test' not in schedulerA

    # For the second target lib, we want the driver to be converted
    for transformation in transformations:
        schedulerB.process(transformation=transformation)

    # Repeat processing to ensure DependencyTransform is idempotent
    for transformation in transformations:
        schedulerB.process(transformation=transformation)

    assert schedulerB.items[0].source.all_subroutines[0].name == 'ext_driver_test'

    # This is the untransformed original module
    assert schedulerB['ext_kernel_mod'].source.all_subroutines[0].name == 'ext_kernel'

    # This is the module-wrapped procedure
    assert schedulerB['ext_kernel_test_mod#ext_kernel_test'].source.all_subroutines[0].name == 'ext_kernel_test'


def test_scheduler_cmake_planner(tmp_path, testdir, frontend):
    """
    Test the plan generation feature over a call hierarchy spanning two
    distinctive projects.

    projA: driverB -> kernelB -> compute_l1 -> compute_l2
                         |
    projB:          ext_driver -> ext_kernel
    """

    sourcedir = testdir/'sources'
    proj_a = sourcedir/'projA'
    proj_b = sourcedir/'projB'

    config = SchedulerConfig.from_dict({
        'default': {
            'role': 'kernel',
            'expand': True,
            'strict': True,
            'ignore': ('header_mod',),
            'mode': 'foobar'
        },
        'routines': {
            'driverB': {'role': 'driver'},
            'kernelB': {'ignore': ['ext_driver']},
        }
    })
    builddir = tmp_path/'scheduler_cmake_planner_dummy_dir'
    builddir.mkdir(exist_ok=True)

    # Populate the scheduler
    # (this is the same as SchedulerA in test_scheduler_dependencies_ignore, so no need to
    # check scheduler set-up itself)
    scheduler = Scheduler(
        paths=[proj_a, proj_b], includes=proj_a/'include',
        config=config, frontend=frontend, xmods=[tmp_path],
        output_dir=builddir
    )

    # Apply the transformation
    planfile = builddir/'loki_plan.cmake'

    scheduler.process(FileWriteTransformation(), proc_strategy=ProcessingStrategy.PLAN)
    scheduler.write_cmake_plan(filepath=planfile, rootpath=sourcedir)

    # Validate the plan file content
    plan_pattern = re.compile(r'set\(\s*(\w+)\s*(.*?)\s*\)', re.DOTALL)

    loki_plan = planfile.read_text()
    plan_dict = {k: v.split() for k, v in plan_pattern.findall(loki_plan)}
    plan_dict = {k: {Path(s).stem for s in v} for k, v in plan_dict.items()}

    expected_files = {
        'driverB_mod', 'kernelB_mod',
        'compute_l1_mod', 'compute_l2_mod'
    }

    assert 'LOKI_SOURCES_TO_TRANSFORM' in plan_dict
    assert plan_dict['LOKI_SOURCES_TO_TRANSFORM'] == expected_files

    assert 'LOKI_SOURCES_TO_REMOVE' in plan_dict
    assert plan_dict['LOKI_SOURCES_TO_REMOVE'] == expected_files

    assert 'LOKI_SOURCES_TO_APPEND' in plan_dict
    assert plan_dict['LOKI_SOURCES_TO_APPEND'] == {
        f'{name}.foobar' for name in expected_files
    }

    planfile.unlink()
    builddir.rmdir()


@pytest.mark.parametrize('prec', ['DP', 'SP'])
def test_scheduler_cmake_planner_libs(tmp_path, testdir, frontend, prec):
    """
    Test the plan generation feature over a call hierarchy spanning two
    distinctive projects. However, this time using the 'lib' attribute.

    projA: driverB -> kernelB -> compute_l1 -> compute_l2
                         |
    projB:          ext_driver -> ext_kernel
    """

    sourcedir = testdir/'sources'
    proj_a = sourcedir/'projA'
    proj_b = sourcedir/'projB'

    config = SchedulerConfig.from_dict({
        'default': {
            'role': 'kernel',
            'expand': True,
            'strict': True,
            'ignore': ('header_mod',),
            'mode': 'foobar',
            'lib': f'projAlib.{prec}'
        },
        'routines': {
            'driverB': {'role': 'driver'},
            'ext_driver': {'lib': f'projBlib.{prec}'}
        }
    })
    builddir = tmp_path/'scheduler_cmake_planner_libs_dummy_dir'
    builddir.mkdir(exist_ok=True)

    # Populate the scheduler
    scheduler = Scheduler(
        paths=[proj_a, proj_b], includes=proj_a/'include',
        config=config, frontend=frontend, xmods=[tmp_path],
        output_dir=builddir
    )

    # Apply the transformation
    planfile = builddir/'loki_plan_libs.cmake'
    scheduler.process(FileWriteTransformation(), proc_strategy=ProcessingStrategy.PLAN)
    scheduler.write_cmake_plan(filepath=planfile, rootpath=sourcedir)

    loki_plan = planfile.read_text()

    # Validate the plan file content
    plan_pattern = re.compile(r'set\(\s*(\w+)\s*(.*?)\s*\)', re.DOTALL)

    # loki_plan = planfile.read_text()
    plan_dict = {k: v.split() for k, v in plan_pattern.findall(loki_plan)}
    plan_dict = {k: {Path(s).stem for s in v} for k, v in plan_dict.items()}

    expected_keys = {'LOKI_SOURCES_TO_TRANSFORM', 'LOKI_SOURCES_TO_APPEND', 'LOKI_SOURCES_TO_REMOVE',
            f'LOKI_SOURCES_TO_TRANSFORM_projBlib_{prec}', f'LOKI_SOURCES_TO_APPEND_projBlib_{prec}',
            f'LOKI_SOURCES_TO_REMOVE_projBlib_{prec}', f'LOKI_SOURCES_TO_TRANSFORM_projAlib_{prec}',
            f'LOKI_SOURCES_TO_APPEND_projAlib_{prec}', f'LOKI_SOURCES_TO_REMOVE_projAlib_{prec}'}

    assert set(plan_dict.keys()) == expected_keys

    expected_files_a = {
        'driverB_mod', 'kernelB_mod',
        'compute_l1_mod', 'compute_l2_mod',
    }
    expected_files_b = {
            'ext_driver_mod', 'ext_kernel'
    }
    expected_files = expected_files_a | expected_files_b

    assert plan_dict['LOKI_SOURCES_TO_TRANSFORM'] == expected_files
    assert plan_dict['LOKI_SOURCES_TO_REMOVE'] == expected_files
    assert plan_dict['LOKI_SOURCES_TO_APPEND'] == {
        f'{name}.foobar' for name in expected_files
    }

    assert plan_dict[f'LOKI_SOURCES_TO_TRANSFORM_projAlib_{prec}'] == expected_files_a
    assert plan_dict[f'LOKI_SOURCES_TO_REMOVE_projAlib_{prec}'] == expected_files_a
    assert plan_dict[f'LOKI_SOURCES_TO_APPEND_projAlib_{prec}'] == {
        f'{name}.foobar' for name in expected_files_a
    }

    assert plan_dict[f'LOKI_SOURCES_TO_TRANSFORM_projBlib_{prec}'] == expected_files_b
    assert plan_dict[f'LOKI_SOURCES_TO_REMOVE_projBlib_{prec}'] == expected_files_b
    assert plan_dict[f'LOKI_SOURCES_TO_APPEND_projBlib_{prec}'] == {
        f'{name}.foobar' for name in expected_files_b
    }

    planfile.unlink()
    builddir.rmdir()


def test_scheduler_item_dependencies(testdir, tmp_path):
    """
    Make sure children are correct and unique for items
    """
    config = SchedulerConfig.from_dict({
        'default': {'role': 'kernel', 'expand': True, 'strict': True},
        'routines': {
            'driver': {'role': 'driver'},
            'another_driver': {'role': 'driver'}
        }
    })

    proj_hoist = testdir/'sources/projHoist'

    scheduler = Scheduler(paths=proj_hoist, config=config, xmods=[tmp_path])

    assert tuple(
        call.name for call in scheduler['transformation_module_hoist#driver'].dependencies
    ) == (
        'kernel1', 'kernel2'
    )
    assert tuple(
        call.name for call in scheduler['transformation_module_hoist#another_driver'].dependencies
    ) == (
        'kernel1',
    )
    assert not scheduler['subroutines_mod#kernel1'].dependencies
    assert tuple(
        call.name for call in scheduler['subroutines_mod#kernel2'].dependencies
    ) == (
        'device1', 'device2'
    )
    assert tuple(
        call.name for call in scheduler['subroutines_mod#device1'].dependencies
    ) == (
        'device2',
    )
    assert not scheduler['subroutines_mod#device2'].dependencies


@pytest.fixture(name='loki_69_dir')
def fixture_loki_69_dir(testdir):
    """
    Fixture to write test file for LOKI-69 test.
    """
    fcode = """
subroutine random_call_0(v_out,v_in,v_inout)
implicit none

    real(kind=jprb),intent(in)  :: v_in
    real(kind=jprb),intent(out)  :: v_out
    real(kind=jprb),intent(inout)  :: v_inout


end subroutine random_call_0

!subroutine random_call_1(v_out,v_in,v_inout)
!implicit none
!
!  real(kind=jprb),intent(in)  :: v_in
!  real(kind=jprb),intent(out)  :: v_out
!  real(kind=jprb),intent(inout)  :: v_inout
!
!
!end subroutine random_call_1

subroutine random_call_2(v_out,v_in,v_inout)
implicit none

    real(kind=jprb),intent(in)  :: v_in
    real(kind=jprb),intent(out)  :: v_out
    real(kind=jprb),intent(inout)  :: v_inout


end subroutine random_call_2

subroutine test(v_out,v_in,v_inout,some_logical)
implicit none

    real(kind=jprb),intent(in   )  :: v_in
    real(kind=jprb),intent(out  )  :: v_out
    real(kind=jprb),intent(inout)  :: v_inout

    logical,intent(in)             :: some_logical

    v_inout = 0._jprb
    if(some_logical)then
        call random_call_0(v_out,v_in,v_inout)
    endif

    if(some_logical) call random_call_2

end subroutine test
    """.strip()

    dirname = testdir/'loki69'
    dirname.mkdir(exist_ok=True)
    filename = dirname/'test.F90'
    filename.write_text(fcode)
    yield dirname
    try:
        filename.unlink()
        dirname.rmdir()
    except FileNotFoundError:
        pass


def test_scheduler_loki_69(loki_69_dir, tmp_path):
    """
    Test compliance of scheduler with edge cases reported in LOKI-69.
    """
    config = {
        'default': {
            'expand': True,
            'strict': True,
        },
    }

    scheduler = Scheduler(paths=loki_69_dir, seed_routines=['test'], config=config, xmods=[tmp_path])
    assert sorted(scheduler.item_factory.item_cache.keys()) == [
        '#random_call_0', '#random_call_2', '#test',
        str(loki_69_dir/'test.f90').lower()
    ]
    assert '#random_call_1' not in scheduler.item_factory

    children_map = {
        '#test': ('#random_call_0', '#random_call_2'),
        '#random_call_0': (),
        '#random_call_2': ()
    }
    assert len(scheduler.items) == len(children_map)
    for item in scheduler.items:
        assert set(scheduler.sgraph.successors(item)) == set(children_map[item.name])


@pytest.mark.skipif(not graphviz_present(), reason='Graphviz is not installed')
def test_scheduler_scopes(tmp_path, testdir, config, frontend):
    """
    Test discovery with import renames and duplicate names in separate scopes

      driver ----> kernel1_mod#kernel ----> kernel1_impl#kernel_impl
        |
        +--------> kernel2_mod#kernel ----> kernel2_impl#kernel_impl
    """
    proj = testdir/'sources/projScopes'

    scheduler = Scheduler(paths=proj, seed_routines=['driver'], config=config, frontend=frontend, xmods=[tmp_path])

    expected_dependencies = {
        '#driver': (
            'kernel1_mod#kernel',
            'kernel2_mod#kernel',
        ),
        'kernel1_mod#kernel': (
            'kernel1_impl',
            'kernel1_impl#kernel_impl',
        ),
        'kernel1_impl': (),
        'kernel1_impl#kernel_impl': (),
        'kernel2_mod#kernel': (
            'kernel2_impl',
            'kernel2_impl#kernel_impl',
        ),
        'kernel2_impl': (),
        'kernel2_impl#kernel_impl': (),
    }

    assert set(scheduler.items) == set(expected_dependencies)
    assert set(scheduler.dependencies) == {
        (a, b) for a, deps in expected_dependencies.items()
        for b in deps
    }

    # Testing of callgraph visualisation
    cg_path = tmp_path/'callgraph_scopes'
    scheduler.callgraph(cg_path)

    vgraph = VisGraphWrapper(cg_path)
    assert set(vgraph.nodes) == {
        n.upper() for n in expected_dependencies
    }
    assert set(vgraph.edges) == {
        (a.upper(), b.upper()) for a, deps in expected_dependencies.items()
        for b in deps
    }

    cg_path.unlink()
    cg_path.with_suffix('.pdf').unlink()


@pytest.mark.skipif(not graphviz_present(), reason='Graphviz is not installed')
def test_scheduler_typebound(tmp_path, testdir, config, frontend, proj_typebound_dependencies):
    """
    Test correct dependency chasing for typebound procedure calls.

    projTypeBound: driver -> some_type%other_routine -> other_routine -> some_type%routine1 -> routine1
                 | | | | | |                                          |                                |
                 | | | | | |       +- routine <- some_type%routine2 <-+                                +---------+
                 | | | | | |       |                                                                             |
                 | | | | | +--> some_type%some_routine -> some_routine -> some_type%routine -> module_routine  <-+
                 | | | +------> header_type%member_routine -> header_member_routine
                 | | +--------> header_type%routine -> header_type%routine_real -> header_routine_real
                 | |                           |
                 | |                           +---> header_type%routine_integer -> routine_integer
                 | +---------->other_type%member -> other_member -> header_member_routine   <--+
                 |                                                                             |
                 +------------>other_type%var%%member_routine -> header_type%member_routine  --+
    """
    proj = testdir/'sources/projTypeBound'

    scheduler = Scheduler(
        paths=proj, seed_routines=['driver'], config=config,
        full_parse=False, frontend=frontend, xmods=[tmp_path]
    )

    assert set(scheduler.items) == set(proj_typebound_dependencies)
    assert set(scheduler.dependencies) == {
        (a, b) for a, deps in proj_typebound_dependencies.items() for b in deps
    }

    # Testing of callgraph visualisation
    cg_path = tmp_path/'callgraph_typebound'
    scheduler.callgraph(cg_path)

    vgraph = VisGraphWrapper(cg_path)
    assert set(vgraph.nodes) == {n.upper() for n in proj_typebound_dependencies}
    assert set(vgraph.edges) == {
        (a.upper(), b.upper()) for a, deps in proj_typebound_dependencies.items() for b in deps
    }

    cg_path.unlink()
    cg_path.with_suffix('.pdf').unlink()


@pytest.mark.skipif(not graphviz_present(), reason='Graphviz is not installed')
def test_scheduler_typebound_ignore(tmp_path, testdir, config, frontend, proj_typebound_dependencies):
    """
    Test correct dependency chasing for typebound procedure calls with ignore working for
    typebound procedures correctly.

    projTypeBound: driver -> some_type%other_routine -> other_routine -> some_type%routine1 -> routine1
                   | | | | |                                          |                                |
                   | | | | |       +- routine <- some_type%routine2 <-+                                +---------+
                   | | | | |       |                                                                             |
                   | | | | +--> some_type%some_routine -> some_routine -> some_type%routine -> module_routine  <-+
                   | | +------> header_type%member_routine -> header_member_routine
                   | +--------> header_type%routine -> header_type%routine_real -> header_routine_real
                   |                           |
                   |                           +---> header_type%routine_integer -> routine_integer
                   +---------->other_type%member -> other_member -> header_member_routine
    """
    proj = testdir/'sources/projTypeBound'

    config['default']['disable'] += [
        'some_type%some_routine',
        'header_member_routine'
    ]
    config['routines'] = {
        'other_member': {
            'disable': config['default']['disable'] + ['member_routine']
        }
    }

    items_to_remove = (
        'typebound_item#some_type%some_routine',
        'typebound_item#some_routine',
        'typebound_item#some_type%routine',
        'typebound_header#header_member_routine',
    )

    proj_typebound_dependencies = {
        name: tuple(dep for dep in deps if dep not in items_to_remove)
        for name, deps in proj_typebound_dependencies.items()
        if name not in items_to_remove
    }

    scheduler = Scheduler(
        paths=proj, seed_routines=['driver'], config=config,
        full_parse=False, frontend=frontend, xmods=[tmp_path]
    )

    assert set(scheduler.items) == set(proj_typebound_dependencies)
    assert set(scheduler.dependencies) == {
        (a, b) for a, deps in proj_typebound_dependencies.items() for b in deps
    }

    # Testing of callgraph visualisation
    cg_path = tmp_path/'callgraph_typebound'
    scheduler.callgraph(cg_path)

    vgraph = VisGraphWrapper(cg_path)
    assert set(vgraph.nodes) == {n.upper() for n in proj_typebound_dependencies}
    assert set(vgraph.edges) == {
        (a.upper(), b.upper()) for a, deps in proj_typebound_dependencies.items() for b in deps
    }

    cg_path.unlink()
    cg_path.with_suffix('.pdf').unlink()


@pytest.mark.parametrize('use_file_graph', [False, True])
@pytest.mark.parametrize('reverse', [False, True])
def test_scheduler_traversal_order(tmp_path, testdir, config, frontend, use_file_graph, reverse):
    """
    Test correct traversal order for scheduler processing

    """
    proj = testdir/'sources/projHoist'

    scheduler = Scheduler(
        paths=proj, seed_routines=['driver'], config=config,
        full_parse=True, frontend=frontend, xmods=[tmp_path]
    )

    if use_file_graph:
        expected = ['driver_mod.f90', 'subroutines_mod.f90']
        expected = [str(proj/'module'/n).lower() + '::' + n for n in expected]
    else:
        expected = [
            'transformation_module_hoist#driver::driver', 'subroutines_mod#kernel1::kernel1',
            'subroutines_mod#kernel2::kernel2', 'subroutines_mod#device1::device1',
            'subroutines_mod#device2::device2'
        ]

    class LoggingTransformation(Transformation):

        reverse_traversal = reverse

        traverse_file_graph = use_file_graph

        def __init__(self):
            self.record = []

        def transform_file(self, sourcefile, **kwargs):
            self.record += [kwargs['item'].name + '::' + sourcefile.path.name]

        def transform_module(self, module, **kwargs):
            self.record += [kwargs['item'].name + '::' + module.name]

        def transform_subroutine(self, routine, **kwargs):
            self.record += [kwargs['item'].name + '::' + routine.name]

    transformation = LoggingTransformation()
    scheduler.process(transformation=transformation)

    if reverse:
        assert transformation.record == expected[::-1]
    else:
        assert transformation.record == expected


@pytest.mark.parametrize('use_file_graph', [False, True])
@pytest.mark.parametrize('reverse', [False, True])
@pytest.mark.parametrize('ignore_internal_procedures', [True, False])
@pytest.mark.parametrize('ignore_internal_procedures_driver', [None, True, False])
def test_scheduler_member_routines(tmp_path, config, frontend, use_file_graph, reverse,
                                   ignore_internal_procedures, ignore_internal_procedures_driver):
    """
    Make sure that transformation processing works also for contained member routines
    if enabled in the config. This includes internal procedures in module routines as well
    as free routines, and selective config overwrites to allow for fine-grained control of this behaviour
    """
    fcode_mod = """
module member_mod
    implicit none
contains
    subroutine my_routine(ret)
        integer, intent(out) :: ret
        ret = 1
    end subroutine my_routine

    subroutine driver
        integer :: val
        call my_member
        write(*,*) val
        val = my_func(val)
        call kernel
    contains
        subroutine my_member
            call my_routine(val)
        end subroutine my_member
        integer function my_func(val0)
           integer, intent(in) :: val0
           my_func = val0 + 1
        end function
    end subroutine driver
end module member_mod
    """.strip()

    fcode_kernel = """
subroutine kernel
    implicit none
    integer :: val
    call my_member
    write(*,*) val
contains
    subroutine my_member
        val = 1
    end subroutine my_member
end subroutine kernel
    """.strip()

    (tmp_path/'member_mod.F90').write_text(fcode_mod)
    (tmp_path/'kernel.F90').write_text(fcode_kernel)

    config['default']['ignore_internal_procedures'] = ignore_internal_procedures
    if ignore_internal_procedures_driver is not None:
        config['routines']['member_mod#driver'] = {'ignore_internal_procedures': ignore_internal_procedures_driver}

    scheduler = Scheduler(
        paths=[tmp_path], config=config, seed_routines=['member_mod#driver'],
        frontend=frontend, xmods=[tmp_path]
    )

    class LoggingTransformation(Transformation):

        reverse_traversal = reverse

        traverse_file_graph = use_file_graph

        def __init__(self):
            self.record = []

        def transform_file(self, sourcefile, **kwargs):
            self.record += [kwargs['item'].name + '::' + sourcefile.path.name]

        def transform_module(self, module, **kwargs):
            self.record += [kwargs['item'].name + '::' + module.name]

        def transform_subroutine(self, routine, **kwargs):
            self.record += [kwargs['item'].name + '::' + routine.name]

    transformation = LoggingTransformation()
    scheduler.process(transformation=transformation)

    if use_file_graph:
        expected = [
            f'{tmp_path/"member_mod.F90"!s}'.lower() + '::member_mod.F90',
            f'{tmp_path/"kernel.F90"!s}'.lower() + '::kernel.F90'
        ]
    else:
        expected = ['member_mod#driver::driver']
        expected_dependencies_driver = ['kernel']

        # Slightly awkward logic to capture the cases that
        #   1) we include internal procedures without any special config overrides for the driver
        #   2) we override the config for the driver to include internal procedures
        #   3) we include internal procedures but disable it for the driver
        # and mark the corresponding excpected dependencies in case 1 and 2
        include_driver_internals = (
            (ignore_internal_procedures_driver is None and not ignore_internal_procedures) or
            ignore_internal_procedures_driver is False
        )

        if include_driver_internals:
            expected += ['member_mod#driver#my_member::my_member', '#kernel::kernel',
                         'member_mod#driver#my_func::my_func']
            expected_dependencies_driver = ['my_member', *expected_dependencies_driver, 'my_func']
        else:
            expected += ['#kernel::kernel']
        expected_dependencies_kernel = []

        if include_driver_internals:
            expected += ['member_mod#my_routine::my_routine']

        if not ignore_internal_procedures:
            expected += ['#kernel#my_member::my_member']
            expected_dependencies_kernel += ['my_member']

        assert [dep.name for dep in scheduler['member_mod#driver'].dependencies] == expected_dependencies_driver
        assert [dep.name for dep in scheduler['#kernel'].dependencies] == expected_dependencies_kernel

    if reverse:
        expected = expected[::-1]

    assert transformation.record == flatten(expected)


@pytest.mark.parametrize('frontend', available_frontends())
def test_scheduler_nested_type_enrichment(tmp_path, frontend, config):
    """
    Make sure that enrichment works correctly for nested types across
    multiple source files
    """
    fcode1 = """
module typebound_procedure_calls_mod
    implicit none

    type my_type
        integer :: val
    contains
        procedure :: reset
        procedure :: add => add_my_type
    end type my_type

    type other_type
        type(my_type) :: arr(3)
    contains
        procedure :: add => add_other_type
        procedure :: total_sum
    end type other_type

contains

    subroutine reset(this)
        class(my_type), intent(inout) :: this
        this%val = 0
    end subroutine reset

    subroutine add_my_type(this, val)
        class(my_type), intent(inout) :: this
        integer, intent(in) :: val
        this%val = this%val + val
    end subroutine add_my_type

    subroutine add_other_type(this, other)
        class(other_type) :: this
        type(other_type) :: other
        integer :: i
        do i=1,3
            call this%arr(i)%add(other%arr(i)%val)
        end do
    end subroutine add_other_type

    function total_sum(this) result(result)
        class(other_type), intent(in) :: this
        integer :: result
        integer :: i
        result = 0
        do i=1,3
            result = result + this%arr(i)%val
        end do
    end function total_sum

end module typebound_procedure_calls_mod
    """.strip()

    fcode2 = """
module other_typebound_procedure_calls_mod
    use typebound_procedure_calls_mod, only: other_type
    implicit none

    type third_type
        type(other_type) :: stuff(2)
    contains
        procedure :: init
        procedure :: print => print_content
    end type third_type

contains

    subroutine init(this)
        class(third_type), intent(inout) :: this
        integer :: i, j
        do i=1,2
            do j=1,3
                call this%stuff(i)%arr(j)%reset()
                call this%stuff(i)%arr(j)%add(i+j)
            end do
        end do
    end subroutine init

    subroutine print_content(this)
        class(third_type), intent(inout) :: this
        call this%stuff(1)%add(this%stuff(2))
        print *, this%stuff(1)%total_sum()
    end subroutine print_content
end module other_typebound_procedure_calls_mod
    """.strip()

    fcode3 = """
subroutine driver
    use other_typebound_procedure_calls_mod, only: third_type
    implicit none
    type(third_type) :: data
    integer :: mysum

    call data%init()
    call data%stuff(1)%arr(1)%add(1)
    mysum = data%stuff(1)%total_sum() + data%stuff(2)%total_sum()
    call data%print
end subroutine driver
    """.strip()

    (tmp_path/'typebound_procedure_calls_mod.F90').write_text(fcode1)
    (tmp_path/'other_typebound_procedure_calls_mod.F90').write_text(fcode2)
    (tmp_path/'driver.F90').write_text(fcode3)

    scheduler = Scheduler(
        paths=[tmp_path], config=config, seed_routines=['driver'],
        frontend=frontend, xmods=[tmp_path]
    )

    driver = scheduler['#driver'].source['driver']
    calls = FindNodes(ir.CallStatement).visit(driver.body)
    assert len(calls) == 3
    for call in calls:
        assert isinstance(call.name, ProcedureSymbol)
        assert isinstance(call.name.type.dtype, ProcedureType)
        assert call.name.parent
        assert isinstance(call.name.parent.type.dtype, DerivedType)
        assert isinstance(call.routine, Subroutine)
        assert isinstance(call.name.type.dtype.procedure, Subroutine)

    assert isinstance(calls[0].name.parent, Scalar)
    assert calls[0].name.parent.type.dtype.name == 'third_type'
    assert isinstance(calls[0].name.parent.type.dtype.typedef, ir.TypeDef)

    assert isinstance(calls[1].name.parent, Array)
    assert calls[1].name.parent.type.dtype.name == 'my_type'
    assert isinstance(calls[1].name.parent.type.dtype.typedef, ir.TypeDef)

    assert isinstance(calls[1].name.parent.parent, Array)
    assert isinstance(calls[1].name.parent.parent.type.dtype, DerivedType)
    assert calls[1].name.parent.parent.type.dtype.name == 'other_type'
    assert isinstance(calls[1].name.parent.parent.type.dtype.typedef, ir.TypeDef)

    assert isinstance(calls[1].name.parent.parent.parent, Scalar)
    assert isinstance(calls[1].name.parent.parent.parent.type.dtype, DerivedType)
    assert calls[1].name.parent.parent.parent.type.dtype.name == 'third_type'
    assert isinstance(calls[1].name.parent.parent.parent.type.dtype.typedef, ir.TypeDef)

    inline_calls = FindInlineCalls().visit(driver.body)
    assert len(inline_calls) == 2
    for call in inline_calls:
        assert isinstance(call.function, ProcedureSymbol)
        assert isinstance(call.function.type.dtype, ProcedureType)

        assert call.function.parent
        assert isinstance(call.function.parent, Array)
        assert isinstance(call.function.parent.type.dtype, DerivedType)
        assert call.function.parent.type.dtype.name == 'other_type'
        assert isinstance(call.function.parent.type.dtype.typedef, ir.TypeDef)

        assert call.function.parent.parent
        assert isinstance(call.function.parent.parent, Scalar)
        assert isinstance(call.function.parent.parent.type.dtype, DerivedType)
        assert call.function.parent.parent.type.dtype.name == 'third_type'
        assert isinstance(call.function.parent.parent.type.dtype.typedef, ir.TypeDef)


@pytest.mark.parametrize('frontend', available_frontends())
def test_scheduler_interface_inline_call(tmp_path, testdir, config, frontend):
    """
    Test that inline function calls declared via an explicit interface are added as dependencies.
    """

    my_config = config.copy()
    my_config['routines'] = {
        'driver': {
            'role': 'driver',
            # 'disable': ['return_one', 'some_var', 'add_args', 'some_type']
        }
    }

    scheduler = Scheduler(
        paths=testdir/'sources/projInlineCalls', config=my_config, frontend=frontend, xmods=[tmp_path]
    )

    expected_dependencies = {
        '#driver': (
            '#double_real', 'some_module', 'some_module#add_args', 'some_module#return_one',
            'some_module#some_type', 'some_module#some_type%do_something', 'vars_module',
        ),
        '#double_real': ('vars_module',),
        'some_module': (),
        'some_module#add_args': ('some_module#add_two_args', 'some_module#add_three_args'),
        'some_module#add_two_args': (),
        'some_module#add_three_args': (),
        'some_module#return_one': (),
        'some_module#some_type': (),
        'some_module#some_type%do_something': ('some_module#add_const',),
        'some_module#add_const': ('some_module#some_type',),
        'vars_module': (),
    }

    assert set(scheduler.items) == set(expected_dependencies)
    assert set(scheduler.dependencies) == {
        (a, b) for a, deps in expected_dependencies.items() for b in deps
    }

    assert isinstance(scheduler['some_module#add_args'], InterfaceItem)
    assert isinstance(scheduler['#double_real'], ProcedureItem)
    assert isinstance(scheduler['some_module#some_type'], TypeDefItem)
    assert isinstance(scheduler['some_module#add_two_args'], ProcedureItem)
    assert isinstance(scheduler['some_module#add_three_args'], ProcedureItem)

    # Testing of callgraph visualisation with imports
    cg_path = tmp_path/'callgraph'
    scheduler.callgraph(cg_path)

    vgraph = VisGraphWrapper(cg_path)
    assert set(vgraph.nodes) == {i.upper() for i in expected_dependencies}
    assert set(vgraph.edges) == {
        (a.upper(), b.upper()) for a, deps in expected_dependencies.items() for b in deps
    }


@pytest.mark.parametrize('frontend', available_frontends())
def test_scheduler_interface_dependencies(tmp_path, frontend, config):
    """
    Ensure that interfaces are treated as intermediate nodes and incur
    dependencies on the actual procedures
    """
    fcode_module = """
module test_scheduler_interface_dependencies_mod
    implicit none
    interface my_intf
        procedure proc1
        procedure proc2
    end interface my_intf
contains
    subroutine proc1(arg)
        integer, intent(inout) :: arg
        arg = arg + 1
    end subroutine proc1
    subroutine proc2(arg)
        real, intent(inout) :: arg
        arg = arg + 1.0
    end subroutine proc2
end module test_scheduler_interface_dependencies_mod
    """
    fcode_driver = """
subroutine test_scheduler_interface_dependencies_driver
    use test_scheduler_interface_dependencies_mod, only: my_intf
    implicit none
    integer i
    real a
    i = 0
    a = 0.0
    call my_intf(i)
    call my_intf(a)
end subroutine test_scheduler_interface_dependencies_driver
    """

    config['routines']['test_scheduler_interface_dependencies_driver'] = {
        'role': 'driver'
    }

    (tmp_path/'test_scheduler_interface_dependencies_mod.F90').write_text(fcode_module)
    (tmp_path/'test_scheduler_interface_dependencies_driver.F90').write_text(fcode_driver)

    scheduler = Scheduler(
        paths=[tmp_path], config=config, seed_routines=['test_scheduler_interface_dependencies_driver'],
        frontend=frontend, xmods=[tmp_path]
    )

    expected_dependencies = {
        '#test_scheduler_interface_dependencies_driver': {
            'test_scheduler_interface_dependencies_mod#my_intf'
        },
        'test_scheduler_interface_dependencies_mod#my_intf': {
            'test_scheduler_interface_dependencies_mod#proc1', 'test_scheduler_interface_dependencies_mod#proc2'
        },
        'test_scheduler_interface_dependencies_mod#proc1': set(),
        'test_scheduler_interface_dependencies_mod#proc2': set()
    }

    assert set(scheduler.items) == set(expected_dependencies)
    assert set(scheduler.dependencies) == {
        (a, b) for a, deps in expected_dependencies.items() for b in deps
    }

    assert isinstance(scheduler['test_scheduler_interface_dependencies_mod#my_intf'], InterfaceItem)
    assert isinstance(scheduler['test_scheduler_interface_dependencies_mod#proc1'], ProcedureItem)
    assert isinstance(scheduler['test_scheduler_interface_dependencies_mod#proc2'], ProcedureItem)


def test_scheduler_item_successors(testdir, config, frontend, tmp_path):
    """
    Test that scheduler.item_successors always returns the original item.
    """

    my_config = config.copy()
    my_config['routines'] = {
        'driver': { 'role': 'driver' }
    }

    scheduler = Scheduler(
        paths=testdir/'sources/projInlineCalls', config=my_config, frontend=frontend, xmods=[tmp_path]
    )
    import_item = scheduler['vars_module']
    driver_item = scheduler['#driver']
    kernel_item = scheduler['#double_real']

    idA = id(import_item)

    for successor in scheduler.sgraph.successors(driver_item):
        if successor.name == import_item.name:
            assert id(successor) == idA
    for successor in scheduler.sgraph.successors(kernel_item):
        if successor.name == import_item.name:
            assert id(successor) == idA


@pytest.mark.parametrize('trafo_item_filter', [
    Item,
    ProcedureItem,
    (ProcedureItem, InterfaceItem, ProcedureBindingItem),
    (ProcedureItem, TypeDefItem),
])
def test_scheduler_successors(tmp_path, config, trafo_item_filter):
    fcode_mod = """
module some_mod
    implicit none
    type some_type
        real :: a
    contains
        procedure :: procedure => some_procedure
        procedure :: routine
        procedure :: other
        generic :: do => procedure, routine
    end type some_type
contains
    subroutine some_procedure(t, i)
        class(some_type), intent(inout) :: t
        integer, intent(in) :: i
        t%a = t%a + real(i)
    end subroutine some_procedure

    subroutine routine(t, v)
        class(some_type), intent(inout) :: t
        real, intent(in) :: v
        t%a = t%a + v
        call t%other
    end subroutine routine

    subroutine other(t)
        class(some_type), intent(in) :: t
        print *,t%a
    end subroutine other
end module some_mod
    """.strip()

    fcode = """
subroutine caller(val)
    use some_mod, only: some_type
    implicit none
    real, intent(inout) :: val
    type(some_type) :: t
    t%a = val
    call t%routine(1)
    call t%routine(2.0)
    call t%do(10)
    call t%do(20.0)
    call t%other
    val = t%a
end subroutine caller
    """.strip()

    expected_dependencies = {
        '#caller': (
            'some_mod#some_type',
            'some_mod#some_type%routine',
            'some_mod#some_type%do',
            'some_mod#some_type%other',
        ),
        'some_mod#some_type': (),
        'some_mod#some_type%routine': ('some_mod#routine',),
        'some_mod#some_type%do': (
            'some_mod#some_type%procedure',
            'some_mod#some_type%routine',
        ),
        'some_mod#some_type%other': ('some_mod#other',),
        'some_mod#routine': (
            'some_mod#some_type',
            'some_mod#some_type%other',
        ),
        'some_mod#other': (
            'some_mod#some_type',
        ),
        'some_mod#some_type%procedure': (
            'some_mod#some_procedure',
        ),
        'some_mod#some_procedure': (
            'some_mod#some_type',
        )
    }

    class SuccessorTransformation(Transformation):

        item_filter = trafo_item_filter

        def __init__(self, expected_successors, **kwargs):
            super().__init__(**kwargs)
            self.counter = {}
            self.expected_successors = expected_successors

        def transform_subroutine(self, routine, **kwargs):
            item = kwargs.get('item')
            assert item.local_name in ('caller', 'routine', 'some_procedure', 'other')
            self.counter[item.local_name] = self.counter.get(item.local_name, 0) + 1

            sub_sgraph = kwargs.get('sub_sgraph')
            successors = as_tuple(sub_sgraph.successors(item))
            assert set(successors) == set(self.expected_successors[item.name])

    (tmp_path/'some_mod.F90').write_text(fcode_mod)
    (tmp_path/'caller.F90').write_text(fcode)

    scheduler = Scheduler(paths=[tmp_path], config=config, seed_routines=['caller'], xmods=[tmp_path])

    assert set(scheduler.items) == set(expected_dependencies)
    assert set(scheduler.dependencies) == {
        (a, b) for a, deps in expected_dependencies.items() for b in deps
    }

    # Filter expected dependencies
    if trafo_item_filter != Item:
        item_filter = as_tuple(trafo_item_filter)
        if ProcedureItem in item_filter:
            item_filter += (ProcedureBindingItem, InterfaceItem)

        expected_dependencies = {
            item_name: tuple(
                dep for dep in dependencies
                if isinstance(scheduler[dep], item_filter)
            )
            for item_name, dependencies in expected_dependencies.items()
        }

    # Add dependency-dependencies for the full successor list
    expected_successors = {}
    for item_name, dependencies in expected_dependencies.items():
        item_successors = set()
        dep_queue = deque(dependencies)
        while dep_queue:
            dep_name = dep_queue.popleft()
            item_successors.add(dep_name)
            dep_queue.extend(
                dep_dep
                for dep_dep in expected_dependencies[dep_name]
                if dep_dep not in item_successors
            )
        expected_successors[item_name] = item_successors

    transformation = SuccessorTransformation(expected_successors)
    scheduler.process(transformation=transformation)

    assert transformation.counter == {
        'caller': 1,
        'routine': 1,
        'some_procedure': 1,
        'other': 1,
    }


@pytest.mark.parametrize('full_parse', [True, False])
def test_scheduler_typebound_inline_call(tmp_path, config, full_parse):
    fcode_mod = """
module some_mod
    implicit none
    type some_type
        integer :: a
    contains
        procedure :: some_routine
        procedure :: some_function
    end type some_type
contains
    subroutine some_routine(t)
        class(some_type), intent(inout) :: t
        t%a = 5
    end subroutine some_routine

    integer function some_function(t)
        class(some_type), intent(in) :: t
        some_function = t%a
    end function some_function
end module some_mod
    """.strip()

    fcode_caller = """
subroutine caller(b)
    use some_mod, only: some_type
    implicit none
    integer, intent(inout) :: b
    type(some_type) :: t
    t%a = b
    call t%some_routine()
    b = t%some_function()
end subroutine caller
    """.strip()

    (tmp_path/'some_mod.F90').write_text(fcode_mod)
    (tmp_path/'caller.F90').write_text(fcode_caller)

    def verify_graph(scheduler, expected_dependencies):
        assert set(scheduler.items) == set(expected_dependencies)
        assert set(scheduler.dependencies) == {
            (a, b) for a, deps in expected_dependencies.items() for b in deps
        }

        assert all(item.source._incomplete is not full_parse for item in scheduler.items)

        # Testing of callgraph visualisation
        cg_path = tmp_path/'callgraph'
        scheduler.callgraph(cg_path)

        vgraph = VisGraphWrapper(cg_path)
        assert set(vgraph.nodes) == {n.upper() for n in expected_dependencies}
        assert set(vgraph.edges) == {
            (a.upper(), b.upper()) for a, deps in expected_dependencies.items()
            for b in deps
        }

    scheduler = Scheduler(
        paths=[tmp_path], config=config, seed_routines=['caller'], full_parse=full_parse, xmods=[tmp_path]
    )

    expected_dependencies = {
        '#caller': (
            'some_mod#some_type',
            'some_mod#some_type%some_routine',
        ),
        'some_mod#some_type': (),
        'some_mod#some_type%some_routine': ('some_mod#some_routine',),
        'some_mod#some_routine': ('some_mod#some_type',),
    }

    if scheduler.full_parse:
        # Inline Calls can only be fully resolved in a full parse
        expected_dependencies['#caller'] += ('some_mod#some_type%some_function',)
        expected_dependencies['some_mod#some_type%some_function'] = ('some_mod#some_function',)
        expected_dependencies['some_mod#some_function'] = ('some_mod#some_type',)

    verify_graph(scheduler, expected_dependencies)

    # TODO: test adding a nested derived type dependency!


@pytest.mark.parametrize('full_parse', [False, True])
def test_scheduler_cycle(tmp_path, config, full_parse):
    fcode_mod = """
module some_mod
    implicit none
    type some_type
        integer :: a
    contains
        procedure :: proc => some_proc
        procedure :: other => some_other
    end type some_type
contains
    recursive subroutine some_proc(this, val, recurse, fallback)
        class(some_type), intent(inout) :: this
        integer, intent(in) :: val
        logical, intent(in), optional :: recurse

        if (present(recurse)) then
            if (present(fallback)) then
                call this%other(val)
            else
                call some_proc(this, val, .true., .true.)
            end if
        else
            call this%proc(val, .true.)
        end if
    end subroutine some_proc

    subroutine some_other(this, val)
        class(some_type), intent(inout) :: this
        integer, intent(in) :: val
        this%a = val
    end subroutine some_other
end module some_mod
    """.strip()

    fcode_caller = """
subroutine caller
    use some_mod, only: some_type
    implicit none
    type(some_type) :: t

    call t%proc(1)
end subroutine caller
    """.strip()

    (tmp_path/'some_mod.F90').write_text(fcode_mod)
    (tmp_path/'caller.F90').write_text(fcode_caller)

    scheduler = Scheduler(
        paths=[tmp_path], config=config, seed_routines=['caller'], full_parse=full_parse, xmods=[tmp_path]
    )

    # Make sure we the outgoing edges from the recursive routine to the procedure binding
    # and itself are removed but the other edge still exists
    assert (scheduler['#caller'], scheduler['some_mod#some_type%proc']) in scheduler.dependencies
    assert (scheduler['some_mod#some_type%proc'], scheduler['some_mod#some_proc']) in scheduler.dependencies
    assert (scheduler['some_mod#some_proc'], scheduler['some_mod#some_type%proc']) not in scheduler.dependencies
    assert (scheduler['some_mod#some_proc'], scheduler['some_mod#some_proc']) not in scheduler.dependencies
    assert (scheduler['some_mod#some_proc'], scheduler['some_mod#some_type%other']) in scheduler.dependencies
    assert (scheduler['some_mod#some_type%other'], scheduler['some_mod#some_other']) in scheduler.dependencies


def test_scheduler_unqualified_imports(config):
    """
    Test that only qualified imports are added as children.
    """

    kernel = """
    subroutine kernel()
       use some_mod
       use other_mod, only: other_routine

       call other_routine
    end subroutine kernel
    """

    source = Sourcefile.from_source(kernel, frontend=REGEX)
    item = ProcedureItem(name='#kernel', source=source, config=config['default'])

    assert len(item.dependencies) == 3
    children = set()
    for dep in item.dependencies:
        if isinstance(dep, ir.Import):
            if dep.symbols:
                children |= {f'{dep.module}#{str(s)}'.lower() for s in dep.symbols}
            else:
                children.add(dep.module.lower())
        elif isinstance(dep, ir.CallStatement):
            children.add(str(dep.name).lower())
        else:
            assert False, 'Unexpected dependency type'
    assert children == {'some_mod', 'other_mod#other_routine', 'other_routine'}


def test_scheduler_depths(testdir, config, frontend, tmp_path):
    projA = testdir/'sources/projA'

    scheduler = Scheduler(
        paths=projA, includes=projA/'include', config=config,
        seed_routines=['driverA'], frontend=frontend, xmods=[tmp_path]
    )

    expected_depths = {
        'drivera_mod#drivera': 0,
        'header_mod#header_type': 1,
        'kernela_mod#kernela': 1,
        'compute_l1_mod#compute_l1': 2,
        '#another_l1': 2,
        'compute_l2_mod#compute_l2': 3,
        '#another_l2': 3,
        'header_mod': 4,
    }
    assert scheduler.sgraph.depths == expected_depths


def test_scheduler_disable_wildcard(testdir, config, tmp_path):

    fcode_mod = """
module field_mod
  type field2d
    contains
    procedure :: init => field_init
  end type

  type field3d
    contains
    procedure :: init => field_init
  end type

  contains
    subroutine field_init()

    end subroutine
end module
"""

    fcode_driver = """
subroutine my_driver
  use field_mod, only: field2d, field3d, field_init
implicit none

  type(field2d) :: a, b
  type(field3d) :: c, d

  call a%init()
  call b%init()
  call c%init()
  call field_init(d)
end subroutine my_driver
"""

    # Set up the test files
    dirname = testdir/'test_scheduler_disable_wildcard'
    dirname.mkdir(exist_ok=True)
    modfile = dirname/'field_mod.F90'
    modfile.write_text(fcode_mod)
    testfile = dirname/'test.F90'
    testfile.write_text(fcode_driver)

    config['default']['disable'] = ['*%init']

    scheduler = Scheduler(paths=dirname, seed_routines=['my_driver'], config=config, xmods=[tmp_path])

    expected_items = [
        '#my_driver', 'field_mod#field_init',
    ]
    expected_dependencies = [
        ('#my_driver', 'field_mod#field_init'),
    ]

    assert all(n in scheduler.items for n in expected_items)
    assert all(e in scheduler.dependencies for e in expected_dependencies)

    assert 'field_mod#field2d%init' not in scheduler.items
    assert 'field_mod#field3d%init' not in scheduler.items

    # Clean up
    try:
        modfile.unlink()
        testfile.unlink()
        dirname.rmdir()
    except FileNotFoundError:
        pass


def test_transformation_config(config):
    """
    Test the correct instantiation of :any:`Transformation` objecst from config
    """
    my_config = config.copy()
    my_config['transformations'] = {
        'DependencyTransformation': {
            'module': 'loki.transformations.build_system',
            'options':
            {
                'suffix': '_rick',
                'module_suffix': '_roll',
                'replace_ignore_items': False,
            }
        },
        # Instantiate IdemTransformation entry without options
        'IdemTransformation': {
            'module': 'loki.transformations',
        }
    }
    cfg = SchedulerConfig.from_dict(my_config)
    assert cfg.transformations['DependencyTransformation']

    transformation = cfg.transformations['DependencyTransformation']
    assert isinstance(transformation, DependencyTransformation)
    assert transformation.suffix == '_rick'
    assert transformation.module_suffix == '_roll'
    assert not transformation.replace_ignore_items

    # Test for errors when failing to instantiate a transformation
    bad_config = config.copy()
    bad_config['transformations'] = {
        'DependencyTrafo': {  # <= typo
            'module': 'loki.transformations.build_system',
            'options': {}
        }
    }
    with pytest.raises(RuntimeError):
        cfg = SchedulerConfig.from_dict(bad_config)

    worse_config = config.copy()
    worse_config['transformations'] = {
        'DependencyTransform': {
            'module': 'loki.transformats.build_system',  # <= typo
            'options': {}
        }
    }
    with pytest.raises(ModuleNotFoundError):
        cfg = SchedulerConfig.from_dict(worse_config)

    worst_config = config.copy()
    worst_config['transformations'] = {
        'DependencyTransform': {
            'module': 'loki.transformations.build_system',
            'options': {'hello': 'Dave'}
        }
    }
    with pytest.raises(RuntimeError):
        cfg = SchedulerConfig.from_dict(worst_config)


def test_transformation_config_external_with_dimension(testdir, config):
    """
    Test instantiation of :any:`Transformation` from config with
    :any:`Dimension` argument.
    """
    my_config = config.copy()
    my_config['dimensions'] = {
        'ij': {'size': 'n', 'index': 'i'}
    }
    my_config['transformations'] = {
        'CallMeRick': {
            'classname': 'CallMeMaybeTrafo',
            'module': 'call_me_trafo',
            'path': str(testdir/'sources'),
            'options': { 'name': 'Rick', 'horizontal': '%dimensions.ij%' }
        }
    }
    cfg = SchedulerConfig.from_dict(my_config)
    assert cfg.transformations['CallMeRick']

    transformation = cfg.transformations['CallMeRick']
    # We don't have the type, so simply check the class name
    assert type(transformation).__name__ == 'CallMeMaybeTrafo'
    assert transformation.name == 'Rick'
    assert isinstance(transformation.horizontal, Dimension)
    assert transformation.horizontal.size == 'n'
    assert transformation.horizontal.index == 'i'


@pytest.mark.parametrize('item_name,keys,use_pattern_matching,match_item_parents,expected', [
    ('comp2', 'comp2', True, True, ('comp2',)),
    ('#comp2', 'comp2', True, True, ('comp2',)),
    ('comp2', '#comp2', True, True, ()),  # This is key: If the config key is provided with explicit scope,
                                          # we don't match unscoped names
    ('#comp2', '#comp2', True, True, ('#comp2',))
])
def test_scheduler_config_match_item_keys(item_name, keys, use_pattern_matching, match_item_parents, expected):
    value = SchedulerConfig.match_item_keys(item_name, keys, use_pattern_matching, match_item_parents)
    assert value == expected


@pytest.mark.parametrize('frontend', available_frontends())
def test_scheduler_filter_items_file_graph(tmp_path, frontend, config):
    """
    Ensure that the ``items`` list given to a transformation in
    a file graph traversal is filtered to include only used items
    """
    fcode = """
module test_scheduler_filter_program_units_file_graph_mod1
implicit none
contains
subroutine proc1(arg)
    integer, intent(inout) :: arg
    arg = arg + 1
end subroutine proc1

subroutine unused_proc(arg)
    integer, intent(inout) :: arg
    arg = arg - 1
end subroutine unused_proc
end module test_scheduler_filter_program_units_file_graph_mod1

module test_scheduler_filter_program_units_file_graph_mod2
implicit none
contains
subroutine proc2(arg)
    integer, intent(inout) :: arg
    arg = arg + 2
end subroutine proc2
end module test_scheduler_filter_program_units_file_graph_mod2

module test_scheduler_filter_program_units_file_graph_mod3
implicit none
integer, parameter :: param3 = 3
contains
subroutine proc3(arg)
    integer, intent(inout) :: arg
    arg = arg + 3
end subroutine proc3
end module test_scheduler_filter_program_units_file_graph_mod3

subroutine test_scheduler_filter_program_units_file_graph_driver
use test_scheduler_filter_program_units_file_graph_mod1, only: proc1
use test_scheduler_filter_program_units_file_graph_mod3, only: param3
implicit none
integer :: i
i = param3
call proc1(i)
end subroutine test_scheduler_filter_program_units_file_graph_driver
    """

    config['routines']['test_scheduler_filter_program_units_file_graph_driver'] = {
        'role': 'driver'
    }

    filepath = tmp_path/'test_scheduler_filter_program_units_file_graph.F90'
    filepath.write_text(fcode)

    scheduler = Scheduler(
        paths=[tmp_path], config=config, seed_routines=['test_scheduler_filter_program_units_file_graph_driver'],
        frontend=frontend, xmods=[tmp_path]
    )

    # Only the driver and mod1 are in the Sgraph
    expected_dependencies = {
        '#test_scheduler_filter_program_units_file_graph_driver': {
            'test_scheduler_filter_program_units_file_graph_mod1#proc1',
            'test_scheduler_filter_program_units_file_graph_mod3'
        },
        'test_scheduler_filter_program_units_file_graph_mod1#proc1': set(),
        'test_scheduler_filter_program_units_file_graph_mod3': set()
    }

    assert set(scheduler.items) == set(expected_dependencies)
    assert set(scheduler.dependencies) == {
        (a, b) for a, deps in expected_dependencies.items() for b in deps
    }

    # The other module and procedure are in the item_factory's cache
    assert 'test_scheduler_filter_program_units_file_graph_mod2' in scheduler.item_factory.item_cache
    assert 'test_scheduler_filter_program_units_file_graph_mod1#unused_proc' in scheduler.item_factory.item_cache

    # The filegraph consists of the single file
    filegraph = scheduler.file_graph
    assert filegraph.items == (str(filepath).lower(),)

    class MyFileTrafo(Transformation):
        traverse_file_graph = True

        def transform_file(self, sourcefile, **kwargs):
            # Only active items should be passed to the transformation
            assert 'items' in kwargs
            assert set(kwargs['items']) == {
                'test_scheduler_filter_program_units_file_graph_mod1',
                'test_scheduler_filter_program_units_file_graph_mod1#proc1',
                'test_scheduler_filter_program_units_file_graph_mod3',
                '#test_scheduler_filter_program_units_file_graph_driver'
            }

    scheduler.process(transformation=MyFileTrafo())


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('frontend_args,defines,preprocess,has_cpp_directives,additional_dependencies', [
    # No preprocessing, thus all call dependencies are included
    (None, None, False, [
        '#test_scheduler_frontend_args1', '#test_scheduler_frontend_args2', '#test_scheduler_frontend_args4'
    ], {
        '#test_scheduler_frontend_args2': ('#test_scheduler_frontend_args3',),
        '#test_scheduler_frontend_args3': (),
        '#test_scheduler_frontend_args4': ('#test_scheduler_frontend_args3',),
    }),
    # Global preprocessing setting SOME_DEFINITION, removing dependency on 3
    (None, ['SOME_DEFINITION'], True, [], {}),
    # Global preprocessing with local definition for one file, re-adding a dependency on 3
    (
        {'file3_4.F90': {'defines': ['SOME_DEFINITION','LOCAL_DEFINITION']}},
        ['SOME_DEFINITION'],
        True,
        [],
        {
            '#test_scheduler_frontend_args3': (),
            '#test_scheduler_frontend_args4': ('#test_scheduler_frontend_args3',),
        }
    ),
    # Global preprocessing with preprocessing switched off for 2
    (
        {'file2.F90': {'preprocess': False}},
        ['SOME_DEFINITION'],
        True,
        ['#test_scheduler_frontend_args2'],
        {
            '#test_scheduler_frontend_args2': ('#test_scheduler_frontend_args3',),
            '#test_scheduler_frontend_args3': (),
        }
    ),
    # No preprocessing except for 2
    (
        {'file2.F90': {'preprocess': True, 'defines': ['SOME_DEFINITION']}},
        None,
        False,
        ['#test_scheduler_frontend_args1', '#test_scheduler_frontend_args4'],
        {
            '#test_scheduler_frontend_args3': (),
            '#test_scheduler_frontend_args4': ('#test_scheduler_frontend_args3',),
        }
    ),
])
def test_scheduler_frontend_args(tmp_path, frontend, frontend_args, defines, preprocess,
                                 has_cpp_directives, additional_dependencies, config):
    """
    Test overwriting frontend options via Scheduler config
    """

    fcode1 = """
subroutine test_scheduler_frontend_args1
    implicit none
#ifdef SOME_DEFINITION
    call test_scheduler_frontend_args2
#endif
end subroutine test_scheduler_frontend_args1
    """.strip()

    fcode2 = """
subroutine test_scheduler_frontend_args2
    implicit none
#ifndef SOME_DEFINITION
    call test_scheduler_frontend_args3
#endif
    call test_scheduler_frontend_args4
end subroutine test_scheduler_frontend_args2
    """.strip()

    fcode3_4 = """
subroutine test_scheduler_frontend_args3
implicit none
end subroutine test_scheduler_frontend_args3

subroutine test_scheduler_frontend_args4
implicit none
#ifdef LOCAL_DEFINITION
    call test_scheduler_frontend_args3
#endif
end subroutine test_scheduler_frontend_args4
    """.strip()

    (tmp_path/'file1.F90').write_text(fcode1)
    (tmp_path/'file2.F90').write_text(fcode2)
    (tmp_path/'file3_4.F90').write_text(fcode3_4)

    expected_dependencies = {
        '#test_scheduler_frontend_args1': ('#test_scheduler_frontend_args2',),
        '#test_scheduler_frontend_args2': ('#test_scheduler_frontend_args4',),
        '#test_scheduler_frontend_args4': (),
    }

    for key, value in additional_dependencies.items():
        expected_dependencies[key] = expected_dependencies.get(key, ()) + value

    config['frontend_args'] = frontend_args

    scheduler = Scheduler(
        paths=[tmp_path], config=config, seed_routines=['test_scheduler_frontend_args1'],
        frontend=frontend, defines=defines, preprocess=preprocess, xmods=[tmp_path]
    )

    assert set(scheduler.items) == set(expected_dependencies)
    assert set(scheduler.dependencies) == {
        (a, b) for a, deps in expected_dependencies.items() for b in deps
    }

    for item in scheduler.items:
        cpp_directives = FindNodes(ir.PreprocessorDirective).visit(item.ir.ir)
        assert bool(cpp_directives) == (item in has_cpp_directives and frontend != OMNI)
        # NB: OMNI always does preprocessing, therefore we won't find the CPP directives
        #     after the full parse


@pytest.mark.skipif(not (HAVE_OMNI and HAVE_FP), reason="OMNI or FP not available")
def test_scheduler_frontend_overwrite(tmp_path, config):
    """
    Test the use of a different frontend via Scheduler config
    """
    fcode_header = """
module test_scheduler_frontend_overwrite_header
    implicit none
    type some_type
        ! We have a comment
        real, dimension(:,:), pointer :: arr
    end type some_type
end module test_scheduler_frontend_overwrite_header
    """.strip()
    fcode_kernel = """
subroutine test_scheduler_frontend_overwrite_kernel
    use test_scheduler_frontend_overwrite_header, only: some_type
    implicit none
    type(some_type) :: var
end subroutine test_scheduler_frontend_overwrite_kernel
    """.strip()

    (tmp_path/'test_scheduler_frontend_overwrite_header.F90').write_text(fcode_header)
    (tmp_path/'test_scheduler_frontend_overwrite_kernel.F90').write_text(fcode_kernel)

    # Make sure that OMNI cannot parse the header file
    with pytest.raises(CalledProcessError):
        Sourcefile.from_source(fcode_header, frontend=OMNI, xmods=[tmp_path])

    # ...and that the problem exists also during Scheduler traversal
    with pytest.raises(CalledProcessError):
        Scheduler(
            paths=[tmp_path], config=config, seed_routines=['test_scheduler_frontend_overwrite_kernel'],
            frontend=OMNI, xmods=[tmp_path]
        )

    # Strip the comment from the header file and parse again to generate an xmod
    fcode_header_lines = fcode_header.split('\n')
    Sourcefile.from_source('\n'.join(fcode_header_lines[:3] + fcode_header_lines[4:]), frontend=OMNI, xmods=[tmp_path])

    # Setup the config with the frontend overwrite
    config['frontend_args'] = {
        'test_scheduler_frontend_overwrite_header.F90': {'frontend': 'FP'}
    }

    # ...and now it works fine
    scheduler = Scheduler(
        paths=[tmp_path], config=config, seed_routines=['test_scheduler_frontend_overwrite_kernel'],
        frontend=OMNI, xmods=[tmp_path]
    )

    assert set(scheduler.items) == {
        '#test_scheduler_frontend_overwrite_kernel', 'test_scheduler_frontend_overwrite_header#some_type'
    }

    assert set(scheduler.dependencies) == {
       ('#test_scheduler_frontend_overwrite_kernel', 'test_scheduler_frontend_overwrite_header#some_type')
    }

    # ...and the derived type has it's comment
    comments = FindNodes(ir.Comment).visit(scheduler['test_scheduler_frontend_overwrite_header#some_type'].ir.body)
    assert len(comments) == 1
    assert comments[0].text == '! We have a comment'


def test_scheduler_pipeline_simple(testdir, config, frontend, tmp_path):
    """
    Test processing a :any:`Pipeline` over a simple call-tree.

    projA: driverA -> kernelA -> compute_l1 -> compute_l2
                           |
                           | --> another_l1 -> another_l2
    """
    projA = testdir/'sources/projA'

    scheduler = Scheduler(
        paths=projA, includes=projA/'include', config=config,
        seed_routines='driverA', frontend=frontend, xmods=[tmp_path]
    )

    class ZeroMyStuffTrafo(Transformation):
        """ Fill each argument array with 0.0 """

        def transform_subroutine(self, routine, **kwargs):
            for v in routine.variables:
                if isinstance(v, Array):
                    routine.body.append(ir.Assignment(lhs=v, rhs=Literal(0.0)))

    class AddSnarkTrafo(Transformation):
        """ Add a snarky comment to the zeroing """

        def __init__(self, name='Rick'):
            self.name = name

        def transform_subroutine(self, routine, **kwargs):
            routine.body.append(ir.Comment(text=''))  # Add a newline
            routine.body.append(ir.Comment(text=f'! Sorry {self.name}, no values for you!'))

    def has_correct_assigns(routine, num_assign, values=None):
        assigns = FindNodes(ir.Assignment).visit(routine.body)
        values = values or [0.0]
        return len(assigns) == num_assign and all(a.rhs in values for a in assigns)

    def has_correct_comments(routine, name='Dave'):
        text = f'! Sorry {name}, no values for you!'
        comments = FindNodes(ir.Comment).visit(routine.body)
        return len(comments) > 2 and comments[-1].text == text

    # First apply in sequence and check effect
    scheduler.process(transformation=ZeroMyStuffTrafo())
    assert has_correct_assigns(scheduler['drivera_mod#drivera'].ir, 0)
    assert has_correct_assigns(scheduler['kernela_mod#kernela'].ir, 2)
    assert has_correct_assigns(scheduler['compute_l1_mod#compute_l1'].ir, 1)
    assert has_correct_assigns(scheduler['compute_l2_mod#compute_l2'].ir, 2, values=[66.0, 00])
    assert has_correct_assigns(scheduler['#another_l1'].ir, 1)
    assert has_correct_assigns(scheduler['#another_l2'].ir, 2, values=[77.0, 00])

    scheduler.process(transformation=AddSnarkTrafo(name='Dave'))
    assert has_correct_comments(scheduler['drivera_mod#drivera'].ir)
    assert has_correct_comments(scheduler['kernela_mod#kernela'].ir)
    assert has_correct_comments(scheduler['compute_l1_mod#compute_l1'].ir)
    assert has_correct_comments(scheduler['compute_l2_mod#compute_l2'].ir)
    assert has_correct_comments(scheduler['#another_l1'].ir)
    assert has_correct_comments(scheduler['#another_l2'].ir)

    # Rebuild the scheduler to wipe the previous result
    scheduler = Scheduler(
        paths=projA, includes=projA/'include', config=config,
        seed_routines='driverA', frontend=frontend, xmods=[tmp_path]
    )

    # Then apply as a simple pipeline and check again
    MyPipeline = partial(Pipeline, classes=(ZeroMyStuffTrafo, AddSnarkTrafo))
    scheduler.process(transformation=MyPipeline(name='Chad'))
    assert has_correct_assigns(scheduler['drivera_mod#drivera'].ir, 0)
    assert has_correct_assigns(scheduler['kernela_mod#kernela'].ir, 2)
    assert has_correct_assigns(scheduler['compute_l1_mod#compute_l1'].ir, 1)
    assert has_correct_assigns(scheduler['compute_l2_mod#compute_l2'].ir, 2, values=[66.0, 00])
    assert has_correct_assigns(scheduler['#another_l1'].ir, 1)
    assert has_correct_assigns(scheduler['#another_l2'].ir, 2, values=[77.0, 00])

    assert has_correct_comments(scheduler['drivera_mod#drivera'].ir, name='Chad')
    assert has_correct_comments(scheduler['kernela_mod#kernela'].ir, name='Chad')
    assert has_correct_comments(scheduler['compute_l1_mod#compute_l1'].ir, name='Chad')
    assert has_correct_comments(scheduler['compute_l2_mod#compute_l2'].ir, name='Chad')
    assert has_correct_comments(scheduler['#another_l1'].ir, name='Chad')
    assert has_correct_comments(scheduler['#another_l2'].ir, name='Chad')


def test_pipeline_config_compose(config):
    """
    Test the correct instantiation of a custom :any:`Pipeline`
    object from config.
    """
    my_config = config.copy()
    my_config['dimensions'] = {
        'horizontal': { 'size': 'KLON', 'index': 'JL', 'bounds': ['KIDIA', 'KFDIA'] },
        'block_dim': { 'size': 'NGPBLKS', 'index': 'IBL' },
    }
    my_config['transformations'] = {
        'VectorWithTrim': {
            'classname': 'SCCVectorPipeline',
            'module': 'loki.transformations.single_column',
            'options':
            {
                'horizontal': '%dimensions.horizontal%',
                'block_dim': '%dimensions.block_dim%',
                'directive': 'openacc',
                'trim_vector_sections': True,
            },
        },
        'preprocess': {
            'classname': 'RemoveCodeTransformation',
            'module': 'loki.transformations',
            'options': {
                'call_names': 'dr_hook',
                'remove_imports': False
            }
        },
        'postprocess': {
            'classname': 'ModuleWrapTransformation',
            'module': 'loki.transformations.build_system',
            'options': { 'module_suffix': '_module' }
        }
    }
    my_config['pipelines'] = {
        'MyVectorPipeline': {
            'transformations': [
                'preprocess',
                'VectorWithTrim',
                'postprocess',
            ],
        }
    }
    cfg = SchedulerConfig.from_dict(my_config)

    # Check that transformations and pipelines were created correctly
    assert cfg.transformations['VectorWithTrim']
    assert cfg.transformations['preprocess']
    assert cfg.transformations['postprocess']

    assert cfg.pipelines['MyVectorPipeline']
    pipeline = cfg.pipelines['MyVectorPipeline']
    assert isinstance(pipeline, Pipeline)

    # Check that the pipeline is correctly composed
    assert len(pipeline.transformations) == 9
    assert type(pipeline.transformations[0]).__name__ == 'RemoveCodeTransformation'
    assert type(pipeline.transformations[1]).__name__ == 'SCCFuseVerticalLoops'
    assert type(pipeline.transformations[2]).__name__ == 'SCCBaseTransformation'
    assert type(pipeline.transformations[3]).__name__ == 'SCCDevectorTransformation'
    assert type(pipeline.transformations[4]).__name__ == 'SCCDemoteTransformation'
    assert type(pipeline.transformations[5]).__name__ == 'SCCVecRevectorTransformation'
    assert type(pipeline.transformations[6]).__name__ == 'SCCAnnotateTransformation'
    assert type(pipeline.transformations[7]).__name__ == 'PragmaModelTransformation'
    assert type(pipeline.transformations[8]).__name__ == 'ModuleWrapTransformation'

    # Check for some specified and default constructor flags
    assert pipeline.transformations[0].call_names == ('dr_hook',)
    assert pipeline.transformations[0].remove_imports is False
    assert isinstance(pipeline.transformations[2].horizontal, Dimension)
    assert pipeline.transformations[2].horizontal.size == 'KLON'
    assert pipeline.transformations[2].horizontal.index == 'JL'
    assert pipeline.transformations[3].trim_vector_sections is True
    assert pipeline.transformations[8].replace_ignore_items is True


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('enable_imports', [False, True])
@pytest.mark.parametrize('import_level', ['module', 'subroutine'])
def test_scheduler_indirect_import(frontend, tmp_path, enable_imports, import_level):
    fcode_mod_a = """
module a_mod
    implicit none
    public
    integer :: global_a = 1
end module a_mod
"""

    fcode_mod_b = """
module b_mod
    use a_mod
    implicit none
    public
    type type_b
        integer :: val
    end type type_b
end module b_mod
"""

    module_import_stmt = ""
    routine_import_stmt = ""
    if import_level == 'module':
        module_import_stmt = "use b_mod, only: type_b, global_a"
    elif import_level == 'subroutine':
        routine_import_stmt = "use b_mod, only: type_b, global_a"

    fcode_mod_c = f"""
module c_mod
    {module_import_stmt}
    implicit none
contains
    subroutine c(b)
        {routine_import_stmt}
        implicit none
        type(type_b), intent(inout) :: b
        b%val = global_a
    end subroutine c
end module c_mod
"""

    # Set-up paths and write sources
    src_path = tmp_path/'src'
    src_path.mkdir()
    out_path = tmp_path/'build'
    out_path.mkdir()

    (src_path/'a.F90').write_text(fcode_mod_a)
    (src_path/'b.F90').write_text(fcode_mod_b)
    (src_path/'c.F90').write_text(fcode_mod_c)

    # Create the Scheduler
    config = SchedulerConfig.from_dict({
        'default': {
            'role': 'kernel',
            'expand': True,
            'strict': True,
            'enable_imports': enable_imports
        },
        'routines': {'c': {'role': 'driver'}}
    })
    try:
        scheduler = Scheduler(
            paths=[src_path], config=config, frontend=frontend,
            output_dir=out_path, xmods=[out_path]
        )
    except CalledProcessError as e:
        if frontend == OMNI and not enable_imports:
            # Without taking care of imports, OMNI will fail to parse the files
            # because it is missing the xmod files for the header modules
            pytest.xfail('Without parsing imports, OMNI does not have the xmod for imported modules')
        raise e

    # Check for all items in the dependency graph
    expected_items = {'a_mod', 'b_mod', 'b_mod#type_b', 'c_mod#c'}
    assert expected_items == {item.name for item in scheduler.items}

    # Verify the type information for the imported symbols:
    # They will have enriched information if the imports are enabled
    # and deferred type otherwise
    type_b = scheduler['b_mod#type_b'].ir
    c_mod_c = scheduler['c_mod#c'].ir
    var_map = CaseInsensitiveDict(
        (v.name, v) for v in FindVariables().visit(c_mod_c.body)
    )
    global_a = var_map['global_a']
    b_dtype = var_map['b'].type.dtype

    if enable_imports:
        assert global_a.type.dtype is BasicType.INTEGER
        assert global_a.type.initial == '1'
        assert b_dtype.typedef is type_b
    else:
        assert global_a.type.dtype is BasicType.DEFERRED
        assert global_a.type.initial is None
        assert b_dtype.typedef is BasicType.DEFERRED


@pytest.mark.parametrize('frontend', available_frontends(skip={OMNI: "OMNI fails on missing module"}))
@pytest.mark.parametrize('external_kernel', [None, 'module', 'intfb'])
def test_scheduler_ignore_external_item(frontend, tmp_path, external_kernel):
    fcode_driver = f"""
module driver_mod
  contains
  subroutine driver(nlon, klev, nb, ydml_phy_mf)
    use parkind1, only: jpim, jprb
    use kernel1_mod, only: kernel1
    {'use kernel2_mod, only: kernel2' if external_kernel == 'module' else ''}
    implicit none
    type(model_physics_mf_type), intent(in) :: ydml_phy_mf
    integer(kind=jpim), intent(in) :: nlon
    integer(kind=jpim), intent(in) :: klev
    integer(kind=jpim), intent(in) :: nb
    integer(kind=jpim) :: jstart
    integer(kind=jpim) :: jend
    integer(kind=jpim) :: b
{'#include "kernel2.intfb.h"' if external_kernel == 'intfb' else ''}
    jstart = 1
    jend = nlon
    do b = 1, nb
        call kernel1()
        {'call kernel2()' if external_kernel else ''}
    enddo
  end subroutine driver
end module driver_mod
    """.strip()
    fcode_kernel1 = """
module kernel1_mod
  contains
  subroutine kernel1()
    use parkind1, only: jpim, jprb
  end subroutine kernel1
end module kernel1_mod
    """.strip()

    (tmp_path/'driver.F90').write_text(fcode_driver)
    (tmp_path/'kernel1_mod.F90').write_text(fcode_kernel1)

    config = {
        'default': {
            'mode': 'idem',
            'role': 'kernel',
            'expand': True,
            'strict': True,
            'ignore': ['parkind1'],
        },
        'routines': {
            'driver': {'role': 'driver'}
        }
    }

    if external_kernel:
        config['default']['generated'] = ['kernel2*']

    class Trafo(Transformation):

        item_filter = (ProcedureItem, ModuleItem)

        def transform_module(self, module, **kwargs):
            pass

        def transform_subroutine(self, routine, **kwargs):
            pass


    definitions = ()
    scheduler = Scheduler(paths=[tmp_path], config=SchedulerConfig.from_dict(config),
                          definitions=definitions, xmods=[tmp_path], frontend=frontend)

    expected_items = {'driver_mod#driver', 'parkind1', 'kernel1_mod#kernel1'}
    if external_kernel == 'module':
        expected_items.add('kernel2_mod#kernel2')
        expected_items.add('kernel2_mod')
    elif external_kernel == 'intfb':
        expected_items.add('#kernel2')

    assert expected_items == {item.name for item in scheduler.items}
    for item in scheduler.items:
        if item.name == 'parkind1':
            assert item.is_ignored
            assert isinstance(item, ExternalItem)
        if item.name.endswith('#kernel2'):
            assert not item.is_ignored
            assert isinstance(item, ExternalItem)
            assert item.is_generated

    if external_kernel:
        # this shouldn't fail because we marked the item as build-time generated
        scheduler.process(transformation=Trafo(), proc_strategy=ProcessingStrategy.PLAN)
        # this should fail
        with pytest.raises(RuntimeError):
            scheduler.process(transformation=Trafo())
    else:
        # check whether this works without any error
        scheduler.process(transformation=Trafo())


@pytest.mark.parametrize('proc_strategy', [ProcessingStrategy.PLAN, ProcessingStrategy.DEFAULT])
@pytest.mark.parametrize('with_filegraph', [True, False])
def test_scheduler_exception_handling(tmp_path, testdir, config, frontend, proc_strategy, with_filegraph):
    """
    Create a simple task graph from a single sub-project:

    projA: driverA -> kernelA -> compute_l1 -> compute_l2
                           |
                           | --> another_l1 -> another_l2
    """

    # Combine directory globbing and explicit file paths for lookup
    projA = testdir/'sources/projA'
    paths = [projA/'module', projA/'source/another_l1.F90', projA/'source/another_l2.F90']

    scheduler = Scheduler(
        paths=paths, includes=projA/'include', config=config,
        seed_routines='driverA', frontend=frontend, xmods=[tmp_path]
    )

    class RuntimeErrorTransformation(Transformation):

        traverse_file_graph = with_filegraph
        recurse_to_modules = with_filegraph
        recurse_to_procedures = with_filegraph

        item_filter = (ProcedureItem, ModuleItem)

        def __init__(self, fail_cls, fail_name):
            self.fail_cls = fail_cls
            self.fail_name = fail_name

        def transform_file(self, sourcefile, **kwargs):
            if self.fail_cls == Sourcefile:
                raise RuntimeError

        plan_file = transform_file

        def transform_module(self, module, **kwargs):
            if self.fail_cls == Module and self.fail_name == module.name.lower():
                raise RuntimeError

        plan_module = transform_module

        def transform_subroutine(self, routine, **kwargs):
            if self.fail_cls in (Subroutine, Function) and self.fail_name == routine.name.lower():
                raise RuntimeError

        plan_subroutine = transform_subroutine

    fail_list = [
        (Subroutine, 'compute_l1'),
        (Module, 'header_mod')
    ]
    if with_filegraph:
        fail_list += [(Sourcefile, '')]

    for fail_cls, fail_name in fail_list:
        message_pattern = f'RuntimeErrorTransformation.*?{fail_cls.__name__}.*?{fail_name.lower()}'

        with pytest.raises(TransformationError, match=message_pattern):
            scheduler.process(
                RuntimeErrorTransformation(fail_cls=fail_cls, fail_name=fail_name),
                proc_strategy=proc_strategy
            )

@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('qualified_import', ['driver', 'module', 'both', 'none'])
def test_scheduler_transient_typedef_imports(frontend, qualified_import, tmp_path):
    """ Test that use of transiently imported typedefs succeeds. """

    fcode_mod_a = f"""
module mod_a
    use mod_b{', only: my_type' if qualified_import in ('module', 'both') else ''}
    implicit none
end module mod_a
"""

    fcode_mod_b = """
module mod_b
    implicit none
    type my_type
        real(kind=4) :: a, b, x

        contains
        procedure :: add_a_b => my_type_add_a_b
    end type my_type

contains
    subroutine my_type_add_a_b(obj)
        type(my_type), intent(inout) :: obj

        obj%x = obj%a + obj%b
    end subroutine my_type_add_a_b
end module mod_b
"""

    fcode_driver = f"""
subroutine test_scheduler()
    use mod_a {', only: my_type' if qualified_import in ('driver', 'both') else ''}
    implicit none

    type(my_type) :: d

    d%a = 42.0
    d%a = 66.6
    call d%add_a_b()
end subroutine test_scheduler
"""
    src_path = tmp_path/'src'
    src_path.mkdir()
    (src_path/'mod_a.F90').write_text(fcode_mod_a)
    (src_path/'mod_b.F90').write_text(fcode_mod_b)
    (src_path/'driver.F90').write_text(fcode_driver)

    # Create the Scheduler
    config = SchedulerConfig.from_dict({
        'default': {
            'role': 'kernel', 'expand': True, 'strict': True, 'enable_imports': True
        },
        'routines': {'test_scheduler': {'role': 'driver'}}
    })
    if qualified_import not in ('driver', 'both'):
        with pytest.raises(RuntimeError):
            # NB: This raises a runtime error because we run in strict mode and cannot determine
            #     that `my_type` is defined via the import due to the absence of a list of
            #     symbols on the import
            _ = Scheduler(
                paths=[src_path], config=config, seed_routines='test_scheduler',
                frontend=frontend, xmods=[tmp_path], full_parse=False
            )
    else:
        scheduler = Scheduler(
            paths=[src_path], config=config, seed_routines='test_scheduler',
            frontend=frontend, xmods=[tmp_path], full_parse=False
        )

        if qualified_import == 'both':
            assert scheduler.items == (
                '#test_scheduler', 'mod_a', 'mod_a#my_type%add_a_b', 'mod_b#my_type'
            )
            assert scheduler.dependencies == (
                ('#test_scheduler', 'mod_a'),
                ('#test_scheduler', 'mod_a#my_type%add_a_b'),
                ('mod_a', 'mod_b#my_type')
            )
        else:
            assert scheduler.items == (
                '#test_scheduler', 'mod_a', 'mod_a#my_type%add_a_b', 'mod_b'
            )
            assert scheduler.dependencies == (
                ('#test_scheduler', 'mod_a'),
                ('#test_scheduler', 'mod_a#my_type%add_a_b'),
                ('mod_a', 'mod_b')
            )

        if frontend == OMNI:
            with pytest.raises(ParseError):
                # OMNI fails to read due to missing mod_b xmods
                scheduler._parse_items()
        else:
            scheduler._parse_items()

        call = FindNodes(ir.CallStatement).visit(scheduler['#test_scheduler'].ir.ir)[0]
        assert call.name == 'd%add_a_b'
        assert isinstance(call.name.parent.type.dtype, DerivedType)
        assert call.name.parent.type.dtype.name == 'my_type'

        if qualified_import in ('driver', 'both') and frontend == FP:
            # Enrichment does work correctly in this situation, providing the link to the typedef
            assert call.name.parent.type.dtype.typedef is (
                getattr(scheduler['mod_b#my_type'], 'ir', None) or scheduler['mod_b'].ir['my_type']
            )
        else:
            # The interprocedural annotations do _not_ currently enrich the type in this situation
            assert call.name.parent.type.dtype.typedef == BasicType.DEFERRED


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('qualified_import', ['driver', 'module', 'both', 'none'])
def test_scheduler_transient_procedure_imports(frontend, qualified_import, tmp_path):
    """ Test that use of transiently imported procedures succeeds. """

    fcode_mod_a = f"""
module mod_a
    use mod_b{', only: my_proc' if qualified_import in ('module', 'both') else ''}
    implicit none
end module mod_a
"""

    fcode_mod_b = """
module mod_b
    implicit none
contains
    subroutine my_proc
        print *,'hello world'
    end subroutine
end module mod_b
"""

    fcode_driver = f"""
subroutine test_scheduler()
    use mod_a {', only: my_proc' if qualified_import in ('driver', 'both') else ''}
    implicit none

    call my_proc
end subroutine test_scheduler
"""
    src_path = tmp_path/'src'
    src_path.mkdir()
    (src_path/'mod_a.F90').write_text(fcode_mod_a)
    (src_path/'mod_b.F90').write_text(fcode_mod_b)
    (src_path/'driver.F90').write_text(fcode_driver)

    # Create the Scheduler
    config = SchedulerConfig.from_dict({
        'default': {
            'role': 'kernel', 'expand': True, 'strict': True, 'enable_imports': True
        },
        'routines': {'test_scheduler': {'role': 'driver'}}
    })

    if qualified_import in ('module', 'none'):
        with pytest.raises(RuntimeError):
            # NB: This raises a runtime error because we run in strict mode and cannot determine
            #     that `proc` is defined via the import due to the absence of a list of
            #     symbols on the import
            _ = Scheduler(
                paths=[src_path], config=config, seed_routines='test_scheduler',
                frontend=frontend, xmods=[tmp_path], full_parse=False
            )
    else:
        scheduler = Scheduler(
            paths=[src_path], config=config, seed_routines='test_scheduler',
            frontend=frontend, xmods=[tmp_path], full_parse=False
        )

        # NB: mod_b is missing from the graph because the transient import does not
        # incur a dependency. See #630
        if qualified_import == 'both':
            assert scheduler.items == ('#test_scheduler', 'mod_a', 'mod_a#my_proc')
            assert scheduler.dependencies == (
                ('#test_scheduler', 'mod_a'), ('#test_scheduler', 'mod_a#my_proc')
            )
        else:
            assert scheduler.items == ('#test_scheduler', 'mod_a', 'mod_a#my_proc', 'mod_b')
            assert scheduler.dependencies == (
                ('#test_scheduler', 'mod_a'), ('#test_scheduler', 'mod_a#my_proc'),
                ('mod_a', 'mod_b')
            )

        if frontend == OMNI and qualified_import == 'both':
            with pytest.raises(CalledProcessError):
                # OMNI fails to read due to missing mod_b xmods
                scheduler._parse_items()
        else:
            scheduler._parse_items()

        # The interprocedural annotations do _not_ currently enrich the call in this situation
        call = FindNodes(ir.CallStatement).visit(scheduler['#test_scheduler'].ir.ir)[0]
        assert call.name == 'my_proc'
        assert call.name.type.imported
        assert not call.name.type.procedure


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('qualified_import', ['driver', 'module', 'both', 'none'])
def test_scheduler_transient_variable_imports(frontend, qualified_import, tmp_path):
    """ Test that use of transiently imported variables succeeds. """

    fcode_mod_a = f"""
module mod_a
    use mod_b{', only: my_var' if qualified_import in ('module', 'both') else ''}
    implicit none
end module mod_a
"""

    fcode_mod_b = """
module mod_b
    implicit none
    integer :: my_var
end module mod_b
"""

    fcode_driver = f"""
subroutine test_scheduler()
    use mod_a {', only: my_var' if qualified_import in ('driver', 'both') else ''}
    implicit none
    my_var = 1
end subroutine test_scheduler
"""
    src_path = tmp_path/'src'
    src_path.mkdir()
    (src_path/'mod_a.F90').write_text(fcode_mod_a)
    (src_path/'mod_b.F90').write_text(fcode_mod_b)
    (src_path/'driver.F90').write_text(fcode_driver)

    # Create the Scheduler
    config = SchedulerConfig.from_dict({
        'default': {
            'role': 'kernel', 'expand': True, 'strict': True, 'enable_imports': True
        },
        'routines': {'test_scheduler': {'role': 'driver'}}
    })

    scheduler = Scheduler(
        paths=[src_path], config=config, seed_routines='test_scheduler',
        frontend=frontend, xmods=[tmp_path], full_parse=False
    )

    # NB: Global variable imports as a dependency are established directly on the import
    #     statements, thus the transient dependency is captured
    assert scheduler.items == ('#test_scheduler', 'mod_a', 'mod_b')
    assert scheduler.dependencies == (
        ('#test_scheduler', 'mod_a'), ('mod_a', 'mod_b')
    )

    scheduler._parse_items()

    # NB: We are able to correctly propagate the type information through
    #     multiple import layers
    my_var = scheduler['#test_scheduler'].ir.body.body[0].lhs
    assert isinstance(my_var, Scalar)
    assert my_var.type.dtype == BasicType.INTEGER
    assert my_var.type.imported

    # NB: This links to the module it imports from, not where it is defined
    assert my_var.type.module is scheduler['mod_a'].ir

    if qualified_import in ('module', 'both'):
        # The symbol only exists in mod_a if it's listed on the import
        mod_a_var = scheduler['mod_a'].ir.imported_symbol_map['my_var']
        assert isinstance(mod_a_var, Scalar)
        assert mod_a_var.type.dtype == BasicType.INTEGER
        assert mod_a_var.type.imported
        assert mod_a_var.type.module is scheduler['mod_b'].ir


def test_scheduler_module_interface_import(frontend, tmp_path):
    """ Test module-level imports of interface routines. """

    fcode_mod_a = """
module mod_a
    use mod_b, only: inner_type, intfb_routine
    implicit none
    type outer_type
        type(inner_type) :: da
    end type outer_type
end module mod_a
"""

    fcode_mod_b = """
module mod_b
    implicit none
    type inner_type
        integer :: da, db
    end type inner_type

    interface intfb_routine
    module procedure routine_a, routine_b
    end interface intfb_routine

    contains
    subroutine routine_a
    end subroutine routine_a

    subroutine routine_b
    end subroutine routine_b
end module mod_b
"""

    fcode_driver = """
subroutine test_scheduler()
    use mod_a, only: outer_type
    implicit none
    type(outer_type) :: da

    da%da%da = 42.0
end subroutine test_scheduler
"""
    src_path = tmp_path/'src'
    src_path.mkdir()
    (src_path/'mod_a.F90').write_text(fcode_mod_a)
    (src_path/'mod_b.F90').write_text(fcode_mod_b)
    (src_path/'driver.F90').write_text(fcode_driver)

    # Create the Scheduler
    config = SchedulerConfig.from_dict({
        'default': {
            'role': 'kernel', 'expand': True, 'strict': True, 'enable_imports': True
        },
        'routines': {'test_scheduler': {'role': 'driver'}}
    })

    scheduler = Scheduler(
        paths=[src_path], config=config, seed_routines='test_scheduler',
        frontend=frontend, xmods=[tmp_path], full_parse=False
    )

    assert scheduler.items == (
        '#test_scheduler', 'mod_a#outer_type', 'mod_b#inner_type'
    )
    assert scheduler.dependencies == (
        ('#test_scheduler', 'mod_a#outer_type'),
        ('mod_a#outer_type', 'mod_b#inner_type')
    )
loki-ecmwf-0.3.6/loki/batch/transformation.py0000664000175000017500000005646415167130205021417 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Base class definition for :ref:`transformations`.
"""
from pprint import pformat

from loki.module import Module
from loki.sourcefile import Sourcefile
from loki.subroutine import Subroutine
from loki.batch.item import ProcedureItem, ModuleItem


__all__ = ['Transformation', 'TransformationError']


class TransformationError(Exception):
    """
    Exception raised when a :any:`Transformation` encounters an
    error while processing an IR node

    Parameters
    ----------
    message : str
        Description of the error
    transformation : subclass of :any:`Transformation`
        The class of the transformation in which the error occured
    source : :any:`Sourcefile` or :any:`ProgramUnit`
        The object that was processed when the error occured
    """

    def __init__(self, message, transformation, source):
        self.message = message
        self.transformation = transformation
        self.source = source

    def __str__(self):
        return f"Applying {self.transformation.__name__} to {self.source} failed: {self.message}"


class Transformation:
    """
    Base class for source code transformations that manipulate source
    items like :any:`Subroutine` or :any:`Module` in place via the
    ``item.apply(transform)`` method.

    The source transformations to be applied should be defined in the
    following class-specific methods:

    * :meth:`transform_subroutine`
    * :meth:`transform_module`
    * :meth:`transform_file`

    The generic dispatch mechanism behind the :meth:`apply` method will ensure
    that all hierarchies of the data model are traversed and the specific
    method for each level is applied, if the relevant recursion mode is enabled
    in the transformation's manifest (:attr:`recurse_to_modules` and/or
    :attr:`recurse_to_procedures`). Note that in :any:`Sourcefile` objects,
    :any:`Module` members will be traversed before standalone :any:`Subroutine` objects.

    Classes inheriting from :any:`Transformation` may configure the
    invocation and behaviour during batch processing via a predefined
    set of class attributes. These flags determine the underlying
    graph traversal when processing complex call trees and determine
    how the transformations are invoked for a given type of :any:`Item` in the
    :any:`Scheduler`.

    Attributes
    ----------
    reverse_traversal : bool
        Forces scheduler traversal in reverse order from the leaf
        nodes upwards (default: ``False``).
    traverse_file_graph : bool
         Apply :any:`Transformation` to the :any:`Sourcefile` object
         corresponding to the :any:`Item` being processed, instead of
         the program unit in question (default: ``False``).
    item_filter : bool
        Filter by graph node types to prune the graph and change connectivity.
        By default, only calls to :any:`Subroutine` items are used to construct
        the graph.
    recurse_to_modules : bool
        Apply transformation to all :any:`Module` objects when processing
        a :any:`Sourcefile` (default ``False``)
    recurse_to_procedures : bool
        Apply transformation to all :any:`Subroutine` objects when processing
        :any:`Sourcefile` or :any:`Module` objects (default ``False``)
    recurse_to_internal_procedures : bool
        Apply transformation to all internal :any:`Subroutine` objects
        when processing :any:`Subroutine` objects (default ``False``)
    process_ignored_items : bool
        Apply transformation to "ignored" :any:`Item` objects for analysis.
        This might be needed if IPO-information needs to be passed across
        library boundaries.
    renames_items : bool
        Indicates to the :any:`Scheduler` that a transformation may change the name of
        the IR node corresponding to the processed :any:`Item` (e.g., by renaming
        a module or subroutine). The transformation has to take care of renaming
        processed the :any:`Item` itself but the :any:`Scheduler` will update its
        internal cache after the transformation has been applied (default ``False``).
    creates_items : bool
        Indicates to the :any:`Scheduler` that a transformation may create new
        scopes or other dependency nodes (e.g., by adding new routines to a
        module). The scheduler will run a discovery step after the transformation has
        been applied to include these new items in the dependency graph (default ``False``).
    """

    # Forces scheduler traversal in reverse order from the leaf nodes upwards
    reverse_traversal = False

    # Traverse a graph of Sourcefile options corresponding to scheduler items
    traverse_file_graph = False

    # Filter certain graph nodes to prune the graph and change connectivity
    item_filter = ProcedureItem  # This can also be a tuple of types

    # Recursion behaviour when invoking transformations via ``trafo.apply()``
    recurse_to_modules = False  # Recurse from Sourcefile to Module
    recurse_to_procedures = False  # Recurse from Sourcefile/Module to subroutines and functions
    recurse_to_internal_procedures = False  # Recurse to subroutines in ``contains`` clause

    # Option to process "ignored" items for analysis
    process_ignored_items = False

    # Control Scheduler cache update requirements after applying the transformation
    renames_items = False
    creates_items = False

    def __str__(self):
        """ Pretty-print transformation details """
        attrs = '\n    '.join(pformat(self.__dict__).splitlines())
        header = f'<{self.__class__.__name__}  [{self.__class__.__module__}]'
        return f'{header}\n    {attrs}>'

    def transform_subroutine(self, routine, **kwargs):
        """
        Defines the transformation to apply to :any:`Subroutine` items.

        For transformations that modify :any:`Subroutine` objects, this method
        should be implemented. It gets called via the dispatch method
        :meth:`apply`.

        Parameters
        ----------
        routine : :any:`Subroutine`
            The subroutine to be transformed.
        **kwargs : optional
            Keyword arguments for the transformation.
        """

    def plan_subroutine(self, routine, **kwargs):
        """
        Define the planning steps to apply for :any:`Subroutine` items.

        For transformations that modify the dependencies of :data:`routine`
        (e.g., adding new procedure calls, inlining calls, renaming the interface)
        this should be implemented. It gets called via the dispatch method :meth:`apply`
        if the optional ``plan_mode`` argument is set to `True`.

        Parameters
        ----------
        routine : :any:`Subroutine`
            The subroutine to be transformed.
        **kwargs : optional
            Keyword arguments for the transformation.
        """

    def transform_module(self, module, **kwargs):
        """
        Defines the transformation to apply to :any:`Module` items.

        For transformations that modify :any:`Module` objects, this method
        should be implemented. It gets called via the dispatch method
        :meth:`apply`.

        Parameters
        ----------
        module : :any:`Module`
            The module to be transformed.
        **kwargs : optional
            Keyword arguments for the transformation.
        """

    def plan_module(self, module, **kwargs):
        """
        Define the planning steps to apply for :any:`Module` items.

        For transformations that modify the dependencies or definitions of :data:`module`
        (e.g., renaming the module, adding new subroutines, adding or removing imports)
        this should be implemented. It gets called via the dispatch method :meth:`apply`
        if the optional ``plan_mode`` argument is set to `True`.

        Parameters
        ----------
        module : :any:`Module`
            The module to be transformed.
        **kwargs : optional
            Keyword arguments for the transformation.
        """

    def transform_file(self, sourcefile, **kwargs):
        """
        Defines the transformation to apply to :any:`Sourcefile` items.

        For transformations that modify :any:`Sourcefile` objects, this method
        should be implemented. It gets called via the dispatch method
        :meth:`apply`.

        Parameters
        ----------
        sourcefile : :any:`Sourcefile`
            The sourcefile to be transformed.
        **kwargs : optional
            Keyword arguments for the transformation.
        """

    def plan_file(self, sourcefile, **kwargs):
        """
        Define the planning steps to apply for :any:`Sourcefile` items.

        For transformations that modify the definitions or dependencies of :data:`sourcefile`
        this should be implemented. It gets called via the dispatch method :meth:`apply`
        if the optional ``plan_mode`` argument is set to `True`.

        Parameters
        ----------
        sourcefile : :any:`Sourcefile`
            The sourcefile to be transformed.
        **kwargs : optional
            Keyword arguments for the transformation.
        """

    def apply(self, source, post_apply_rescope_symbols=False, plan_mode=False, **kwargs):
        """
        Dispatch method to apply transformation to :data:`source`.

        It dispatches to one of the type-specific dispatch methods
        :meth:`apply_file`, :meth:`apply_module`, or :meth:`apply_subroutine`.

        Parameters
        ----------
        source : :any:`Sourcefile` or :any:`Module` or :any:`Subroutine`
            The source item to transform.
        post_apply_rescope_symbols : bool, optional
            Call ``rescope_symbols`` on :data:`source` after applying the
            transformation to clean up any scoping issues.
        **kwargs : optional
            Keyword arguments that are passed on to the methods defining the
            actual transformation.
        """
        if isinstance(source, Sourcefile):
            self.apply_file(source, plan_mode=plan_mode, **kwargs)

        if isinstance(source, Subroutine):
            self.apply_subroutine(source, plan_mode=plan_mode, **kwargs)

        if isinstance(source, Module):
            self.apply_module(source, plan_mode=plan_mode, **kwargs)

        if not plan_mode:
            self.post_apply(source, rescope_symbols=post_apply_rescope_symbols)

    def apply_file(self, sourcefile, plan_mode=False, **kwargs):
        """
        Apply transformation to all items in :data:`sourcefile`.

        This calls :meth:`transform_file` or, if :data:`plan_mode` is enabled,
        calls :meth:`plan_file`.

        If the :attr:`recurse_to_modules` class property is set, it
        will also invoke :meth:`apply` on all :any:`Module` objects in
        this :any:`Sourcefile`. Likewise, if
        :attr:`recurse_to_procedures` is set, it will invoke
        :meth:`apply` on all free :any:`Subroutine` objects in this
        :any:`Sourcefile`.

        Parameters
        ----------
        sourcefile : :any:`Sourcefile`
            The file to transform.
        plan_mode : bool, optional
            If enabled, apply planning mode.
        **kwargs : optional
            Keyword arguments that are passed on to transformation methods.
        """
        if not isinstance(sourcefile, Sourcefile):
            raise TypeError('Transformation.apply_file can only be applied to Sourcefile object')

        item = kwargs.pop('item', None)
        items = kwargs.pop('items', None)
        role = kwargs.pop('role', None)
        targets = kwargs.pop('targets', None)

        # Apply file-level transformations
        try:
            if plan_mode:
                self.plan_file(sourcefile, item=item, role=role, targets=targets, items=items, **kwargs)
            else:
                if sourcefile._incomplete:
                    raise RuntimeError('Transformation.apply_file requires Sourcefile to be complete')

                self.transform_file(sourcefile, item=item, role=role, targets=targets, items=items, **kwargs)
        except Exception as e:
            raise TransformationError(
                message=f'Error in Sourcefile {sourcefile.path!s} -- {e!s}',
                transformation=type(self), source=sourcefile
            ) from e

        # Recurse to modules, if configured
        if self.recurse_to_modules:
            if items:
                # Recursion into all module items in the current file
                for item in items:
                    if isinstance(item, ModuleItem):
                        # Currently, we don't get the role for modules correct as 'driver'
                        # if the role overwrite in the config marks only specific procedures
                        # as driver, but everything else as kernel by default. This is in particular the
                        # case, if the ModuleWrapTransformation is applied to a driver routine.
                        # For that reason, we set the role as unspecified (None) if not the role is
                        # universally equal throughout the module
                        item_role = item.role
                        definitions_roles = {_it.role for _it in items if _it.scope_name == item.name}
                        if definitions_roles != {item_role}:
                            item_role = None

                        # Provide the list of items that belong to this module
                        item_items = tuple(_it for _it in items if _it.scope is item.ir)

                        try:
                            if plan_mode:
                                self.plan_module(
                                    item.ir, item=item, role=item_role, targets=item.targets, items=item_items, **kwargs
                                )
                            else:
                                self.transform_module(
                                    item.ir, item=item, role=item_role, targets=item.targets, items=item_items, **kwargs
                                )
                        except Exception as e:
                            raise TransformationError(
                                message=f'Error in Module {item.ir.name} -- {e!s}',
                                transformation=type(self), source=item.ir
                            ) from e
            else:
                for module in sourcefile.modules:
                    try:
                        if plan_mode:
                            self.plan_module(module, item=item, role=role, targets=targets, items=items, **kwargs)
                        else:
                            self.transform_module(module, item=item, role=role, targets=targets, items=items, **kwargs)
                    except Exception as e:
                        raise TransformationError(
                            message=f'Error in Module {module.name} -- {e!s}',
                            transformation=type(self), source=module
                        ) from e

        # Recurse into procedures, if configured
        if self.recurse_to_procedures:
            if items:
                # Recursion into all subroutine items in the current file
                for item in items:
                    if isinstance(item, ProcedureItem):
                        try:
                            if plan_mode:
                                self.plan_subroutine(
                                    item.ir, item=item, role=item.role, targets=item.targets, **kwargs
                                )
                            else:
                                self.transform_subroutine(
                                    item.ir, item=item, role=item.role, targets=item.targets, **kwargs
                                )
                        except Exception as e:
                            raise TransformationError(
                                message=f'Error in Procedure {item.ir.name} -- {e!s}',
                                transformation=type(self), source=item.ir
                            ) from e
            else:
                for routine in sourcefile.all_subroutines:
                    try:
                        if plan_mode:
                            self.plan_subroutine(routine, item=item, role=role, targets=targets, **kwargs)
                        else:
                            self.transform_subroutine(routine, item=item, role=role, targets=targets, **kwargs)
                    except Exception as e:
                        raise TransformationError(
                            message=f'Error in Procedure {routine.name} -- {e!s}',
                            transformation=type(self), source=routine
                        ) from e

    def apply_subroutine(self, subroutine, plan_mode=False, **kwargs):
        """
        Apply transformation to a given :any:`Subroutine` object and its members.

        This calls :meth:`transform_subroutine` or, if :data:`plan_mode` is enabled,
        calls :meth:`plan_subroutine`.

        If the :attr:`recurse_to_member_procedures` class property is
        set, it will also invoke :meth:`apply` on all
        :any:`Subroutine` objects in the ``contains`` clause of this
        :any:`Subroutine`.

        Parameters
        ----------
        subroutine : :any:`Subroutine`
            The subroutine to transform.
        plan_mode : bool, optional
            If enabled, apply planning mode.
        **kwargs : optional
            Keyword arguments that are passed on to transformation methods.
        """
        if not isinstance(subroutine, Subroutine):
            raise TypeError('Transformation.apply_subroutine can only be applied to Subroutine object')

        # Apply the actual transformation for subroutines
        try:
            if plan_mode:
                self.plan_subroutine(subroutine, **kwargs)
            else:
                if subroutine._incomplete:
                    raise RuntimeError('Transformation.apply_subroutine requires Subroutine to be complete')

                self.transform_subroutine(subroutine, **kwargs)
        except Exception as e:
            raise TransformationError(
                message=f'Error in Procedure {subroutine.name} -- {e!s}',
                transformation=type(self), source=subroutine
            ) from e

        # Recurse to internal procedures
        if self.recurse_to_internal_procedures:
            for routine in subroutine.subroutines:
                self.apply_subroutine(routine, plan_mode=plan_mode, **kwargs)

    def apply_module(self, module, plan_mode=False, **kwargs):
        """
        Apply transformation to a given :any:`Module` object and its members.

        This calls :meth:`transform_module` or, if :data:`plan_mode` is enabled,
        calls :meth:`plan_module`.

        If the :attr:`recurse_to_procedures` class property is set,
        it will also invoke :meth:`apply` on all :any:`Subroutine`
        objects in the ``contains`` clause of this :any:`Module`.

        Parameters
        ----------
        module : :any:`Module`
            The module to transform.
        plan_mode : bool, optional
            If enabled, apply planning mode.
        **kwargs : optional
            Keyword arguments that are passed on to transformation methods.
        """
        if not isinstance(module, Module):
            raise TypeError('Transformation.apply_module can only be applied to Module object')

        # Apply the actual transformation for modules
        try:
            if plan_mode:
                self.plan_module(module, **kwargs)
            else:
                if module._incomplete:
                    raise RuntimeError('Transformation.apply_module requires Module to be complete')

                self.transform_module(module, **kwargs)
        except Exception as e:
            raise TransformationError(
                message=f'Error in Module {module.name} -- {e!s}',
                transformation=type(self), source=module
            ) from e

        # Recurse to procedures contained in this module
        if self.recurse_to_procedures:
            for routine in module.subroutines:
                self.apply_subroutine(routine, plan_mode=plan_mode, **kwargs)

    def post_apply(self, source, rescope_symbols=False):
        """
        Dispatch method for actions to be carried out after applying a transformation
        to :data:`source`.

        It dispatches to one of the type-specific dispatch methods
        :meth:`post_apply_file`, :meth:`post_apply_module`, or :meth:`post_apply_subroutine`.

        Parameters
        ----------
        source : :any:`Sourcefile` or :any:`Module` or :any:`Subroutine`
            The source item to transform.
        rescope_symbols : bool, optional
            Call ``rescope_symbols`` on :data:`source`
        """
        if isinstance(source, Sourcefile):
            self.post_apply_file(source, rescope_symbols)

        if isinstance(source, Subroutine):
            self.post_apply_subroutine(source, rescope_symbols)

        if isinstance(source, Module):
            self.post_apply_module(source, rescope_symbols)

    def post_apply_file(self, sourcefile, rescope_symbols):
        """
        Apply actions after applying a transformation to :data:`sourcefile`.

        Parameters
        ----------
        sourcefile : :any:`Sourcefile`
            The file to transform.
        rescope_symbols : bool
            Call ``rescope_symbols`` on modules and subroutines in :data:`sourcefile`
        """
        if not isinstance(sourcefile, Sourcefile):
            raise TypeError('Transformation.post_apply_file can only be applied to Sourcefile object')

        for module in sourcefile.modules:
            self.post_apply_module(module, rescope_symbols)

        for routine in sourcefile.subroutines:
            self.post_apply_subroutine(routine, rescope_symbols)


    def post_apply_subroutine(self, subroutine, rescope_symbols):
        """
        Apply actions after applying a transformation to :data:`subroutine`.

        Parameters
        ----------
        subroutine : :any:`Subroutine`
            The file to transform.
        rescope_symbols : bool
            Call ``rescope_symbols`` on :data:`subroutine`
        """
        if not isinstance(subroutine, Subroutine):
            raise TypeError('Transformation.post_apply_subroutine can only be applied to Subroutine object')

        for routine in subroutine.members:
            self.post_apply_subroutine(routine, False)

        # Ensure all objects in the IR are in the subroutine's or a parent scope.
        if rescope_symbols:
            subroutine.rescope_symbols()

    def post_apply_module(self, module, rescope_symbols):
        """
        Apply actions after applying a transformation to :data:`module`.

        Parameters
        ----------
        module : :any:`Module`
            The file to transform.
        rescope_symbols : bool
            Call ``rescope_symbols`` on :data:`module`
        """
        if not isinstance(module, Module):
            raise TypeError('Transformation.post_apply_module can only be applied to Module object')

        for routine in module.subroutines:
            self.post_apply_subroutine(routine, False)

        # Ensure all objects in the IR are in the module's scope.
        if rescope_symbols:
            module.rescope_symbols()
loki-ecmwf-0.3.6/loki/batch/configure.py0000664000175000017500000005151615167130205020323 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import fnmatch
from itertools import accumulate
from pathlib import Path
import re

from loki.dimension import Dimension
from loki.tools import as_tuple, CaseInsensitiveDict, load_module
from loki.types import ProcedureType, DerivedType
from loki.logging import error, warning


__all__ = [
    'SchedulerConfig', 'TransformationConfig', 'PipelineConfig',
    'ItemConfig'
]


class SchedulerConfig:
    """
    Configuration object for the :any:`Scheduler`

    It encapsulates config options for scheduler behaviour, with default options
    and item-specific overrides, as well as transformation-specific parameterisations.

    The :any:`SchedulerConfig` can be created either from a raw dictionary or configuration file.

    Parameters
    ----------
    default : dict
        Default options for each item
    routines : dict of dicts or list of dicts
        Dicts with routine-specific options.
    dimensions : dict of dicts or list of dicts
        Dicts with options to define :any:`Dimension` objects.
    disable : list of str
        Subroutine names that are entirely disabled and will not be
        added to either the callgraph that we traverse, nor the
        visualisation. These are intended for utility routines that
        pop up in many routines but can be ignored in terms of program
        control flow, like ``flush`` or ``abort``.
    enable_imports : bool
       Disable the inclusion of module imports as scheduler dependencies.
    transformation_configs : dict
        Dicts with transformation-specific options
    frontend_args : dict
        Dicts with file-specific frontend options
    """

    def __init__(
            self, default, routines, disable=None, dimensions=None,
            transformation_configs=None, pipeline_configs=None,
            enable_imports=False, frontend_args=None
    ):
        self.default = default
        self.disable = as_tuple(disable)
        self.dimensions = dimensions
        self.enable_imports = enable_imports

        self.routines = CaseInsensitiveDict(routines)
        self.transformation_configs = transformation_configs
        self.pipeline_configs = pipeline_configs
        self.frontend_args = frontend_args

        # Resolve the dimensions for trafo configurations
        for cfg in self.transformation_configs.values():
            cfg.resolve_dimensions(dimensions)

        # Instantiate Transformation objects
        self.transformations = {
            name: config.instantiate() for name, config in self.transformation_configs.items()
        }

        # Instantiate Pipeline objects
        self.pipelines = {
            name: config.instantiate(transformation_map=self.transformations)
            for name, config in self.pipeline_configs.items()
        }

    @classmethod
    def from_dict(cls, config):
        default = config.get('default', {})
        routines = config.get('routines', [])
        disable = default.get('disable', None)
        enable_imports = default.get('enable_imports', False)

        # Add any dimension definitions contained in the config dict
        dimensions = config.get('dimensions', {})
        dimensions = {k: Dimension(**d) for k, d in dimensions.items()}

        # Create config objects for Transformation configurations
        transformation_configs = config.get('transformations', {})
        transformation_configs = {
            name: TransformationConfig(name=name, **cfg)
            for name, cfg in transformation_configs.items()
        }
        frontend_args = config.get('frontend_args', {})

        pipeline_configs = config.get('pipelines', {})
        pipeline_configs = {
            name: PipelineConfig(name=name, **cfg)
            for name, cfg in pipeline_configs.items()
        }

        return cls(
            default=default, routines=routines, disable=disable, dimensions=dimensions,
            transformation_configs=transformation_configs, pipeline_configs=pipeline_configs,
            frontend_args=frontend_args, enable_imports=enable_imports
        )

    @classmethod
    def from_file(cls, path):
        try:
            import tomllib as toml  # pylint: disable=import-outside-toplevel
        except ModuleNotFoundError:
            import tomli as toml  # pylint: disable=import-outside-toplevel

        # Load configuration file and process options
        with Path(path).open('rb') as f:
            config = toml.load(f)

        return cls.from_dict(config)

    @staticmethod
    def match_item_keys(item_name, keys, use_pattern_matching=False, match_item_parents=False):
        """
        Helper routine to match an item name against config keys.

        The :data:`item_name` may be a fully-qualified name of an :any:`Item`, which may
        include a scope, or only a partial, e.g., local name part. This is then compared
        against a provided list of keys as they may appear in a config property (for
        example an ``ignore`` or ``disable`` list).

        By default, the fully qualified name and the local name are matched.
        Optionally, the matching can be be extended to parent scopes in the item name,
        which is particularly useful if, e.g., the item name of a module member is checked
        against an exclusion list, which lists the module name. This is enabled via
        :data:`match_item_parents`.

        The matching can take patterns in the :data:`keys` into account, allowing for the
        pattern syntax supported by :any:`fnmatch`.
        This requires enabling :data:`use_pattern_matching`.

        Parameters
        ----------
        item_name : str
            The item name to check for matches
        keys : list of str
            The config key values to check for matches
        use_pattern_matching : bool, optional
            Allow patterns in :data:`keys` when matching (default ``False``)
        match_item_parents : bool, optional
            Match also name parts of parent scopes in :data:`item_name`

        Returns
        -------
        tuple of str
            The entries in :data:`keys` that :data:`item_name` matched
        """
        # Sanitize the item name
        item_name = item_name.lower()
        name_parts = item_name.split('#')
        if len(name_parts) == 1:
            scope_name, local_name = '', name_parts[0]
        elif len(name_parts) == 2:
            scope_name, local_name = name_parts
        elif len(name_parts) == 3:
            scope_name, local_name = name_parts[0], f'{name_parts[1]}#{name_parts[2]}'
        else:
            raise ValueError(f'Invalid item name {item_name}: Only one or two `#` are allowed in the name.')

        # Build the variations of item name to match
        item_names = {item_name, local_name}
        if match_item_parents:
            if scope_name:
                item_names.add(scope_name)
            if '%' in local_name:
                type_name, *member_names = local_name.split('%')
                item_names |= {
                    name
                    for partial_name in accumulate(member_names, lambda l, r: f'{l}%{r}', initial=type_name)
                    for name in (f'{scope_name}#{partial_name}', partial_name)
                }

        # Match against keys
        keys = tuple(key.lower() for key in as_tuple(keys))
        if use_pattern_matching:
            return tuple(key for key in keys if fnmatch.filter(item_names, key))
        return tuple(key for key in keys if key in item_names)

    def create_item_config(self, name):
        """
        Create the bespoke config `dict` for an :any:`Item`

        The resulting config object contains the :attr:`default`
        values and any item-specific overwrites and additions.
        """
        keys = self.match_item_keys(name, self.routines)
        if len(keys) > 1:
            if self.default.get('strict'):
                raise RuntimeError(f'{name} matches multiple config entries: {", ".join(keys)}')
            warning(f'{name} matches multiple config entries: {", ".join(keys)}')
        item_conf = self.default.copy()
        for key in keys:
            item_conf.update(self.routines[key])
        return item_conf

    def create_frontend_args(self, path, default_args):
        """
        Create bespoke ``frontend_args`` to pass to the constructor
        or ``make_complete`` method for a file

        The resulting `dict` contains overwrites that have been provided
        in the :attr:`frontend_args` of the config.

        Parameters
        ----------
        path : str or pathlib.Path
            The file path for which to create the frontend arguments. This
            can be a fully-qualified path or include :any:`fnmatch`-compatible
            patterns.
        default_args : dict
            The default options to use. Only keys that are explicitly overriden
            for the file in the scheduler config are updated.

        Returns
        -------
        dict
            The frontend arguments, with file-specific overrides of
            :data:`default_args` if specified in the Scheduler config.
        """
        path = str(path).lower()
        frontend_args = default_args.copy()
        for key, args in (self.frontend_args or {}).items():
            pattern = key.lower() if key[0] == '/' else f'*{key}'.lower()
            if fnmatch.fnmatch(path, pattern):
                frontend_args.update(args)
                return frontend_args
        return frontend_args

    def is_disabled(self, name):
        """
        Check if the item with the given :data:`name` is marked as `disabled`
        """
        return bool(self.match_item_keys(name, self.disable, use_pattern_matching=True, match_item_parents=True))


class TransformationConfig:
    """
    Configuration object for :any:`Transformation` instances that can
    be used to create :any:`Transformation` objects from dictionaries
    or a config file.

    Parameters
    ----------
    name : str
        Name of the transformation object
    module : str
        Python module from which to load the transformation class
    classname : str, optional
        Name of the class to look for when instantiating the transformation.
        If not provided, ``name`` will be used instead.
    path : str or Path, optional
        Path to add to the sys.path before attempting to load the ``module``
    options : dict
        Dicts of options that define the transformation behaviour.
        These options will be passed as constructor arguments using
        keyword-argument notation.
    """

    _re_dimension = re.compile(r'\%dimensions\.(.*?)\%')

    def __init__(self, name, module, classname=None, path=None, options=None):
        self.name = name
        self.module = module
        self.classname = classname or self.name
        self.path = path
        self.options = dict(options) if options else {}

    def resolve_dimensions(self, dimensions):
        """
        Substitute :any:`Dimension` objects for placeholder strings.

        The format of the string replacement matches the TOML
        configuration.  It will attempt to replace ``%dimensions.dim_name%``
        with a :any:`Dimension` found in :data:`dimensions`:

        Parameters
        ----------
        dimensions : dict
            Dict matching string to pre-configured :any:`Dimension` objects.
        """
        for key, val in self.options.items():
            if not isinstance(val, str):
                continue

            matches = self._re_dimension.findall(val)
            matches = tuple(dimensions[m] for m in as_tuple(matches))
            if matches:
                self.options[key] = matches[0] if len(matches) == 1 else matches

    def instantiate(self):
        """
        Creates instantiated :any:`Transformation` object from stored config options.
        """
        # Load the module that contains the transformations
        mod = load_module(self.module, path=self.path)

        # Check for and return Transformation class
        if not hasattr(mod, self.classname):
            raise RuntimeError(f'Failed to load Transformation class: {self.classname}')

        # Attempt to instantiate transformation from config
        try:
            transformation = getattr(mod, self.classname)(**self.options)
        except TypeError as e:
            error(f'[Loki::Transformation] Failed to instiate {self.classname} from configuration')
            error(f'    Options passed: {self.options}')
            raise e

        return transformation


class PipelineConfig:
    """
    Configuration object for custom :any:`Pipeline` instances that can
    be used to create pipelines from other transformations stored in
    the config.

    Parameters
    ----------
    name : str
        Name of the transformation object
    transformations : list of str
        List of transformation names for which to look when
        instnatiating thie pipeline.
    """


    def __init__(self, name, transformations=None):
        self.name = name
        self.transformations = transformations or []

    def instantiate(self, transformation_map=None):
        """
        Creates a custom :any:`Pipeline` object from instantiated
        :any:`Transformation` or :any:`Pipeline` objects in the given map.
        """
        from loki.batch.pipeline import Pipeline  # pylint: disable=import-outside-toplevel,cyclic-import

        # Create an empty pipeline and add from the map
        pipeline = Pipeline(classes=())
        for name in self.transformations:
            if name not in transformation_map:
                error(f'[Loki::Pipeline] Failed to find {name} in transformation config!')
                raise RuntimeError(f'[Loki::Pipeline] Transformation {name} not found!')

            # Use native notation to append transformation/pipeline,
            # so that we may use them interchangably in config
            pipeline += transformation_map[name]

        return pipeline


class ItemConfig:
    """
    :any:`Item`-specific configuration settings.

    This is filled by inheriting values from :any:`SchedulerConfig.default`
    and applying explicit specialisations provided for an item in the config
    file or dictionary.

    Attributes
    ----------
    role : str or None
        Role in the transformation chain, typically ``'driver'`` or ``'kernel'``
    mode : str or None
        Transformation "mode" to pass to transformations applied to the item
    expand : bool (default: False)
        Flag to enable/disable expansion of children under this node
    strict : bool (default: True)
        Flag controlling whether to fail if dependency items cannot be found
    replicate : bool (default: False)
        Flag indicating whether to mark item as "replicated" in call graphs
    disable : tuple
        List of dependency names that are completely ignored and not reported as
        dependencies by the item. Useful to exclude entire call trees or utility
        routines.
    block : tuple
        List of dependency names that should not be added to the scheduler graph
        as dependencies and are not processed as targets. Note that these might still
        be shown in the graph visualisation.
    ignore : tuple
        List of dependency names that should not be added to the scheduler graph
        as dependencies (and are therefore not processed by transformations)
        but are treated in the current item as targets. This facilitates processing
        across build targets, where, e.g., caller and callee-side are transformed in
        separate Loki passes.
    enrich : tuple
        List of program units that should still be looked up and used to "enrich"
        IR nodes (e.g., :any:`ProcedureSymbol` in :any:`CallStatement`) in this item
        for inter-procedural transformation passes.

    Parameters
    ----------
    config : dict
        The config values for the :any:`Item`. Typically generated by
        :any:`SchedulerConfig.create_item_config`.
    """

    def __init__(self, config):
        self.config = config or {}
        super().__init__()

    @property
    def role(self):
        """
        Role in the transformation chain, for example ``'driver'`` or ``'kernel'``
        """
        return self.config.get('role', None)

    @property
    def mode(self):
        """
        Transformation "mode" to pass to the transformation
        """
        return self.config.get('mode', None)

    @property
    def expand(self):
        """
        Flag to trigger expansion of children under this node
        """
        return self.config.get('expand', False)

    @property
    def lib(self):
        """
        Compile unit/library this item belongs to
        """
        return self.config.get('lib', None)

    @property
    def strict(self):
        """
        Flag controlling whether to strictly fail if source file cannot be parsed
        """
        return self.config.get('strict', True)

    @property
    def replicate(self):
        """
        Flag indicating whether to mark item as "replicated" in call graphs
        """
        return self.config.get('replicate', False)

    @property
    def disable(self):
        """
        List of sources to completely exclude from expansion and the source tree.
        """
        return self.config.get('disable', tuple())

    @property
    def block(self):
        """
        List of sources to block from processing, but add to the
        source tree for visualisation.
        """
        return self.config.get('block', tuple())

    @property
    def ignore(self):
        """
        List of sources to expand but ignore during processing
        """
        return self.config.get('ignore', tuple())

    @property
    def enrich(self):
        """
        List of sources to to use for IPA enrichment
        """
        return self.config.get('enrich', tuple())

    @property
    def is_ignored(self):
        """
        Flag controlling whether the item is ignored during processing
        """
        return self.config.get('is_ignored', False)

    @property
    def ignore_internal_procedures(self):
        """
        Flag controlling the inclusion of internal procedures as dependencies
        """
        return self.config.get('ignore_internal_procedures', True)

    @classmethod
    def match_symbol_or_name(cls, symbol_or_name, keys, scope=None):
        """
        Match a :any:`TypedSymbol`, :any:`MetaSymbol` or name against
        a list of config values given as :data:`keys`

        This checks whether :data:`symbol_or_name` matches any of the given entries,
        which would typically be something like the :attr:`disable`, :attr:`ignore`, or
        :attr:`block` config entries.

        Optionally, :data:`scope` provides the name of the scope in which
        :data:`symbol_or_name` is defined.
        For derived type members, this takes care of resolving to the type name
        and matching that. This will also match successfully, if only parent components
        match, e.g., the scope name or the type name of the symbol.
        The use of simple patterns is allowed, see :any:`SchedulerConfig.match_item_keys`
        for more information.

        Parameters
        ----------
        symbol_or_name : :any:`TypedSymbol` or :any:`MetaSymbol` or str
            The symbol or name to match
        keys : list of str
            The list of candidate names to match against. This can be fully qualified
            names (e.g., ``'my_scope#my_routine'``), plain scope or routine names
            (e.g., ``'my_scope'`` or ``'my_routine'``), or use simple patterns (e.g., ``'my_*'``).
        scope : str, optional
            The name of the scope, in which :data:`symbol_or_name` is defined, if available.
            Providing this allows to match a larger range of name combinations

        Returns
        -------
        bool
            ``True`` if matched successfully, otherwise ``False``
        """
        if isinstance(symbol_or_name, str):
            scope_prefix = f'{scope!s}#'.lower() if scope is not None else ''
            return len(SchedulerConfig.match_item_keys(
                f'{scope_prefix}{symbol_or_name}', keys, use_pattern_matching=True, match_item_parents=True
            )) > 0

        if parents := getattr(symbol_or_name, 'parents', None):
            type_name = parents[0].type.dtype.name
            parents = [parent.basename for parent in parents[1:]]
            return cls.match_symbol_or_name(
                '%'.join([type_name, *parents, symbol_or_name.basename]), keys, scope=scope
            )

        if type_ := getattr(symbol_or_name, 'type', None):
            if isinstance(type_.dtype, (ProcedureType, DerivedType)):
                return cls.match_symbol_or_name(type_.dtype.name, keys, scope=scope)

        return cls.match_symbol_or_name(str(symbol_or_name), keys, scope=scope)
loki-ecmwf-0.3.6/loki/batch/item_factory.py0000664000175000017500000007517315167130205021034 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.batch.configure import SchedulerConfig
from loki.batch.item import (
    get_all_import_map, ExternalItem, FileItem, InterfaceItem, ModuleItem,
    ProcedureBindingItem, ProcedureItem, TypeDefItem
)
from loki.expression import ProcedureSymbol
from loki.ir import nodes as ir
from loki.logging import warning
from loki.module import Module
from loki.subroutine import Subroutine
from loki.sourcefile import Sourcefile
from loki.tools import CaseInsensitiveDict, as_tuple
from loki.types import BasicType


__all__ = ['ItemFactory']


class ItemFactory:
    """
    Utility class to instantiate instances of :any:`Item`

    It maintains a :attr:`item_cache` for all created items. Most
    important factory method is :meth:`create_from_ir` to create (or
    return from the cache) a :any:`Item` object corresponding to an
    IR node. Other factory methods exist for more bespoke use cases.

    Attributes
    ----------
    item_cache : :any:`CaseInsensitiveDict`
        This maps item names to corresponding :any:`Item` objects
    """

    def __init__(self):
        self.item_cache = CaseInsensitiveDict()

    def __contains__(self, key):
        """
        Check if an item under the given name exists in the :attr:`item_cache`
        """
        return key in self.item_cache

    def create_from_ir(self, node, scope_ir, config=None, ignore=None):
        """
        Helper method to create items for definitions or dependency

        This is a helper method to determine the fully-qualified item names
        and item type for a given IR :any:`Node`, e.g., when creating the items
        for definitions (see :any:`Item.create_definition_items`) or dependencies
        (see :any:`Item.create_dependency_items`).

        This routine's responsibility is to determine the item name, and then call
        :meth:`get_or_create_item` to look-up an existing item or create it.

        Parameters
        ----------
        node : :any:`Node` or :any:`pymbolic.primitives.Expression`
            The Loki IR node for which to create a corresponding :any:`Item`
        scope_ir : :any:`Scope`
            The scope node in which the IR node is declared or used. Note that this
            is not necessarily the same as the scope of the created :any:`Item` but
            serves as the entry point for the lookup mechanism that underpins the
            creation procedure.
        config : any:`SchedulerConfiguration`, optional
            The config object from which a bespoke item configuration will be derived.
        ignore : list of str, optional
            A list of item names that should be ignored, i.e., not be created as an item.
        """
        if isinstance(node, Module):
            item_name = node.name.lower()
            if self._is_ignored(item_name, config, ignore):
                return None
            return as_tuple(self.get_or_create_item(ModuleItem, item_name, item_name, config))

        if isinstance(node, Subroutine):
            scope_name = getattr(node.parent, 'name', '').lower()
            item_name = f'{scope_name}#{node.name}'.lower()
            if self._is_ignored(item_name, config, ignore):
                return None
            return as_tuple(
                self.get_or_create_item(ProcedureItem, item_name, scope_name, config)
            )

        if isinstance(node, ir.TypeDef):
            # A typedef always lives in a Module
            scope_name = node.parent.name.lower()
            item_name = f'{scope_name}#{node.name}'.lower()
            if self._is_ignored(item_name, config, ignore):
                return None
            return as_tuple(self.get_or_create_item(TypeDefItem, item_name, scope_name, config))

        if isinstance(node, ir.Import):
            # Skip intrinsic modules
            if node.nature == 'intrinsic':
                return None

            # Skip CPP includes
            if node.c_import:
                return None

            # If we have a fully-qualified import (which we hopefully have),
            # we create a dependency for every imported symbol, otherwise we
            # depend only on the imported module
            scope_name = node.module.lower()
            if self._is_ignored(scope_name, config, ignore):
                return None
            if scope_name not in self.item_cache:
                # This will instantiate an ExternalItem
                return as_tuple(self.get_or_create_item(ModuleItem, scope_name, scope_name, config))

            scope_item = self.item_cache[scope_name]

            if node.symbols:
                scope_definitions = {
                    item.local_name: item
                    for item in scope_item.create_definition_items(item_factory=self, config=config)
                }
                symbol_names = tuple(str(smbl.type.use_name or smbl).lower() for smbl in node.symbols)
                non_ignored_symbol_names = tuple(
                    smbl for smbl in symbol_names
                    if not self._is_ignored(f'{scope_name}#{smbl}', config, ignore)
                )
                imported_items = tuple(
                    it for smbl in non_ignored_symbol_names
                    if (it := scope_definitions.get(smbl)) is not None
                )

                # Global variable imports are filtered out in the previous statement because they
                # are not represented by an Item. For these, we introduce a dependency on the
                # module instead
                has_globalvar_import = len(imported_items) != len(non_ignored_symbol_names)

                # Filter out ProcedureItems corresponding to a subroutine:
                # dependencies on subroutines are introduced via the call statements, as this avoids
                # depending on imported but not called subroutines
                imported_items = tuple(
                    it for it in imported_items
                    if not isinstance(it, ProcedureItem) or it.ir.is_function
                )

                if has_globalvar_import:
                    return (scope_item,) + imported_items
                if not imported_items:
                    return None
                return imported_items

            return (scope_item,)

        if isinstance(node, ir.CallStatement):
            procedure_symbols = as_tuple(node.name)
        elif isinstance(node, ProcedureSymbol):
            procedure_symbols = as_tuple(node)
        elif isinstance(node, (ir.ProcedureDeclaration, ir.Interface)):
            procedure_symbols = as_tuple(node.symbols)
        else:
            raise ValueError(f'{node} has an unsupported node type {type(node)}')

        return tuple(
            self._get_procedure_binding_item(symbol, scope_ir, config, ignore=ignore) if '%' in symbol.name
            else self._get_procedure_item(symbol, scope_ir, config, ignore=ignore)
            for symbol in procedure_symbols
        )

    def get_or_create_item(self, item_cls, item_name, scope_name, config=None):
        """
        Helper method to instantiate an :any:`Item` of class :data:`item_cls`
        with name :data:`item_name`.

        This helper method checks for the presence of :data:`item_name` in the
        :attr:`item_cache` and returns that instance. If none is found, an instance
        of :data:`item_cls` is created and stored in the item cache.

        The :data:`scope_name` denotes the name of the parent scope, under which a
        parent :any:`Item` has to exist in :data:`self.item_cache` to find the source
        object to use.

        Item names matching one of the entries in the :data:`config` disable list
        are skipped. If `strict` mode is enabled, this raises a :any:`RuntimeError`
        if no matching parent item can be found in the item cache.

        Parameters
        ----------
        item_cls : subclass of :any:`Item`
            The class of the item to create
        item_name : str
            The name of the item to create
        scope_name : str
            The name under which a parent item can be found in the :attr:`item_cache`
            to find the corresponding source
        config : :any:`SchedulerConfig`, optional
            The config object to use to determine disabled items, and to use when
            instantiating the new item

        Returns
        -------
        :any:`Item` or None
            The item object or `None` if disabled or impossible to create
        """
        if item_name in self.item_cache:
            return self.item_cache[item_name]

        if not scope_name and item_name.count('#') == 2:
            # For an internal procedure that is contained in a procedure that
            # is not a module procedure, we use the parent procedure as scope
            # in the lookup
            scope_name = item_name.rsplit('#', maxsplit=1)[0]
        scope_item = self.item_cache.get(scope_name)

        item_conf = config.create_item_config(item_name) if config else None
        if scope_item is None or isinstance(scope_item, ExternalItem):
            warning(f'Module {scope_name} not found in self.item_cache. Marking {item_name} as an external dependency')
            item = ExternalItem(item_name, source=None, config=item_conf, origin_cls=item_cls,
                                is_generated=self._is_generated(item_name, item_conf))
        else:
            source = scope_item.source
            item = item_cls(item_name, source=source, config=item_conf)
        self.item_cache[item_name] = item
        return item

    def get_or_create_item_from_item(self, name, item, config=None):
        """
        Helper method to instantiate an :any:`Item` as a clone of a given :data:`item`
        with the given new :data:`name`.

        This helper method checks for the presence of :data:`name` in the
        :attr:`item_cache` and returns that instance. If none is in the cache, it tries
        a lookup via the scope, if applicable. Otherwise, a new item is created as
        a duplicate.

        This duplication is performed by replicating the corresponding :any:`FileItem`
        and any enclosing scope items, applying name changes for scopes as implied by
        :data:`name`.

        Parameters
        ----------
        name : str
            The name of the item to create
        item : :any:`Item`
            The item to duplicate to create the new item
        config : :any:`SchedulerConfig`, optional
            The config object to use when instantiating the new item

        Returns
        -------
        :any:`Item`
            The new item object
        """
        # Sanity checks and early return if an item by that name exists
        if name in self.item_cache:
            return self.item_cache[name]

        if not isinstance(item, ProcedureItem):
            raise NotImplementedError(f'Cloning of Items is not supported for {type(item)}')

        # Derive name components for the new item
        pos = name.find('#')
        local_name = name[pos+1:]
        if pos == -1:
            scope_name = None
            if local_name == item.local_name:
                raise RuntimeError(f'Cloning item {item.name} with the same name in global scope')
            if item.scope_name:
                raise RuntimeError(f'Cloning item {item.name} from local scope to global scope is not supported')
        else:
            scope_name = name[:pos]
            if scope_name and scope_name == item.scope_name:
                raise RuntimeError(f'Cloning item {item.name} as {name} creates name conflict for scope {scope_name}')
            if scope_name and not item.scope_name:
                raise RuntimeError(f'Cloning item {item.name} from global scope to local scope is not supported')

        # We may need to create a new item as a clone of the given item
        # For this, we start with replicating the source and updating the
        if not scope_name or scope_name not in self.item_cache:
            # Clone the source and update names
            new_source = item.source.clone()
            if scope_name:
                scope = new_source[item.scope_name]
                scope.name = scope_name
                item_ir = scope[item.local_name]
            else:
                item_ir = new_source[item.local_name]
            item_ir.name = local_name

            # Create a new FileItem for the new source
            new_source.path = item.path.with_name(f'{scope_name or local_name}{item.path.suffix}')
            file_item = self.get_or_create_file_item_from_source(new_source, config=config)

            # Get the definition items for the FileItem and return the new item
            definition_items = {
                it.name: it for it in file_item.create_definition_items(item_factory=self, config=config)
            }
            self.item_cache.update(definition_items)

            if name in definition_items:
                return definition_items[name]

        # Check for existing scope item
        if scope_name and scope_name in self.item_cache:
            scope = self.item_cache[scope_name].ir
            if local_name not in scope:
                raise RuntimeError((
                    f'Cloning item {item.name} as {name} failed, '
                    f'{local_name} not found in existing scope {scope_name}'
                ))
            return self.create_from_ir(scope[local_name], scope, config=config)

        raise RuntimeError(f'Failed to clone item {item.name} as {name}')

    def get_or_create_file_item_from_path(self, path, config, frontend_args=None):
        """
        Utility method to create a :any:`FileItem` for a given path

        This is used to instantiate items for the first time during the scheduler's
        discovery phase. It will use a cached item if it exists, or parse the source
        file using the given :data:`frontend_args`.

        Parameters
        ----------
        path : str or pathlib.Path
            The file path of the source file
        config : :any:`SchedulerConfig`
            The config object from which the item configuration will be derived
        frontend_args : dict, optional
            Frontend arguments that are given to :any:`Sourcefile.from_file` when
            parsing the file
        """
        item_name = str(path).lower()
        if file_item := self.item_cache.get(item_name):
            return file_item

        if not frontend_args:
            frontend_args = {}
        if config:
            frontend_args = config.create_frontend_args(path, frontend_args)

        source = Sourcefile.from_file(path, **frontend_args)
        item_conf = config.create_item_config(item_name) if config else None
        file_item = FileItem(item_name, source=source, config=item_conf)
        self.item_cache[item_name] = file_item
        return file_item

    def get_or_create_file_item_from_source(self, source, config):
        """
        Utility method to create a :any:`FileItem` corresponding to a given source object

        This can be used to create a :any:`FileItem` for an already parsed :any:`Sourcefile`,
        or when looking up the file item corresponding to a :any:`Item` by providing the
        item's ``source`` object.

        Lookup is not performed via the ``path`` property in :data:`source` but by
        searching for an existing :any:`FileItem` in the cache that has the same source
        object. This allows creating clones of source files during transformations without
        having to ensure their path property is always updated. Only if no item is found
        in the cache, a new one is created.

        Parameters
        ----------
        source : :any:`Sourcefile`
            The existing sourcefile object for which to create the file item
        config : :any:`SchedulerConfig`
            The config object from which the item configuration will be derived
        """
        # Check for file item with the same source object
        for item in self.item_cache.values():
            if isinstance(item, FileItem) and item.source is source:
                return item

        if not source.path:
            raise RuntimeError('Cannot create FileItem from source: Sourcefile has no path')

        # Create a new file item
        item_name = str(source.path).lower()
        item_conf = config.create_item_config(item_name) if config else None
        file_item = FileItem(item_name, source=source, config=item_conf)
        self.item_cache[item_name] = file_item
        return file_item

    def _get_procedure_binding_item(self, proc_symbol, scope_ir, config, ignore=None):
        """
        Utility method to create a :any:`ProcedureBindingItem` for a given
        :any:`ProcedureSymbol`

        Parameters
        ----------
        proc_symbol : :any:`ProcedureSymbol`
            The procedure symbol of the type binding
        scope_ir : :any:`Scope`
            The scope node in which the procedure binding is declared or used. Note that this
            is not necessarily the same as the scope of the created :any:`Item` but
            serves as the entry point for the lookup mechanism that underpins the
            creation procedure.
        config : :any:`SchedulerConfig`
            The config object from which the item configuration will be derived
        ignore : list of str, optional
            A list of item names that should be ignored, i.e., not be created as an item.
        """
        is_strict = not config or config.default.get('strict', True)

        # This is a typebound procedure call: we are only resolving
        # to the type member by mapping the local name to the type name,
        # and creating a ProcedureBindingItem. For that we need to find out
        # the type of the derived type symbol.
        # NB: For nested derived types, we create multiple such ProcedureBindingItems,
        #     resolving one type at a time, e.g.
        #     my_var%member%procedure -> my_type%member%procedure -> member_type%procedure -> procedure
        dtype = proc_symbol.parents[0].type.dtype
        if dtype is BasicType.DEFERRED:
            msg = f'Missing type information for symbol {proc_symbol.parents[0]}'
            warning(msg)
            return None
        type_name = dtype.name
        scope_name = None

        # Imported in current or parent scopes?
        if imprt := get_all_import_map(scope_ir).get(type_name):
            scope_name = imprt.module
            type_name = self._get_imported_symbol_name(imprt, type_name)

        # Otherwise: must be declared in parent module scope
        if not scope_name:
            scope = scope_ir
            while scope:
                if isinstance(scope, Module):
                    if type_name in scope.typedef_map:
                        scope_name = scope.name
                    break
                scope = scope.parent

        # Unknown: Likely imported via `USE` without `ONLY` list
        if not scope_name:
            # We create definition items for TypeDefs in all modules for which
            # we have unqualified imports, to find the type definition that
            # may have been imported via one of the unqualified imports
            unqualified_import_modules = [
                imprt.module for imprt in scope_ir.all_imports if not imprt.symbols
            ]
            candidates = self.get_or_create_module_definitions_from_candidates(
                type_name, config, module_names=unqualified_import_modules, only=TypeDefItem
            )
            if not candidates:
                msg = f'Unable to find the module declaring {type_name}.'
                if is_strict:
                    raise RuntimeError(msg)
                warning(msg)
                return None
            if len(candidates) > 1:
                msg = f'Multiple definitions for {type_name}: '
                msg += ','.join(item.name for item in candidates)
                if is_strict:
                    raise RuntimeError(msg)
                warning(msg)
            scope_name = candidates[0].scope_name

        item_name = f'{scope_name}#{type_name}%{"%".join(proc_symbol.name_parts[1:])}'.lower()
        if self._is_ignored(item_name, config, ignore):
            return None
        return self.get_or_create_item(ProcedureBindingItem, item_name, scope_name, config)

    def _get_procedure_item(self, proc_symbol, scope_ir, config, ignore=None):
        """
        Utility method to create a :any:`ProcedureItem`, :any:`ProcedureBindingItem`,
        or :any:`InterfaceItem` for a given :any:`ProcedureSymbol`

        Parameters
        ----------
        proc_symbol : :any:`ProcedureSymbol`
            The procedure symbol for which the corresponding item is created
        scope_ir : :any:`Scope`
            The scope node in which the procedure symbol is declared or used. Note that this
            is not necessarily the same as the scope of the created :any:`Item` but
            serves as the entry point for the lookup mechanism that underpins the
            creation procedure.
        config : :any:`SchedulerConfig`
            The config object from which the item configuration will be derived
        ignore : list of str, optional
            A list of item names that should be ignored, i.e., not be created as an item.
        """
        proc_name = proc_symbol.name

        if proc_name in scope_ir:
            if isinstance(scope_ir, ir.TypeDef):
                # This is a procedure binding item
                scope_name = scope_ir.parent.name.lower()
                item_name = f'{scope_name}#{scope_ir.name}%{proc_name}'.lower()
                if self._is_ignored(item_name, config, ignore):
                    return None
                return self.get_or_create_item(ProcedureBindingItem, item_name, scope_name, config)

            if isinstance(scope_ir, Subroutine) and proc_name in scope_ir.subroutine_map:
                # This is an internal procedure
                current_module = scope_ir.parent
                scope_name = current_module.name.lower() if current_module else ''
                item_name = f'{scope_name}#{scope_ir.name}#{proc_name}'.lower()
                if self._is_ignored(item_name, config, ignore):
                    return None
                return self.get_or_create_item(ProcedureItem, item_name, scope_name, config)

        # Recursively search for the enclosing module
        current_module = None
        scope = scope_ir
        while scope:
            if isinstance(scope, Module):
                current_module = scope
                break
            scope = scope.parent

        if current_module and any(proc_name.lower() == r.name.lower() for r in current_module.subroutines):
            # This is a call to a procedure in the same module
            scope_name = current_module.name
            item_name = f'{scope_name}#{proc_name}'.lower()
            if self._is_ignored(item_name, config, ignore):
                return None
            return self.get_or_create_item(ProcedureItem, item_name, scope_name, config)

        if current_module and proc_name in current_module.interface_symbols:
            # This procedure is declared in an interface in the current module
            scope_name = scope_ir.name
            item_name = f'{scope_name}#{proc_name}'.lower()
            if self._is_ignored(item_name, config, ignore):
                return None
            return self.get_or_create_item(InterfaceItem, item_name, scope_name, config)

        if imprt := get_all_import_map(scope_ir).get(proc_name):
            # This is a call to a module procedure which has been imported via
            # a fully qualified import in the current or parent scope
            scope_name = imprt.module
            proc_name = self._get_imported_symbol_name(imprt, proc_name)
            item_name = f'{scope_name}#{proc_name}'.lower()
            if self._is_ignored(item_name, config, ignore):
                return None
            return self.get_or_create_item(ProcedureItem, item_name, scope_name, config)

        # This may come from an unqualified import
        unqualified_imports = [imprt for imprt in scope_ir.all_imports if not imprt.symbols]
        if unqualified_imports:
            # We try to find the ProcedureItem in the unqualified module imports
            module_names = [imprt.module for imprt in unqualified_imports]
            candidates = self.get_or_create_module_definitions_from_candidates(
                proc_name, config, module_names=module_names, only=ProcedureItem
            )
            if candidates:
                if len(candidates) > 1:
                    candidate_modules = [it.scope_name for it in candidates]
                    raise RuntimeError(
                        f'Procedure {item_name} defined in multiple imported modules: {", ".join(candidate_modules)}'
                    )
                return candidates[0]

        # This is a call to a subroutine declared via header-included interface
        item_name = f'#{proc_name}'.lower()
        if self._is_ignored(item_name, config, ignore):
            return None
        if config and config.is_disabled(item_name):
            return None
        if item_name in self.item_cache:
            return self.item_cache[item_name]

        # We definitely can't find this item, which may be fine, depending on config...
        item_conf = config.create_item_config(item_name) if config else None
        is_generated = self._is_generated(item_name, item_conf)

        if not is_generated and (not config or config.default.get('strict', True)):
            raise RuntimeError(f'Procedure {item_name} not found in self.item_cache.')

        warning(f'Procedure {item_name} not found in self.item_cache, marking as external dependency')
        return ExternalItem(
            item_name, source=None, config=item_conf,
            origin_cls=ProcedureItem, is_generated=is_generated
        )

    def get_or_create_module_definitions_from_candidates(self, name, config, module_names=None, only=None):
        """
        Utility routine to get definition items matching :data:`name`
        from a given list of module candidates

        This can be used to find a dependency that has been introduced via an unqualified
        import statement, where the local name of the dependency is known and a set of
        candidate modules thanks to the unqualified imports on the use side.

        Parameters
        ----------
        name : str
            Local name of the item(s) in the candidate modules
        config : :any:`SchedulerConfig`
            The config object from which the item configuration will be derived
        module_names : list of str, optional
            List of module candidates in which to create the definition items. If not provided,
            all :any:`ModuleItem` in the cache will be considered.
        only : list of :any:`Item` classes, optional
            Filter the generated items to include only those of the type provided in the list

        Returns
        -------
        tuple of :any:`Item`
            The items matching :data:`name` in the modules given in :any:`module_names`.
            Ideally, only a single item will be found (or there would be a name conflict).
        """
        if not module_names:
            module_names = [item.name for item in self.item_cache.values() if isinstance(item, ModuleItem)]
        items = []
        for module_name in module_names:
            module_item = self.item_cache.get(module_name)
            if module_item:
                definition_items = module_item.create_definition_items(
                    item_factory=self, config=config, only=only
                )
                items += [_it for _it in definition_items if _it.name[_it.name.index('#')+1:] == name.lower()]
        return tuple(items)

    @staticmethod
    def _get_imported_symbol_name(imprt, symbol_name):
        """
        For a :data:`symbol_name` and its corresponding :any:`Import` node :data:`imprt`,
        determine the symbol in the defining module.

        This resolves renaming upon import but, in most cases, will simply return the
        original :data:`symbol_name`.

        Returns
        -------
        :any:`MetaSymbol` or :any:`TypedSymbol` :
            The symbol in the defining scope
        """
        if imprt.symbols:
            imprt_symbol = imprt.symbols[imprt.symbols.index(symbol_name)]
        else:
            rename_dic = CaseInsensitiveDict({local.name: use for (use, local) in imprt.rename_list})
            symbol_name = rename_dic[symbol_name]
            imprt_symbol = None
        if imprt_symbol and imprt_symbol.type.use_name:
            symbol_name = imprt_symbol.type.use_name
        return symbol_name

    @staticmethod
    def _is_ignored(name, config, ignore):
        """
        Utility method to check if a given :data:`name` is ignored

        Parameters
        ----------
        name : str
            The name to check
        config : :any:`SchedulerConfig`, optional
            An optional config object, in which :any:`SchedulerConfig.is_disabled`
            is checked for :data:`name`
        ignore : list of str, optional
            An optional list of names, as typically provided in a config value.
            These are matched via :any:`SchedulerConfig.match_item_keys` with
            pattern matching enabled.

        Returns
        -------
        bool
            ``True`` if matched successfully via :data:`config` or :data:`ignore` list,
            otherwise ``False``
        """
        keys = as_tuple(config.disable if config else ()) + as_tuple(ignore)
        return bool(keys and SchedulerConfig.match_item_keys(
            name, keys, use_pattern_matching=True, match_item_parents=True
        ))

    @staticmethod
    def _is_generated(name, config=None):
        """
        Utility method to check if a given :data:`name` is build-time generated.

        Parameters
        ----------
        name : str
            The name to check
        config : dict, optional
            A config dictionary for a given item, as created by
            :any:`SchedulerConfig.create_item_config`. This is queried for the `"generated"`
            key to obtain a list of names, as typically provided in a scheduler config
            or injected by a :any:`Transformation`. These are matched via
            :any:`SchedulerConfig.match_item_keys` with pattern matching enabled.

        Returns
        -------
        bool
            ``True`` if matched successfully via :data:`config` otherwise ``False``
        """
        keys = config.get('generated', []) if config else []
        return bool(keys and SchedulerConfig.match_item_keys(
            name, keys, use_pattern_matching=True, match_item_parents=True
        ))
loki-ecmwf-0.3.6/loki/batch/sgraph.py0000664000175000017500000004735415167130205017633 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from collections import deque, defaultdict
from pathlib import Path
from codetiming import Timer
import networkx as nx

from loki.batch.configure import SchedulerConfig
from loki.batch.item import (
    InterfaceItem, ProcedureItem, ProcedureBindingItem, TypeDefItem
)
from loki.batch.sfilter import SFilter
from loki.logging import debug, perf, warning
from loki.tools import as_tuple


__all__ = ['SGraph']


class SGraph:
    """
    The dependency graph underpinning the :any:`Scheduler`

    It is built upon a :any:`networkx.DiGraph` to expose the dependencies
    between :any:`Item` nodes. It is typically created from one or multiple
    `seed` items via :meth:`from_seed` by recursively chasing dependencies.

    Cyclic dependencies are broken for procedures that are marked as
    ``RECURSIVE``, which would otherwise constitute a dependency on itself.
    See :meth:`_break_cycles`.

    Parameters
    ----------
    graph : :any:`networkx.DiGraph`, optional
        Optionally, directed graph as instance of
        :any:`networkx.DiGraph` as initial graph.
    """

    def __init__(self, graph=None):
        self._graph = graph or nx.DiGraph()

    @classmethod
    @Timer(logger=perf, text='[Loki::Scheduler] Built SGraph from seed in {:.2f}s')
    def from_seed(cls, seed, item_factory, config=None):
        """
        Create a new :any:`SGraph` using :data:`seed` as starting point.

        Parameters
        ----------
        seed : (list of) str
            The names of the root nodes
        item_factory : :any:`ItemFactory`
            The item factory to use when creating graph nodes
        config : :any:`SchedulerConfig`, optional
            The config object to use when creating items
        """
        _graph = cls()
        _graph._populate(seed, item_factory, config)
        _graph._break_cycles()
        return _graph

    @classmethod
    def from_dict(cls, graph_dict):
        """
        Create a new :any:`SGraph` using :data:`graph_dict` as starting point.

        Parameters
        ----------
        graph_dict : dict
            Representation of the underlying graph
            represented as dict.
        """
        graph = nx.DiGraph(graph_dict)
        return cls(graph)

    def as_filegraph(self, item_factory, config=None, item_filter=None, exclude_ignored=False):
        """
        Convert the :any:`SGraph` to a dependency graph that only contains
        :any:`FileItem` nodes.

        Parameters
        ----------
        item_factory : :any:`ItemFactory`
            The item factory to use when creating graph nodes
        config : :any:`SchedulerConfig`, optional
            The config object to use when creating items
        item_filter : list of :any:`Item` subclasses, optional
            Only include files that include at least one dependency item of the
            given type. By default, all items are included.
        exclude_ignored : bool, optional
            Exclude :any:`Item`s that have the ``is_ignored`` property

        Returns
        -------
        :any:`SGraph`
            A new graph object
        """
        _graph = type(self)()
        _graph._populate_filegraph(self, item_factory, config, item_filter, exclude_ignored=exclude_ignored)
        return _graph

    def _create_item(self, name, item_factory, config):
        """
        Utility method to create a new item node with the given :data:`name`

        This may trigger on-demand creation of definition items in
        the enclosing scope.
        """
        if '#' not in name:
            name = f'#{name}'
        item = item_factory.item_cache.get(name)

        if not item:
            # We may have to create the corresponding module's definitions first to make
            # the item available in the cache
            scope_name = name[:name.index('#')]
            module_item = item_factory.item_cache.get(scope_name)
            if module_item:
                module_item.create_definition_items(item_factory=item_factory, config=config)
                item = item_factory.item_cache.get(name)

        if not item:
            # The name may be a module procedure or type that is not fully qualified,
            # so we need to search all modules for any matching routines
            if '%' in name:
                module_member_name = name[name.index('#')+1:name.index('%')]
            else:
                module_member_name = name[name.index('#')+1:]
            item = item_factory.get_or_create_module_definitions_from_candidates(
                module_member_name, config, only=(ProcedureItem, TypeDefItem)
            ) or None

            if item and '%' in name:
                # If this is a type-bound procedure, we may have to create its definitions
                for _item in item:
                    _item.create_definition_items(item_factory=item_factory, config=config)
                item = item_factory.item_cache.get(name)

        return item

    def _add_children(self, item, item_factory, config, dependencies=None):
        """
        Create items for dependencies of the :data:`item` and add them to
        the graph as a dependency of :data:`item`

        Parameters
        ----------
        item : :any:`Item`
            Create the dependencies for this item
        item_factory : :any:`ItemFactory`
            The item factory to use when creating graph nodes
        config : :any:`SchedulerConfig`, optional
            The config object to use when creating items
        dependencies : list, optional
            An initial list of already created dependencies

        Returns
        -------
        list of :any:`Item`
            The list of new items that have been added to the graph
        """
        dependencies = as_tuple(dependencies)
        for dependency in item.create_dependency_items(item_factory=item_factory, config=config):
            if not (dependency in dependencies or SchedulerConfig.match_item_keys(dependency.name, item.block)):
                dependency.config['is_ignored'] = (
                    item.is_ignored or
                    bool(SchedulerConfig.match_item_keys(dependency.name, item.ignore, match_item_parents=True))
                )
                dependencies += (dependency,)

        new_items = tuple(item_ for item_ in dependencies if item_ not in self._graph)
        if new_items:
            self.add_nodes(new_items)

        # propagate 'lib' attribute (the compile unit the item belongs to)
        for new_item in new_items:
            new_item.config['lib'] = item.config.get('lib', None)

        # Careful not to include cycles (from recursive TypeDefs)
        self.add_edges((item, item_) for item_ in dependencies if not item == item_)
        return new_items

    def _populate(self, seed, item_factory, config):
        """
        Build the dependency graph, initialised from :data:`seed` using :data:`item_factory`
        to create the node items

        Parameters
        ----------
        seed : (list of) str
            The names of the seed items
        item_factory : :any:`ItemFactory`
            The item factory to use when creating graph nodes
        config : :any:`SchedulerConfig`, optional
            The config object to use when creating items
        """
        queue = deque()

        # Insert the seed objects
        for name in as_tuple(seed):
            item = as_tuple(self._create_item(name, item_factory, config))
            if item:
                self.add_nodes(item)
                queue.extend(item)
            else:
                debug('No item found for seed "%s"', name)

        # Populate the graph
        while queue:
            item = queue.popleft()
            if item.expand:
                children = self._add_children(item, item_factory, config)
                if children:
                    queue.extend(children)

    def _populate_filegraph(self, sgraph, item_factory, config=None, item_filter=None, exclude_ignored=False):
        """
        Derive a dependency graph with :any:`FileItem` nodes from a given :data:`sgraph`

        Parameters
        ----------
        sgraph : :any:`SGraph`
            The dependency graph from which to derive the file graph
        item_factory : :any:`ItemFactory`
            The item factory to use when creating graph nodes
        config : :any:`SchedulerConfig`, optional
            The config object to use when creating items
        item_filter : list of :any:`Item` subclasses, optional
            Only include files that include at least one dependency item of the
            given type. By default, all items are included.
        exclude_ignored : bool, optional
            Exclude :any:`Item`s that have the ``is_ignored`` property
        """
        item_2_file_item_map = {}
        file_item_2_item_map = defaultdict(list)

        # Add the file nodes for each of the items matching the filter criterion
        for item in SFilter(sgraph, item_filter, exclude_ignored=exclude_ignored):
            file_item = item_factory.get_or_create_file_item_from_source(item.source, config)
            item_2_file_item_map[item.name] = file_item
            file_item_2_item_map[file_item.name] += [item]
            if file_item not in self._graph:
                self.add_node(file_item)

        # Update the "is_ignored" and "replicate" attributes for file items
        for items in file_item_2_item_map.values():
            file_item = item_2_file_item_map[items[0]]
            is_ignored = all(item.is_ignored for item in items)
            file_item.config['is_ignored'] = is_ignored

            replicate = any(item.replicate for item in items)
            if replicate:
                non_replicate_items = [item for item in items if not item.replicate]
                if non_replicate_items:
                    warning((
                        f'File {file_item.name} will be replicated but contains items '
                        f'that are marked as non-replicated: {", ".join(item.name for item in non_replicate_items)}'
                    ))
            file_item.config['replicate'] = replicate

            # propagate 'lib' attribute to the parent file item (the compile unit the item belongs to)
            default_lib = None
            libs = [item.lib for item in items]
            if any(lib is not default_lib for lib in libs):
                use_lib = None
                for lib in libs:
                    if lib is not default_lib:
                        use_lib = lib
                        break
                file_item.config['lib'] = use_lib

        # Insert edges to the file items corresponding to the successors of the items
        for item in SFilter(sgraph, item_filter, exclude_ignored=exclude_ignored):
            file_item = item_2_file_item_map[item.name]
            for child in sgraph._graph.successors(item):
                child_file_item = item_2_file_item_map.get(child.name)
                if not child_file_item or child_file_item == file_item:
                    # Skip 2 situations:
                    # 1) The child_file_item is None, i.e., not in item_2_file_item_map, if
                    #    the child does not match the item_filter
                    # 2) The child may be the same as the file if there is a dependency to
                    #    another item in the same file
                    continue
                self.add_edge((file_item, child_file_item))

    def _break_cycles(self):
        """
        Remove cyclic dependencies by deleting the first outgoing edge of
        each cyclic dependency for all procedure items with a ``RECURSIVE`` prefix
        """
        for item in self.items:  # We cannot iterate over the graph itself as we plan on changing it
            if (
                isinstance(item, ProcedureItem) and
                any('recursive' in prefix.lower() for prefix in getattr(item.ir, 'prefix', []) or [])
            ):
                try:
                    while True:
                        cycle_path = nx.find_cycle(self._graph, item)
                        debug(f'Removed edge {cycle_path[0]!s} to break cyclic dependency {cycle_path!s}')
                        self._graph.remove_edge(*cycle_path[0])
                except nx.NetworkXNoCycle:
                    pass


    def __iter__(self):
        """
        Iterate over the items in the dependency graph
        """
        return iter(SFilter(self))

    @property
    def items(self):
        """
        Return all :any:`Item` nodes in the dependency graph
        """
        return tuple(self._graph.nodes)

    @property
    def dependencies(self):
        """
        Return all dependencies, i.e., edges of the dependency graph
        """
        return tuple(self._graph.edges)

    @staticmethod
    def _get_item_filter(item_filter):
        """
        If :any:`ProcedureItem` is part of ``item_filter``, add :any:`ProcedureBindingItem` and
        :any:`InterfaceItem` as well, since these are intermediate nodes. Their
        dependencies will also be included until they eventually resolve to a
        :any:`ProcedureItem`.

        This returns the updated ``item_filter``.

        Parameters
        ----------
        item_filter : list of :any:`Item` subclasses, optional
            Filter successor items to only include items of the provided type. By default,
            all items are considered. Note that including :any:`ProcedureItem` in the
            ``item_filter`` automatically adds :any:`ProcedureBindingItem` and
            :any:`InterfaceItem` as well, since these are intermediate nodes. Their
            dependencies will also be included until they eventually resolve to a
            :any:`ProcedureItem`.
        """
        item_filter = as_tuple(item_filter)
        if ProcedureItem in item_filter:
            # ProcedureBindingItem and InterfaceItem are intermediate nodes that take
            # essentially the role of an edge to ProcedureItems. Therefore
            # we need to make sure these are included if ProcedureItems are included
            if ProcedureBindingItem not in item_filter:
                item_filter = item_filter + (ProcedureBindingItem,)
            if InterfaceItem not in item_filter:
                item_filter = item_filter + (InterfaceItem,)
        return item_filter or None

    def successors(self, item, item_filter=None):
        """
        Return the list of successor nodes in the dependency tree below :any:`Item`

        This returns all immediate successors (but can be filtered accordingly using
        the item's ``targets`` property) of the item in the dependency graph

        The list of successors is provided to transformations during processing with
        the :any:`Scheduler`.

        Parameters
        ----------
        item : :any:`Item`
            The item node in the dependency graph for which to determine the successors
        item_filter : list of :any:`Item` subclasses, optional
            Filter successor items to only include items of the provided type. By default,
            all items are considered. Note that including :any:`ProcedureItem` in the
            ``item_filter`` automatically adds :any:`ProcedureBindingItem` and
            :any:`InterfaceItem` as well, since these are intermediate nodes. Their
            dependencies will also be included until they eventually resolve to a
            :any:`ProcedureItem`.
        """
        # item_filter = self._get_item_filter(as_tuple(item_filter)) or None
        item_filter = self._get_item_filter(item_filter)

        successors = ()
        for child in self._graph.successors(item):
            if item_filter is None or isinstance(child, item_filter):
                if isinstance(child, (ProcedureBindingItem, InterfaceItem)):
                    successors += (child,) + self.successors(child)
                else:
                    successors += (child,)
        return successors

    def get_sub_sgraph(self, item, item_filter=None):
        """
        Return the subgraph of ``self._graph`` from source ``item`` as a new instance of
        :any:`SGraph`.

        Parameters
        ----------
        item : :any:`Item`
            The item node in the dependency graph for which to determine the successors
        item_filter : list of :any:`Item` subclasses, optional
            Filter successor items to only include items of the provided type. By default,
            all items are considered. Note that including :any:`ProcedureItem` in the
            ``item_filter`` automatically adds :any:`ProcedureBindingItem` and
            :any:`InterfaceItem` as well, since these are intermediate nodes. Their
            dependencies will also be included until they eventually resolve to a
            :any:`ProcedureItem`.
        """
        # item_filter = self._get_item_filter(as_tuple(item_filter)) or None
        item_filter = self._get_item_filter(item_filter)
        # find descendants and add item itself
        nodes = as_tuple(nx.descendants(self._graph, item)) + (item,)
        if item_filter is not None:
            nodes = as_tuple([node for node in nodes if isinstance(node, item_filter)])
        # generate (a copy of) the nx.DiGraph subgraph
        subgraph = self._graph.subgraph(nodes).copy()
        # return this subgraph as instance of SGraph -> sub_sgraph
        return type(self)(subgraph)

    @property
    def depths(self):
        """
        Return a mapping of :any:`Item` nodes to their depth (topological generation)
        in the dependency graph
        """
        topological_generations = list(nx.topological_generations(self._graph))
        depths = {
            item: i_gen
            for i_gen, gen in enumerate(topological_generations)
            for item in gen
        }
        return depths

    def add_node(self, item):
        """
        Add :data:`item` as a node to the dependency graph
        """
        self._graph.add_node(item)

    def add_nodes(self, items):
        """
        Add the given :data:`items` as nodes to the dependency graph
        """
        self._graph.add_nodes_from(items)

    def add_edge(self, edge):
        """
        Add a dependency :data:`edge` to the dependency graph
        """
        self._graph.add_edge(edge[0], edge[1])

    def add_edges(self, edges):
        """
        Add the dependency :data:`edges` to the dependency graph
        """
        self._graph.add_edges_from(edges)

    def export_to_file(self, dotfile_path):
        """
        Generate a dotfile from the current graph

        Parameters
        ----------
        dotfile_path : str or pathlib.Path
            Path to write the dotfile to. A corresponding graphical representation
            will be created with an additional ``.pdf`` appendix.
        """
        try:
            import graphviz as gviz  # pylint: disable=import-outside-toplevel
        except ImportError:
            warning('[Loki] Failed to load graphviz, skipping file export generation...')
            return

        path = Path(dotfile_path)
        graph = gviz.Digraph(format='pdf', strict=True, graph_attr=(('rankdir', 'LR'),))

        # Insert all nodes in the graph
        style = {
            'color': 'black', 'shape': 'box', 'fillcolor': 'limegreen', 'style': 'filled'
        }
        for item in self.items:
            graph.node(item.name.upper(), **style)

        # Insert all edges in the schedulers graph
        graph.edges((a.name.upper(), b.name.upper()) for a, b in self.dependencies)

        try:
            graph.render(path, view=False)
        except gviz.ExecutableNotFound as e:
            warning(f'[Loki] Failed to render callgraph due to graphviz error:\n  {e}')
loki-ecmwf-0.3.6/loki/batch/item.py0000664000175000017500000011633015167130205017274 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from functools import reduce
import sys

from loki.batch.configure import SchedulerConfig, ItemConfig
from loki.expression import (
    TypedSymbol, MetaSymbol, ProcedureSymbol, Variable
)
from loki.frontend import REGEX, RegexParserClass
from loki.ir import (
    Import, CallStatement, TypeDef, ProcedureDeclaration, Interface,
    FindNodes, FindInlineCalls
)
from loki.logging import debug, warning
from loki.module import Module
from loki.subroutine import Subroutine
from loki.tools import as_tuple, flatten, CaseInsensitiveDict
from loki.types import DerivedType


__all__ = [
    'get_all_import_map', 'Item', 'FileItem', 'ModuleItem', 'ProcedureItem',
    'TypeDefItem', 'InterfaceItem', 'ProcedureBindingItem', 'ExternalItem'
]


def get_all_import_map(scope):
    """
    Map of imported symbol names to objects in :data:`scope` and any parent scopes

    For imports that shadow imports in a parent scope, the innermost import
    takes precedence.

    Parameters
    ----------
    scope : :any:`Scope`
        The scope for which the import map is built

    Returns
    -------
    CaseInsensitiveDict
        Mapping of symbol name to symbol object
    """
    imports = getattr(scope, 'imports', ())
    while (scope := scope.parent):
        imports += getattr(scope, 'imports', ())
    return CaseInsensitiveDict(
        (s.name, imprt)
        for imprt in reversed(imports)
        for s in imprt.symbols or [r[1] for r in imprt.rename_list or ()]
    )


class Item(ItemConfig):
    """
    Base class of a work item in the :any:`Scheduler` graph, to which
    a :any:`Transformation` can be applied.

    The :any:`Scheduler` builds a dependency graph consisting of :any:`Item`
    instances as nodes.

    The :attr:`name` of a :class:`Item` refers to the corresponding routine's,
    interface's or type's name using a fully-qualified name in the format
    ``#``. The ```` corresponds to a Fortran
    module that, e.g., a subroutine is declared in, or can be empty if the
    subroutine is not enclosed in a module (i.e. exists in the global scope).
    This is to enable use of routines with the same name that are declared in
    different modules.
    The corresponding parts of the name can be accessed via :attr:`scope_name` and
    :attr:`local_name`.

    For type-bound procedures, the :attr:`local_name` should take the format
    ``%``. This may also span across multiple derived types, e.g.,
    to allow calls to type bound procedures of a derived type variable that in turn
    is a member of another derived type, e.g., ``%%``.
    See :class:`ProcedureBindingItem`` for more details.

    Relation to Loki IR
    -------------------

    Every :any:`Item` corresponds to a specific node in Loki's internal representation.

    For most cases these IR nodes are scopes:

    * :any:`FileItem`: corresponding to :any:`Sourcefile`
    * :any:`ModuleItem`: corresponding to :any:`Module`
    * :any:`ProcedureItem`: corresponding to :any:`Subroutine`

    The remaining cases are items corresponding to IR nodes that constitute some
    form of intermediate dependency, which are required to resolve the indirection
    to the scope node:

    * :any:`InterfaceItem`: corresponding to :any:`Interface`
    * :any:`TypeDefItem`: corresponding to :any:`TypeDef`
    * :any:`ProcedureBindingItem`: corresponding to the :any:`ProcedureSymbol`
      that is declared in a :any:`ProcedureDeclaration` in a derived type.

    The IR node corresponding to an item can be obtain via the :attr:`ir` property.

    Definitions and dependencies of items
    -------------------------------------

    Each item exhibits two key properties:

    * :attr:`definitions`: A list of all IR nodes that constitute symbols/names
    that are made available by an item. For a :any:`FileItem`, this typically consists
    of all modules and procedures in that sourcefile, and for a :any:`ModuleItem` it
    comprises of procedures, interfaces, global variables and derived type definitions.
    * :attr:`dependencies`: A list of all IR nodes that introduce a dependency
    on other items, e.g., :any:`CallStatement` or :any:`Import`.

    Item config
    -----------

    Every item has a bespoke configuration derived from the default values in
    :any:`SchedulerConfig`. The schema and accessible attributes are defined in the
    base class :any:`ItemConfig`.

    Attributes
    ----------

    _parser_class : tuple of :any:`RegexParserClass` or None
        The parser classes that need to be active during a parse with the :any:`REGEX`
        frontend to create the IR nodes corresponding to the item type. This
        class attribute is specified by every class implementing a specific item
        type.
    _defines_items : tuple of subclasses of :any:`Item`
        The types of items that definitions of the item may create. This class
        attribute is specified by every class implementing a specific item type.
    _depends_class : tuple of :any:`RegexParserClass` or None
        The parser classes that need to be active during a parse with the :any:`REGEX`
        frontend to create the IR nodes that constitute dependencies in this
        item type. This class attribute is specified by every class implementing
        a specific item type.
    source : :any:`Sourcefile`
        The sourcefile object in which the IR node corresponding to this item is defined.
        The :attr:`ir` property will look-up and yield the IR node in this source file.
    trafo_data : any:`dict`
        Container object for analysis passes to store analysis data. This can be used
        in subsequent transformation passes.
    plan_data : any:`dict`
        Container object for plan dry-run passes to store information about
        additional and removed dependencies.

    Parameters
    ----------
    name : str
        Name to identify items in the schedulers graph
    source : :any:`Sourcefile`
        The underlying source file that contains the associated item
    config : dict
        Dict of item-specific config options, see :any:`ItemConfig`
    """

    _parser_class = None
    _defines_items = ()
    _depends_class = None

    def __init__(self, name, source, config=None):
        self.name = name
        self.source = source
        self.trafo_data = {}
        self.plan_data = {}
        super().__init__(config)

    def __repr__(self):
        return f'loki.batch.{self.__class__.__name__}<{self.name}>'

    def __eq__(self, other):
        """
        :class:`Item` objects are considered equal if they refer to the same
        fully-qualified name, i.e., :attr:`name` is identical

        This allows also comparison against a string.
        """
        if isinstance(other, Item):
            return self.name.lower() == other.name.lower()
        if isinstance(other, str):
            return self.name.lower() == other.lower()
        return super().__eq__(other)

    def __hash__(self):
        return hash(self.name)

    @property
    def definitions(self):
        """
        Return a tuple of the IR nodes this item defines

        By default, this returns an empty tuple and is overwritten by
        derived classes.
        """
        return ()

    @property
    def dependencies(self):
        """
        Return a tuple of IR nodes that constitute dependencies for this item

        This calls :meth:`concretize_dependencies` to trigger a further parse
        with the :any:`REGEX` frontend, including the :attr:`_depends_class` of
        the item. The list of actual dependencies is defined via :meth:`_dependencies`,
        which is overwritten by derived classes.
        """
        self.concretize_dependencies()
        return self._dependencies

    @property
    def _dependencies(self):
        """
        Return a tuple of the IR nodes that constitute dependencies for this item

        This method is used by :attr:`dependencies` to determine the actual
        dependencies after calling :meth:`concretize_dependencies`.

        By default, this returns an empty tuple and is overwritten by
        derived classes.
        """
        return ()

    @property
    def ir(self):
        """
        Return the IR :any:`Node` that the item represents
        """
        return self.source[self.local_name]

    @property
    def scope_ir(self):
        """
        Return the nearest :any:`Scope` IR node that this item either defines
        or is embedded into.
        """
        return self.ir

    transformation_ir = scope_ir
    """
    Return the nearest :any:`Scope` IR node that corresponds to a suitable
    transformation entry point, i.e., is a :any:`ProgramUnit` or :any:`Sourcefile`.

    For most but not all item types this is equivalent to ``scope_ir``.
    """

    def _parser_classes_from_item_type_names(self, item_type_names):
        """
        Helper method that queries the :attr:`Item._parser_class` of all
        :any:`Item` subclasses listed in :data:`item_type_names`
        """
        item_types = [getattr(sys.modules[__name__], name) for name in item_type_names]
        parser_classes = [p for item_type in item_types if (p := item_type._parser_class) is not None]
        return reduce(lambda x, y: x | y, parser_classes, RegexParserClass.EmptyClass)

    def concretize_definitions(self):
        """
        Trigger a re-parse of the source file corresponding to the current item's scope

        This uses :meth:`_parser_classes_from_item_type_names` to determine all
        :any:`RegexParserClass` that the item's definitions require to be parsed.
        An item's definition classes are listed in :attr:`_defines_items`.
        """
        parser_classes = self._parser_classes_from_item_type_names(self._defines_items)
        if parser_classes and hasattr(self.ir, 'make_complete'):
            self.ir.make_complete(frontend=REGEX, parser_classes=parser_classes)

    def concretize_dependencies(self):
        """
        Trigger a re-parse of the source file corresponding to the current item's scope

        This uses :attr:`_depends_class` to determine all :any:`RegexParserClass` that
        the are require to be parsed to find the item's dependencies.
        """
        if not self._depends_class:
            return
        scope = self.scope_ir
        if not scope:
            debug('concretize_dependencies: No scope IR for %s', self.name)
            return
        while scope.parent:
            scope = scope.parent
        if hasattr(scope, 'make_complete'):
            scope.make_complete(frontend=REGEX, parser_classes=self._depends_class)

    def create_definition_items(self, item_factory, config=None, only=None):
        """
        Create the :any:`Item` nodes corresponding to the definitions in the
        current item

        Parameters
        ----------
        item_factory : :any:`ItemFactory`
            The :any:`ItemFactory` to use when creating the items
        config : :any:`SchedulerConfig`, optional
            The scheduler config to use when instantiating new items
        only : list of :any:`Item` classes
            Filter the generated items to include only those provided in the list

        Returns
        -------
        tuple
            The list of :any:`Item` nodes
        """
        if definitions := self.definitions:
            scope_ir = self.scope_ir
            items = as_tuple(flatten(
                item_factory.create_from_ir(node, scope_ir, config)
                for node in definitions
            ))
            items = as_tuple(item for item in items if item is not None)
        else:
            items = ()
        if only:
            items = tuple(item for item in items if isinstance(item, only))
        return items

    def create_dependency_items(self, item_factory, config=None, only=None):
        """
        Create the :any:`Item` nodes corresponding to the dependencies of the
        current item

        Parameters
        ----------
        item_factory : :any:`ItemFactory`
            The :any:`ItemFactory` to use when creating the items
        config : :any:`SchedulerConfig`, optional
            The scheduler config to use when instantiating new items
        only : list of :any:`Item` classes
            Filter the generated items to include only those provided in the list

        Returns
        -------
        tuple
            The list of :any:`Item` nodes
        """
        ignore = [*self.disable, *self.block]
        items = as_tuple(self.plan_data.get('additional_dependencies'))
        if (dependencies := self.dependencies):
            scope_ir = self.scope_ir
            items += tuple(
                item
                for node in dependencies
                for item in as_tuple(item_factory.create_from_ir(node, scope_ir, config, ignore=ignore))
                if item is not None
            )
        if self.disable:
            items = tuple(
                item for item in items
                if not SchedulerConfig.match_item_keys(item.name, self.disable)
            )
        if (removed_dependencies := self.plan_data.get('removed_dependencies')):
            items = tuple(item for item in items if item not in removed_dependencies)

        if only:
            items = tuple(item for item in items if isinstance(item, only))
        return tuple(dict.fromkeys(items))


    @property
    def scope_name(self):
        """
        The name of this item's scope
        """
        pos = self.name.find('#')
        if pos == -1:
            return None
        return self.name[:pos]

    @property
    def local_name(self):
        """
        The item name without the scope
        """
        return self.name[self.name.find('#')+1:]

    @property
    def scope(self):
        """
        IR object that is the enclosing scope of this :any:`Item`

        Note that this property is cached, so that updating the name of an associated
        :any:`Module` (eg. via the :any:`DependencyTransformation`) may not
        break the association with this :any:`Item`.

        Returns
        -------
        :any:`Module` or `NoneType`
        """
        name = self.scope_name
        if name is None:
            return None
        return self.source[name]

    @property
    def calls(self):
        """
        Return a tuple of local names of subroutines that are called

        This will replace the object name by the type name for calls to
        typebound procedures, but not resolve potential renaming via imports.
        """
        calls = tuple(call for call in self.dependencies if isinstance(call, CallStatement))
        calls = tuple(
            f'{call.name.parents[0].type.dtype.name}{call_name[call_name.index("%"):]}'
            if '%' in (call_name := str(call.name).lower()) else call_name
            for call in calls
        )
        return calls

    @property
    def targets(self):
        """
        Set of "active" child dependencies that are part of the transformation
        traversal.

        This includes all child dependencies of an item that will be
        traversed when applying a :any:`Transformation`, after tree pruning rules
        are applied but without taking item filters into account.

        This means, all items excluded via ``block`` or ``disable`` lists in the
        :any:`SchedulerConfig` are not listed here. However, this will include
        items in the ``ignore`` list, which are not processed but are treated
        as if they were.

        Returns
        -------
        list of str
        """
        # Determine an exclusion list
        exclude = as_tuple(str(t).lower() for t in self.disable)
        exclude += as_tuple(str(t).lower() for t in self.block)
        return self._get_children(exclude=exclude)

    @property
    def targets_and_blocked_targets(self):
        """
        Set of all child dependencies, including those that are not part of the
        traversal, but ignoring ``disabled`` dependencies.

        This includes all child dependencies that are returned by :attr:`Item.targets`
        as well as any that are excluded via the :attr:`ItemConfig.block` list.

        This means, only items excluded via ``disable`` lists in the
        :any:`SchedulerConfig` are not listed here. However, it will include
        items in the ``ignore`` and ``block`` list.

        Returns
        -------
        list of str
        """
        # Determine an exclusion list
        exclude = as_tuple(str(t).lower() for t in self.disable)
        return self._get_children(exclude=exclude)

    def _get_children(self, exclude=None):
        """
        Helper method that returns a list of child dependency names

        This takes :attr:`Item.dependencies` and translates the dependency nodes
        to their name, excluding any dependencies that match the exclusion
        list given in :data:`exclude`.

        This method is used by :attr:`targets` and :attr:`targets_and_blocked_targets`.
        """
        exclude = as_tuple(exclude)

        # Determine all potential targets from dependencies and filter out excluded targets
        if not (dependencies := self.dependencies):
            return ()

        def _add_new_child(name, is_excluded, child_exclusion_map):
            # Helper utility to add or update an entry
            child_exclusion_map[name] = child_exclusion_map.get(name, False) or is_excluded

        child_exclusion_map = CaseInsensitiveDict()
        import_map = get_all_import_map(self.scope_ir)
        for dependency in dependencies:
            if isinstance(dependency, Import):
                # Exclude all imported symbols if the module is excluded, otherwise
                # exclude only individual imported symbols as required
                is_excluded = self.match_symbol_or_name(dependency.module, exclude)
                _add_new_child(dependency.module, is_excluded, child_exclusion_map)
                for symbol in dependency.symbols or ():
                    is_symbol_excluded = (
                        is_excluded or symbol.type.parameter or
                        self.match_symbol_or_name(symbol, exclude, scope=dependency.module)
                    )
                    _add_new_child(symbol.name, is_symbol_excluded, child_exclusion_map)

            elif isinstance(dependency, Interface):
                for symbol in dependency.symbols:
                    if symbol.name in import_map:
                        scope = import_map[symbol.name].module
                    else:
                        scope = self.scope_name
                    _add_new_child(
                        symbol.name,
                        self.match_symbol_or_name(symbol, exclude, scope=scope),
                        child_exclusion_map
                    )

            elif isinstance(dependency, TypeDef):
                if dependency.name in import_map:
                    scope = import_map[dependency.name].module
                else:
                    scope = self.scope_name
                _add_new_child(
                    dependency.name,
                    self.match_symbol_or_name(dependency.name, exclude, scope=scope),
                    child_exclusion_map
                )

            elif isinstance(dependency, (Subroutine, CallStatement, MetaSymbol, TypedSymbol)):
                # Treating these together to avoid duplicating the control flow
                # for symbol matching
                if isinstance(dependency, CallStatement):
                    symbol = dependency.name
                elif isinstance(dependency, Subroutine):
                    symbol = dependency.procedure_symbol
                else:
                    symbol = dependency

                if '%' in symbol.name:
                    # We check both:
                    # the (potentially imported) type name via the call relative to
                    # the type name, and the (potentially imported) declared symbol itself
                    type_name = symbol.parents[0].type.dtype.name
                    call_name = f'{type_name}{symbol.name[symbol.name.index("%"):]}'
                    if type_name in import_map:
                        scope = import_map[type_name].module
                    else:
                        scope = self.scope_name
                    is_excluded = self.match_symbol_or_name(call_name, exclude, scope=scope)

                    declared_name = symbol.parents[0].name
                    if (declared_name := symbol.parents[0].name) in import_map:
                        scope = import_map[declared_name].module
                    else:
                        scope = self.scope_name
                    is_excluded = is_excluded or self.match_symbol_or_name(symbol, exclude, scope=scope)

                else:
                    if symbol.name in import_map:
                        scope = import_map[symbol.name].module
                    else:
                        scope = self.scope_name
                    is_excluded = self.match_symbol_or_name(symbol, exclude, scope=scope)

                _add_new_child(symbol.name, is_excluded, child_exclusion_map)
            else:
                raise ValueError(f'Unexpected dependency type {type(dependency)} for {dependency}')

        children = tuple(target for target, excluded in child_exclusion_map.items() if not excluded)
        return children

    @property
    def path(self):
        """
        The filepath of the associated source file
        """
        return self.source.path


class FileItem(Item):
    """
    Item class representing a :any:`Sourcefile`

    The name of this item is typically the file path.

    A :any:`FileItem` does not have any direct dependencies. A dependency
    filegraph can be generated by the :any:`SGraph` class using dependencies
    of items defined by nodes inside the file.

    A :any:`FileItem` defines :any:`ModuleItem` and :any:`ProcedureItem` nodes.
    """

    # We do not need to parse anything inside the file for this item type
    _parser_class = None

    # Modules and Procedures can appear in a sourcefile
    _defines_items = ('ModuleItem', 'ProcedureItem')

    @property
    def definitions(self):
        """
        Return the list of definitions in this source file
        """
        self.concretize_definitions()
        definitions = self.ir.definitions
        for obj in self.ir.definitions:
            if isinstance(obj, Module):
                definitions += obj.definitions
        return self.ir.definitions

    @property
    def ir(self):
        """
        Return the :any:`Sourcefile` associated with this item
        """
        return self.source

    def create_definition_items(self, item_factory, config=None, only=None):
        """
        Create the :any:`Item` nodes corresponding to the definitions in the file

        This overwrites the corresponding method in the base class to enable
        instantiating the top-level scopes in the file item without them being
        available in the :any:`ItemFactory.item_cache`, yet.

        Parameters
        ----------
        item_factory : :any:`ItemFactory`
            The :any:`ItemFactory` to use when creating the items
        config : :any:`SchedulerConfig`, optional
            The scheduler config to use when instantiating new items
        only : list of :any:`Item` classes
            Filter the generated items to include only those provided in the list

        Returns
        -------
        tuple
            The list of :any:`Item` nodes
        """
        items = ()
        for node in self.definitions:
            if isinstance(node, Module):
                items += as_tuple(
                    item_factory.get_or_create_item(ModuleItem, node.name.lower(), self.name, config)
                )
            elif isinstance(node, Subroutine) and not node.parent:
                items += as_tuple(
                    item_factory.get_or_create_item(ProcedureItem, f'#{node.name.lower()}', self.name, config)
                )
            else:
                items += item_factory.create_from_ir(node, self.scope_ir, config)
        items = as_tuple(item for item in items if item is not None)
        if only:
            items = tuple(item for item in items if isinstance(item, only))
        return items


class ModuleItem(Item):
    """
    Item class representing a :any:`Module`

    The name of this item is the module's name, meaning scope name, local name
    and name are all equivalent.

    A :any:`ModuleItem` defines :any:`ProcedureItem`, :any:`InterfaceItem` and
    :any:`TypeDefItem`. Note that global variable imports, which are the fourth
    kind of symbols that can be imported from a module into other scopes are not
    represented by bespoke items.

    A :any:`ModuleItem` can only have a dependency on another :any:`ModuleItem`
    via a :any:`Import` statement.
    """

    _parser_class = RegexParserClass.ProgramUnitClass
    _defines_items = ('ProcedureItem', 'InterfaceItem', 'TypeDefItem')
    _depends_class = RegexParserClass.ImportClass

    @property
    def definitions(self):
        """
        Return the list of definitions in this module, filtering out
        global variables.
        """
        self.concretize_definitions()
        definitions = self.ir.subroutines + as_tuple(FindNodes((TypeDef, Interface)).visit(self.ir.spec))
        return definitions

    @property
    def _dependencies(self):
        """
        Return the list of :any:`Import` nodes that constitute dependencies
        for this module, filtering out imports to intrinsic modules.
        """
        return tuple(
            imprt for imprt in self.ir.imports
            if not imprt.c_import and str(imprt.nature).lower() != 'intrinsic'
        )

    @property
    def local_name(self):
        """
        Return the module's name
        """
        return self.name


class ProcedureItem(Item):
    """
    Item class representing a :any:`Subroutine`

    The name of this item is comprised of the scope's name in which the procedure
    is declared, i.e., the enclosing module, and the procedure name:
    ``#``. For procedures that are not declared inside
    a module, the ```` is an empty string, i.e., the item name becomes
    ``#``.

    A :any:`ProcedureItem` does not define any child items.

    Dependencies of a :any:`ProcedureItem` can be introduced by

    * imports, i.e., a dependency on :any:`ModuleItem`,
    * the use of derived types, i.e., a dependency on :any:`TypeDefItem`,
    * calls to other procedures, i.e., a dependency on :any:`ProcedureItem` or,
      as an indirection, on :any:`InterfaceItem` or :any:`ProcedureBindingItem`.
    """

    _parser_class = RegexParserClass.ProgramUnitClass
    _depends_class = (
        RegexParserClass.ImportClass | RegexParserClass.InterfaceClass | RegexParserClass.TypeDefClass |
        RegexParserClass.DeclarationClass | RegexParserClass.CallClass | RegexParserClass.PragmaClass
    )

    @property
    def _dependencies(self):
        """
        Return the list of :any:`Import`, :any:`Interface`, :any:`TypeDef`,
        :any:`CallStatement`, and :any:`ProcedureSymbol` (to represent
        calls to functions) nodes that constitute dependencies of this item.
        """
        self_ir = self.ir
        if not self_ir:
            debug('Failed to resolve IR for item %s - cannot compile dependencies', self.name)
            return ()

        calls = tuple({call.name.name: call for call in FindNodes(CallStatement).visit(self_ir.ir)}.values())
        internal_procedures = [routine.name.lower() for routine in self_ir.routines]
        inline_calls = tuple({
            call.function.name: call.function
            for call in FindInlineCalls().visit(self_ir.ir)
            if isinstance(call.function, ProcedureSymbol) and not call.function.type.is_intrinsic
        }.values())
        if internal_procedures and self.ignore_internal_procedures:
            calls = tuple(call for call in calls if call.name.name.lower() not in internal_procedures)
            inline_calls = tuple(func for func in inline_calls
                                 if not func.name.lower() in internal_procedures)
        imports = tuple(
            imprt for imprt in self_ir.imports
            if not imprt.c_import and str(imprt.nature).lower() != 'intrinsic'
        )
        interfaces = self_ir.interfaces
        typedefs = ()

        # Create dependencies on type definitions that may have been declared in or
        # imported via the module scope
        if self.scope:
            type_names = [
                dtype.name for var in self_ir.variables
                if isinstance((dtype := var.type.dtype), DerivedType)
            ]
            if type_names:
                typedef_map = self.scope.typedef_map
                import_map = self.scope.import_map
                typedefs += tuple(typedef for type_name in type_names if (typedef := typedef_map.get(type_name)))
                imports += tuple(imprt for type_name in type_names if (imprt := import_map.get(type_name)))
        return imports + interfaces + typedefs + calls + inline_calls

    @property
    def ir(self):
        # For internal procedures we have to do a two-stage lookup here
        local_name = self.local_name
        if (sep := local_name.find('#')) != -1:
            return self.source[local_name[:sep]][local_name[sep+1:]]
        return self.source[local_name]


class TypeDefItem(Item):
    """
    Item class representing a :any:`TypeDef`

    The name of this item is comprised of the scope's name in which the derived type
    is declared, i.e., the enclosing module, and the type name:
    ``#``.

    A :any:`TypeDefItem` defines :any:`ProcedureBindingItem`.

    Dependencies of a :any:`TypeDefItem` are introduced via

    * the use of derived types in declarations of members, i.e., a dependency
      on :any:`TypeDefItem`,
    * imports of derived types, i.e., a dependency on :any:`ModuleItem`.
    """

    _parser_class = RegexParserClass.TypeDefClass
    _defines_items = ('ProcedureBindingItem',)
    _depends_class = RegexParserClass.DeclarationClass

    @property
    def _dependencies(self):
        """
        Return the list of :any:`Import` and :any:`TypeDef` nodes that this item
        depends upon.
        """
        # We restrict dependencies to derived types used in the typedef
        imports = ()
        typedefs = ()

        type_names = [
            dtype.name for var in self.ir.variables
            if isinstance((dtype := var.type.dtype), DerivedType)
        ]
        if type_names:
            typedef_map = self.scope.typedef_map
            import_map = self.scope.import_map
            typedefs = tuple(typedef for type_name in type_names if (typedef := typedef_map.get(type_name)))
            imports = tuple(imprt for type_name in type_names if (imprt := import_map.get(type_name)))

            def _trim_import_symbol_list(imprt):
                # Trim the import symbol list to relevant names only
                symbols = tuple(symbol for symbol in imprt.symbols if symbol in type_names)
                if symbols == imprt.symbols:
                    return imprt
                return imprt.clone(symbols=symbols)

            imports = tuple(_trim_import_symbol_list(imprt) for imprt in imports)

        return tuple(dict.fromkeys(imports + typedefs))

    @property
    def definitions(self):
        """
        Return the list of :any:`ProcedureDeclaration` nodes that define
        procedure bindings in this item.
        """
        return tuple(
            decl for decl in self.ir.declarations
            if isinstance(decl, ProcedureDeclaration)
        )

    transformation_ir = Item.scope
    """
    The transformation entry point for TypeDefItem is the scope in which the
    typedef is declared.
    """


class InterfaceItem(Item):
    """
    Item class representing a :any:`Interface` declared in a module

    The name of this item is comprised of the scope's name in which the interface
    is declared, i.e., the enclosing module, and the interface name:
    ``#``.

    A :any:`InterfaceItem` does not define any child items.

    The dependency of an :any:`InterfaceItem` is the procedure it declares,
    i.e., a :any:`ProcedureItem` or another :any:`InterfaceItem`.

    This does not constitute a work item when applying transformations across the
    call tree in the :any:`Scheduler` and is skipped by most transformations during
    the processing phase.
    However, it is necessary to provide the dependency link from calls to procedures
    declared via an interface to their implementation in a Fortran routine.
    """

    _parser_class = RegexParserClass.InterfaceClass

    @property
    def _dependencies(self):
        """
        Return the list of :any:`ProcedureSymbol` this interface declares.
        """
        return as_tuple(flatten(
            getattr(node, 'procedure_symbol', getattr(node, 'symbols', ()))
            for node in self.ir.body
        ))

    @property
    def scope_ir(self):
        """
        Return the :any:`Module` in which the interface is declared.
        """
        return self.scope


class ProcedureBindingItem(Item):
    """
    Item class representing a Fortran procedure binding

    The name of this item is comprised of three parts: the scope's name in
    which the derived type with this procedure binding is declared, the name
    of the derived type, and the name of the procedure binding:
    ``#%``.

    For nested derived types, the ```` may consist of multiple parts,
    e.g., ``#%%%``.

    A :any:`ProcedureBindingItem` does not define any child items.

    The dependency of a :any:`ProcedureBindingItem` is the procedure it binds to,
    i.e., a :any:`ProcedureItem`, an :any:`InterfaceItem`, or another
    :any:`ProcedureBindingItem` to resolve generic bindings or calls to bindings
    in nested derived types.

    A :any:`ProcedureBindingItem` does not constitute a work item when applying
    transformations across the dependency tree in the :any:`Scheduler` and is skipped
    during the processing phase by most transformations.
    However, it is necessary to provide the dependency link from calls to type bound
    procedures to their implementation in a Fortran routine.
    """

    _parser_class = RegexParserClass.TypeDefClass | RegexParserClass.CallClass
    _depends_class = RegexParserClass.DeclarationClass

    def __init__(self, name, source, config=None):
        assert '%' in name
        super().__init__(name, source, config)

    @property
    def ir(self):
        """
        Return the :any:`ProcedureSymbol` this binding corresponds to.
        """
        name_parts = self.local_name.split('%')
        typedef = self.source[name_parts[0]]
        if not typedef:
            self.scope.make_complete(frontend=REGEX, parser_classes=self._parser_class)
            typedef = self.source[name_parts[0]]
        if typedef:
            for decl in typedef.declarations:
                # We need to compare here explicitly symbol names as the symbol could be
                # declared with a dimension
                for symbol in decl.symbols:
                    if name_parts[1] == symbol.name.lower():
                        return decl.symbols[decl.symbols.index(symbol)]
        raise RuntimeError(f'Declaration for {self.name} not found')

    @property
    def scope_ir(self):
        """
        Return the :any:`TypeDef` in which this procedure binding appears.
        """
        try:
            typedef = self.ir.scope
        except RuntimeError as excinfo:
            # The typedef could not be found, it might originate from a transient import
            scope_name = self.name.split('#')[0]
            type_name = self.local_name.split('%')[0]
            scope = self.source[scope_name]
            if scope and (typedef := scope.imported_symbol_map.get(type_name)):
                debug("%s - transient import, using imported symbol", excinfo)
            else:
                warning(excinfo)
                typedef = None
        return typedef

    @property
    def _dependencies(self):
        """
        Return the list of :any:`ProcedureSymbol` that correspond to the routine
        binding
        """
        try:
            symbol = self.ir
        except RuntimeError as excinfo:
            # The procedure binding symbol could not be found, it might originate from an import
            debug("%s - cannot determine dependencies", excinfo)
            return ()

        name_parts = self.local_name.split('%')
        if len(name_parts) == 2:
            if symbol.type.dtype.is_generic:
                # Generic binding
                return tuple(
                    symbol.scope.variable_map[str(name)]
                    for name in as_tuple(symbol.type.bind_names)
                )
            if symbol.type.bind_names:
                # Specific binding with rename
                assert len(symbol.type.bind_names) == 1
                return as_tuple(symbol.type.bind_names[0].type.dtype.procedure)
            return as_tuple(self.source[symbol.name])

        # This is a typebound procedure on a member;
        # let's start by building the (possibly nested) intermediate symbols...
        symbol_name = f'{symbol.name}'
        for name_part in name_parts[2:-1]:
            symbol_name += '%' + name_part
            symbol = Variable(name=symbol_name, parent=symbol, scope=symbol.scope)
        # ...and explicitly instantiate the final symbol as ProcedureSymbol
        proc_name = f'{symbol_name}%{name_parts[-1]}'
        return as_tuple(ProcedureSymbol(name=proc_name, parent=symbol, scope=symbol.scope))

    transformation_ir = Item.scope
    """
    The transformation entry point is the scope in which the typedef is declared
    to which this procedure binding is attached.
    """


class ExternalItem(Item):
    """
    Item class representing an external dependency that cannot be resolved

    The name of this item may be a fully qualified name containing scope
    and local name, or only a local name.

    It does not define any child items or depend on other items.

    It does not constitute a work item when applying transformations across the
    call tree in the :any:`Scheduler`.

    Parameters
    ----------
    origin_cls :
        The subclass of :any:`Item` this item represents.
    is_generated : bool
        A flag that indicates whether this item is expected to be generated by the
        build.
    """

    def __init__(self, name, source, config=None, origin_cls=None, is_generated=False):
        self.origin_cls = origin_cls
        self.is_generated = is_generated
        super().__init__(name, source, config)

    @property
    def ir(self):
        """
        This raises a :any:`RuntimeError`
        """
        raise RuntimeError(f'No .ir available for ExternalItem `{self.name}`')

    @property
    def scope(self):
        """
        This raises a :any:`RuntimeError`
        """
        raise RuntimeError(f'No .scope available for ExternalItem `{self.name}`')

    @property
    def path(self):
        """
        This raises a :any:`RuntimeError`
        """
        raise RuntimeError(f'No .path available for ExternalItem `{self.name}`')
loki-ecmwf-0.3.6/loki/batch/sfilter.py0000664000175000017500000000511215167130205020001 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import networkx as nx

from loki.batch.item import Item, ExternalItem


__all__ = ['SFilter']


class SFilter:
    """
    Filtered iterator over a :any:`SGraph`

    This class allows to change the iteration behaviour over the dependency graph
    stored in :any:`SGraph`.

    Example use::

      items = scheduler.items
      reversed_items = as_tuple(SFilter(scheduler.sgraph, reverse=True))
      procedure_bindings = as_tuple(SFilter(scheduler.sgraph, item_filter=ProcedureBindingItem))

    Parameters
    ----------
    sgraph : :any:`SGraph`
        The graph over which to iterate
    item_filter : list of :any:`Item` subclasses, optional
        Only include items that match the provided list of types
    reverse : bool, optional
        Iterate over the dependency graph in reverse direction
    exclude_ignored : bool, optional
        Exclude :any:`Item` objects that have the ``is_ignored`` property
    include_external : bool, optional
        Do not skip :any:`ExternalItem` in the iterator
    """

    def __init__(self, sgraph, item_filter=None, reverse=False, exclude_ignored=False, include_external=False):
        self.sgraph = sgraph
        self.reverse = reverse
        if item_filter:
            self.item_filter = item_filter
        else:
            self.item_filter = Item
        self.exclude_ignored = exclude_ignored
        self.include_external = include_external

    def __iter__(self):
        if self.reverse:
            self._iter = iter(reversed(list(nx.topological_sort(self.sgraph._graph))))
        else:
            self._iter = iter(nx.topological_sort(self.sgraph._graph))
        return self

    def __next__(self):
        while node := next(self._iter):
            # Determine the node type but skip externals if applicable
            if isinstance(node, ExternalItem):
                if not self.include_external:
                    continue
                node_cls = node.origin_cls
            else:
                node_cls = type(node)
            if issubclass(node_cls, self.item_filter) and not (self.exclude_ignored and node.is_ignored):
                # We found the next item matching the filter (and which is not ignored, if applicable)
                break
        return node
loki-ecmwf-0.3.6/loki/batch/scheduler.py0000664000175000017500000007402115167130205020314 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from enum import Enum, auto
from os.path import commonpath
from pathlib import Path
from codetiming import Timer

from loki.batch.configure import SchedulerConfig
from loki.batch.item import (
    FileItem, ModuleItem, ProcedureItem, ProcedureBindingItem,
    InterfaceItem, TypeDefItem, ExternalItem
)
from loki.batch.item_factory import ItemFactory
from loki.batch.pipeline import Pipeline
from loki.batch.sfilter import SFilter
from loki.batch.sgraph import SGraph
from loki.batch.transformation import Transformation

from loki.frontend import FP, REGEX, RegexParserClass
from loki.tools import as_tuple, CaseInsensitiveDict, flatten

from loki.logging import info, perf, warning, error

__all__ = ['ProcessingStrategy', 'Scheduler']


class ProcessingStrategy(Enum):
    """
    List of available processing types for :any:`Scheduler.process`

    Multiple options exist how the :any:`Scheduler.process` method can
    apply a provided :any:`Transformation` or :any:`Pipeline` object to the
    items in a :any:`Scheduler` graph. The permissible options and default
    values are provided by this class.
    """

    SEQUENCE = auto()
    """Sequential processing of transformations

    For each transformation in a pipeline, the :any:`Transformation.apply`
    method is called for every item in the graph, following the graph traversal
    mode specified in the transformation's manifest, before repeating the
    same for the next transformation in the pipeline.
    """

    PLAN = auto()
    """Planning mode using :any:`ProcessingStrategy.SEQUENCE` strategy.

    This calls :any:`Transformation.plan` (instead of :any:`Transformation.apply`)
    for each transformation.
    """

    DEFAULT = SEQUENCE
    """Default processing strategy, currently :any:`ProcessingStrategy.SEQUENCE`"""


class Scheduler:
    """
    Work queue manager to discover and capture dependencies for a given
    call tree, and apply transformations for batch processing

    Using a given list of :data:`paths` and :data:`seed_routines` (can be
    inferred from :data:`config`), all scopes and symbols defined in the
    source tree are discovered and a dependency graph is created. The nodes
    of the dependency graph are :any:`Item` objects, each corresponding to
    a specific Loki IR node.

    The dependency graph is re-generated on-the-fly and can be filtered for
    specific dependency classes during traversals (see :any:`SFilter`).
    All items, are stored in the cache of the associated :attr:`item_factory`.

    Under the hood, the Scheduler is initialised in a three-stage procedure:

    1. A `discovery` step, where the minimum top-level definitions (modules
       and procedures) are determined for every source file in the path list.
    2. A `populate` step, which instantiates a first :any:`SGraph` dependency
       graph by chasing dependencies starting from the provided seed nodes.
    3. Optionally, a full parse is triggered for all :any:`Sourcefile` that
       appear in an :any:`Item` in the dependency graph.

    This first two stages rely on an incremental, incomplete parsing of the
    source files that extract only the minimum set of symbols in each file.
    This is driven by the :any:`Item._parser_class` and :any:`Item._depends_class`
    attributes, which declare the minimum :any:`RegexParserClass` classes to
    use with the :any:`REGEX` frontend.

    To discover dependencies, the item's IR is `concretized`. This calls
    ``Scope.make_complete`` with the minimum set of additional parser classes
    (as defined in the ``_parser_class`` attribute) that are required to discover
    the dependencies (e.g., calls, imports).
    When creating the corresponding dependency's item, the defining scope's item
    (e.g., a module containing a derived type declaration) is queried for its
    ``definitions``, which in turn may trigger also a `concretize` step with the
    ``_parser_class`` of all item types that are listed in the scope item's
    ``defines_items`` attribute.

    A :any:`Transformation` can be applied across all nodes in the dependency
    graph using the :meth:`process` method. The class properties in the
    transformation implementation (such as :any:`Transformation.reverse_traversal`,
    :any:`Transformation.traverse_file_graph` or :any:`Transformation.item_filter`)
    determine, what nodes should be processed.

    Attributes
    ----------
    config : :any:`SchedulerConfig`
        The config object describing the Scheduler's behaviour
    full_parse : bool
        Flag to indicate a full parse of scheduler items
    paths : list of :any:`pathlib.Path`
        List of paths where sourcefiles are searched
    seeds : list of str
        Names of seed routines that are the root of dependency graphs
    build_args : dict
        List of frontend arguments that are given to :any:`Sourcefile.from_file`
        when performing a full parse
    item_factory : :any:`ItemFactory`
        Instance of the factory class for :any:`Item` creation and caching

    Parameters
    ----------
    paths : str or list of str
        List of paths to search for automated source file detection.
    config : dict or str, optional
        Configuration dict or path to scheduler configuration file
    seed_routines : list of str, optional
        Names of routines from which to populate the callgraph initially.
        If not provided, these will be inferred from the given config.
    preprocess : bool, optional
        Flag to trigger CPP preprocessing (by default `False`).
    includes : list of str, optional
        Include paths to pass to the C-preprocessor.
    defines : list of str, optional
        Symbol definitions to pass to the C-preprocessor.
    definitions : list of :any:`Module`, optional
        :any:`Module` object(s) that may supply external type or procedure
        definitions.
    xmods : str, optional
        Path to directory to find and store ``.xmod`` files when using
        the OMNI frontend.
    omni_includes: list of str, optional
        Additional include paths to pass to the preprocessor run as part of
        the OMNI frontend parse. If set, this **replaces** (!)
        :data:`includes`, otherwise :data:`omni_includes` defaults to the
        value of :data:`includes`.
    full_parse: bool, optional
        Flag indicating whether a full parse of all sourcefiles is required.
        By default a full parse is executed, use this flag to suppress.
    frontend : :any:`Frontend`, optional
        Frontend to use for full parse of source files (default :any:`FP`).
    output_dir : str or path
        Directory for the output to be written to
    """

    # TODO: Should be user-definable!
    source_suffixes = ['.f90', '.F90', '.f', '.F']

    def __init__(self, paths, config=None, seed_routines=None, preprocess=False,
                 includes=None, defines=None, definitions=None, xmods=None,
                 omni_includes=None, full_parse=True, frontend=FP, output_dir=None):
        # Derive config from file or dict
        if isinstance(config, SchedulerConfig):
            self.config = config
        elif isinstance(config, (str, Path)):
            self.config = SchedulerConfig.from_file(config)
        else:
            self.config = SchedulerConfig.from_dict(config or {})

        self.full_parse = full_parse

        # Build-related arguments to pass to the sources
        self.paths = [Path(p) for p in as_tuple(paths)]
        self.seeds = tuple(
            seed.lower()
            for seed in as_tuple(seed_routines) or self.config.routines.keys()
        )

        # Accumulate all build arguments to pass to `Sourcefile` constructors
        self.build_args = {
            'definitions': definitions,
            'preprocess': preprocess,
            'includes': includes,
            'defines': defines,
            'xmods': xmods,
            'omni_includes': omni_includes,
            'frontend': frontend,
            'output_dir': output_dir
        }

        # Internal data structures to store the callgraph
        self.item_factory = ItemFactory()

        self._discover()

        if self.full_parse:
            self._parse_items()

            # Attach interprocedural call-tree information
            self._enrich()

    @Timer(logger=info, text='[Loki::Scheduler] Performed initial source scan in {:.2f}s')
    def _discover(self):
        """
        Scan all source paths and create light-weight :any:`Sourcefile` objects for each file
        """
        frontend_args = {
            'preprocess': self.build_args['preprocess'],
            'includes': self.build_args['includes'],
            'defines': self.build_args['defines'],
            'parser_classes': RegexParserClass.ProgramUnitClass,
            'frontend': REGEX
        }

        # Create a list of initial files to scan with the fast REGEX frontend
        path_list = [
            path.glob(f'**/*{ext}') if path.is_dir() else path
            for path in self.paths for ext in self.source_suffixes
        ]
        path_list = list(set(flatten(path_list)))  # Filter duplicates and flatten

        # Instantiate FileItem instances for all files in the search path
        for path in path_list:
            self.item_factory.get_or_create_file_item_from_path(path, self.config, frontend_args)

        # Instantiate the basic list of items for files and top-level program units
        #  in each file, i.e., modules and subroutines
        #  Note that we do this separate from the FileItem instantiation above to enable discovery
        #  also for FileItems that have been created as part of a transformation
        file_items = [
            file_item
            for file_item in self.item_factory.item_cache.values()
            if isinstance(file_item, FileItem)
        ]
        for file_item in file_items:
            definition_items = {
                item.name: item
                for item in file_item.create_definition_items(item_factory=self.item_factory, config=self.config)
            }
            self.item_factory.item_cache.update(definition_items)

        # (Re-)build the SGraph after discovery for later traversals
        self._sgraph = SGraph.from_seed(self.seeds, self.item_factory, self.config)

    @property
    def sgraph(self):
        """
        Create and return the :any:`SGraph` constructed from the :attr:`seeds` of the Scheduler.
        """
        return self._sgraph

    @property
    def items(self):
        """
        All :any:`Item` objects contained in the :any:`Scheduler` dependency graph.
        """
        return self.sgraph.items

    @property
    def dependencies(self):
        """
        All individual pairs of :any:`Item` that represent a dependency
        and form an edge in the :any`Scheduler` call graph.
        """
        return self.sgraph.dependencies

    @property
    def definitions(self):
        """
        The list of definitions that the source files in the callgraph provide
        """
        return tuple(
            definition
            for item in self.file_graph
            for definition in item.definitions
        )

    @property
    def file_graph(self):
        """
        Alternative dependency graph based on relations between source files

        Returns
        -------
        :any:`SGraph`
            A dependency graph containing only :any:`FileItem` nodes
        """
        item_filter = None if self.config.enable_imports else ProcedureItem
        return self.sgraph.as_filegraph(
            self.item_factory, self.config, item_filter=item_filter
        )

    def __getitem__(self, name):
        """
        Find and return an item in the Scheduler's dependency graph
        """
        for item in self.items:
            if item == name:
                return item
        return None

    def __iter__(self):
        return self.sgraph._graph.__iter__()

    @Timer(logger=info, text='[Loki::Scheduler] Performed full source parse in {:.2f}s')
    def _parse_items(self):
        """
        Prepare processing by triggering a full parse of the items in
        the execution plan and enriching subroutine calls.
        """
        # Force the parsing of the routines
        default_frontend_args = self.build_args.copy()
        default_frontend_args['definitions'] = as_tuple(default_frontend_args['definitions']) + self.definitions
        for item in SFilter(self.file_graph, reverse=True):
            frontend_args = self.config.create_frontend_args(item.name, default_frontend_args)
            item.source.make_complete(**frontend_args)

        # Re-build the SGraph after parsing to pick up all new connections
        self._sgraph = SGraph.from_seed(self.seeds, self.item_factory, self.config)

    @Timer(logger=perf, text='[Loki::Scheduler] Enriched call tree in {:.2f}s')
    def _enrich(self):
        """
        For items that have a specific enrichment list provided as part of their
        config, try to provide this information
        """
        definitions = self.definitions
        for item in SFilter(self.sgraph, item_filter=ProcedureItem):
            # Enrich with the definitions of the scheduler's graph and meta-info from outside the callgraph
            enrich_definitions = definitions
            for name in as_tuple(item.enrich):
                enrich_items = as_tuple(
                    self.sgraph._create_item(name, item_factory=self.item_factory, config=self.config)
                )
                for enrich_item in enrich_items:
                    frontend_args = self.config.create_frontend_args(enrich_item.source.path, self.build_args)
                    enrich_item.source.make_complete(**frontend_args)
                enrich_definitions += tuple(item_.ir for item_ in enrich_items)
            item.ir.enrich(enrich_definitions, recurse=True)

    def rekey_item_cache(self):
        """
        Rebuild the mapping of item names to items in the :attr:`item_factory`'s cache

        This is required when a :any:`Transformation` renames items during processing,
        and is triggered automatically at the end of the :meth:`process` method if
        the transformation has :any:`Transformation.renames_items` specified.

        This update also updates :attr:`config` entries that are affected by the renaming.
        """
        # Find invalid item cache entries
        renamed_keys = {
            key: item.name for key, item in self.item_factory.item_cache.items()
            if item.name != key
        }

        # Find deleted item cache entries
        deleted_keys = set()
        for key, item in self.item_factory.item_cache.items():
            if isinstance(item, FileItem):
                continue
            if isinstance(item, ModuleItem):
                if item.name not in renamed_keys and item.name not in item.source:
                    # The module was in a file (likely with something else) and has been deleted
                    deleted_keys.add(key)
            elif not isinstance(item, ExternalItem):
                if not item.scope_name:
                    # IR node without a scope (i.e., a Procedure without a module)
                    if item.local_name not in item.source:
                        # ...has been deleted
                        deleted_keys.add(key)
                elif item.scope_name in renamed_keys:
                    # The parent module has been renamed...
                    if item.local_name not in item.source[renamed_keys[item.scope_name]]:
                        # ...and the contained item has been deleted from the module
                        deleted_keys.add(key)
                else:
                    if item.scope_name not in item.source:
                        # The parent module has been removed
                        deleted_keys.add(key)

        # Rename item scopes where necessary
        for key, item in self.item_factory.item_cache.items():
            if item.scope_name in renamed_keys and key not in deleted_keys:
                item.name = f'{renamed_keys[item.scope_name]}#{item.local_name}'
                renamed_keys[key] = item.name

        # Search for invalid item cache keys in config entries
        for old_name, new_name in renamed_keys.items():
            if matched_keys := self.config.match_item_keys(old_name, self.config.routines):
                for key in matched_keys:
                    self.config.routines[new_name] = self.config.routines[key].copy()
                    del self.config.routines[key]
            if matched_keys := self.config.match_item_keys(old_name, self.seeds):
                self.seeds = tuple(
                    new_name if seed in matched_keys else seed
                    for seed in self.seeds
                )

        # Find FileItem cache entries for renamed cache entries and rename them.
        # This allows to clone program units _and_ use references to unmodified program units
        # in the original file within the same SGraph. The unmodified program units are
        # re-discovered when running _discover() afterwards.
        if renamed_keys:
            for key, file_item in self.item_factory.item_cache.items():
                if isinstance(file_item, FileItem):
                    if any(file_item.source is self.item_factory.item_cache[key].source for key in renamed_keys):
                        file_item.name = f'duplicate of {file_item.name}'
                        renamed_keys[key] = file_item.name

        # Rebuild item_cache to make keys match entries
        self.item_factory.item_cache = CaseInsensitiveDict(
            (item.name, item) for item in self.item_factory.item_cache.values()
            if item.name not in deleted_keys
        )

    def process(self, transformation, proc_strategy=ProcessingStrategy.DEFAULT):
        """
        Process all :attr:`items` in the scheduler's graph with either
        a :any:`Pipeline` or a single :any:`Transformation`.

        A single :any:`Transformation` pass invokes
        :meth:`process_transformation` individually, while a
        :any:`Pipeline` will apply each contained transformation in
        turn over the full dependency graph of the scheduler.

        Parameters
        ----------
        transformation : :any:`Transformation` or :any:`Pipeline`
            The transformation or transformation pipeline to apply
        proc_strategy : :any:`ProcessingStrategy`
            The processing strategy to use when applying the given
            :data:`transformation` to the scheduler's graph.
        """
        if isinstance(transformation, Transformation):
            self.process_transformation(transformation=transformation, proc_strategy=proc_strategy)

        elif isinstance(transformation, Pipeline):
            self.process_pipeline(pipeline=transformation, proc_strategy=proc_strategy)

        else:
            error('[Loki::Scheduler] Batch processing requires Transformation or Pipeline object')
            raise RuntimeError(f'Could not batch process {transformation}')

    def process_pipeline(self, pipeline, proc_strategy=ProcessingStrategy.DEFAULT):
        """
        Process a given :any:`Pipeline` by applying its assocaited
        transformations in turn.

        Parameters
        ----------
        transformation : :any:`Pipeline`
            The transformation pipeline to apply
        proc_strategy : :any:`ProcessingStrategy`
            The processing strategy to use when applying the given
            :data:`pipeline` to the scheduler's graph.
        """
        for transformation in pipeline.transformations:
            self.process_transformation(transformation, proc_strategy=proc_strategy)

    def process_transformation(self, transformation, proc_strategy=ProcessingStrategy.DEFAULT):
        """
        Process all :attr:`items` in the scheduler's graph

        By default, the traversal is performed in topological order, which
        ensures that an item is processed before the items it depends upon
        (e.g., via a procedure call)
        This order can be reversed in the :any:`Transformation` manifest by
        setting :any:`Transformation.reverse_traversal` to ``True``.

        The scheduler applies the transformation to the program unit scope corresponding to
        each item in the scheduler's graph, determined by the :any:`Item.transformation_ir`
        property. For example, for a :any:`ProcedureItem`, the transformation is
        applied to the corresponding :any:`Subroutine` object.

        Optionally, the traversal can be performed on a source file level only,
        if the transformation has set :any:`Transformation.traverse_file_graph`
        to ``True``. This uses the :attr:`filegraph` to process the dependency tree.
        If combined with a :any:`Transformation.item_filter`, only source files with
        at least one object corresponding to an item of that type are processed.

        Parameters
        ----------
        transformation : :any:`Transformation`
            The transformation to apply over the dependency tree
        proc_strategy : :any:`ProcessingStrategy`
            The processing strategy to use when applying the given
            :data:`transformation` to the scheduler's graph.
        """
        def _get_definition_items(_item, sgraph_items):
            # For backward-compatibility with the DependencyTransform and LinterTransformation
            if not transformation.traverse_file_graph:
                return None

            # Recursively obtain all definition items but exclude any that are not part of the original SGraph
            items = ()
            for item in _item.create_definition_items(item_factory=self.item_factory, config=self.config):
                # Recursion gives us only items that are included in the SGraph, or the parent scopes
                # of items included in the SGraph
                child_items = _get_definition_items(item, sgraph_items)
                # If the current item has relevant children, or is included in the SGraph itself, we
                # include it in the list of items
                if child_items or item in sgraph_items:
                    if transformation.process_ignored_items or not item.is_ignored:
                        items += (item,) + child_items
            return items

        if proc_strategy not in (ProcessingStrategy.SEQUENCE, ProcessingStrategy.PLAN):
            error(f'[Loki::Scheduler] Processing {proc_strategy} is not implemented!')
            raise RuntimeError(f'Could not batch process {transformation}')

        trafo_name = transformation.__class__.__name__
        log = f'[Loki::Scheduler] Applied transformation <{trafo_name}>' + ' in {:.2f}s'
        with Timer(logger=info, text=log):

            # Extract the graph iteration properties from the transformation
            item_filter = as_tuple(transformation.item_filter)
            if transformation.traverse_file_graph:
                sgraph = self.sgraph
                graph = sgraph.as_filegraph(
                    self.item_factory, self.config, item_filter=item_filter,
                    exclude_ignored=not transformation.process_ignored_items
                )
                sgraph_items = sgraph.items
                traversal = SFilter(
                    graph, reverse=transformation.reverse_traversal,
                    include_external=self.config.default.get('strict', True)
                )
            else:
                graph = self.sgraph
                sgraph_items = graph.items
                traversal = SFilter(
                    graph, item_filter=item_filter, reverse=transformation.reverse_traversal,
                    exclude_ignored=not transformation.process_ignored_items,
                    include_external=self.config.default.get('strict', True)
                )

            # Collect common transformation arguments
            kwargs = {
                'depths': graph.depths,
                'build_args': self.build_args,
                'plan_mode': proc_strategy == ProcessingStrategy.PLAN,
            }

            if transformation.renames_items or transformation.creates_items:
                kwargs['item_factory'] = self.item_factory
                kwargs['scheduler_config'] = self.config

            for _item in traversal:
                if isinstance(_item, ExternalItem):
                    if kwargs['plan_mode'] and _item.is_generated:
                        continue
                    raise RuntimeError(f'Cannot apply {trafo_name} to {_item.name}: Item is marked as external.')

                transformation.apply(
                    _item.transformation_ir, item=_item, items=_get_definition_items(_item, sgraph_items),
                    sub_sgraph=graph.get_sub_sgraph(_item, item_filter=item_filter),
                    role=_item.role, mode=_item.mode, targets=_item.targets,
                    **kwargs
                )

        if transformation.renames_items:
            self.rekey_item_cache()

        if transformation.creates_items:
            self._discover()
            if self.full_parse:
                self._parse_items()

    def callgraph(self, path, with_file_graph=False, with_legend=False):
        """
        Generate a callgraph visualization and dump to file.

        Parameters
        ----------
        path : str or pathlib.Path
            Path to write the callgraph figure to.
        with_filegraph : bool or str or pathlib.Path
            Visualize file dependencies in an additional file. Can be set to `True` or a file path to write to.
        with_legend : bool
            Add a key/legend to the plot. Can be enabled by setting the argument to `True`.
        """
        try:
            import graphviz as gviz  # pylint: disable=import-outside-toplevel
        except ImportError:
            warning('[Loki] Failed to load graphviz, skipping callgraph generation...')
            return

        item_colors = {
            FileItem: '#c0c0c0',       # gray
            ModuleItem: '#2080ff',     # blue
            ProcedureItem: '#60e080',  # green
            TypeDefItem: '#ffc832',    # yellow
            InterfaceItem: '#c0ff40',  # light-green
            ProcedureBindingItem: '#00dcc8', # turquoise
            ExternalItem: '#dc2000'    # red
        }

        cg_path = Path(path)
        callgraph = gviz.Digraph(format='pdf', strict=True, graph_attr=(('rankdir', 'LR'),))

        node_style = {
            'color': 'black',
            'shape': 'box',
            'style': 'filled'
        }

        # Insert all nodes in the schedulers graph
        for item in self.items:
            style = node_style.copy()
            alpha_channel = '33' if item.is_ignored else 'ff'
            style['fillcolor'] = item_colors.get(type(item), '#333333') + alpha_channel
            if item.replicate:
                style['shape'] = 'diamond'
                style['style'] += ',rounded'
            callgraph.node(item.name.upper(), **style)

        # Insert all edges in the schedulers graph
        for parent, child in self.dependencies:
            callgraph.edge(parent.name.upper(), child.name.upper())

        # Insert all nodes we were told to either block or ignore
        for item in self.items:
            blocked_children = set(item.targets_and_blocked_targets) - set(item.targets)
            for child in blocked_children:
                style = node_style.copy()
                style['fillcolor'] = '#ff141499'  # light red
                callgraph.node(child.upper(), **style)
                callgraph.edge(item.name.upper(), child.upper())

        if with_legend:
            for cls, color in item_colors.items():
                style = node_style.copy()
                style['fillcolor'] = color
                callgraph.node(cls.__name__, **style)

        try:
            callgraph.render(cg_path, view=False)
        except gviz.ExecutableNotFound as e:
            warning(f'[Loki] Failed to render callgraph due to graphviz error:\n  {e}')

        if with_file_graph:
            if with_file_graph is True:
                fg_path = cg_path.with_name(f'{cg_path.stem}_file_graph{cg_path.suffix}')
            else:
                fg_path = Path(with_file_graph)
            fg = gviz.Digraph(format='pdf', strict=True, graph_attr=(('rankdir', 'LR'),))
            file_graph = self.file_graph

            basedir = commonpath(item.name for item in file_graph.items)
            name_offset = len(basedir) + 1 if len(basedir) > 0 else 0

            for item in file_graph:
                style = node_style.copy()
                alpha_channel = '33' if item.is_ignored else 'ff'
                style['fillcolor'] = item_colors.get(type(item), '#333333') + alpha_channel
                if item.replicate:
                    style['shape'] = 'diamond'
                    style['style'] += ',rounded'
                fg.node(str(item.name)[name_offset:], **style)

            for parent, child in file_graph.dependencies:
                fg.edge(str(parent.name)[name_offset:], str(child.name)[name_offset:])

            try:
                fg.render(fg_path, view=False)
            except gviz.ExecutableNotFound as e:
                warning(f'[Loki] Failed to render filegraph due to graphviz error:\n  {e}')

    @Timer(logger=perf, text='[Loki::Scheduler] Wrote CMake plan file in {:.2f}s')
    def write_cmake_plan(self, filepath, rootpath=None):
        """
        Generate the "plan file" for CMake

        See :any:`CMakePlanTransformation` for the specification of that file.

        Parameters
        ----------
        filepath : str or Path
            The path of the CMake file to write.
        rootpath : str or Path (optional)
            If given, all paths in the CMake file will be made relative to this root directory
        """
        info(f'[Loki] Scheduler writing CMake plan: {filepath}')

        from loki.transformations.build_system.plan import CMakePlanTransformation  # pylint: disable=import-outside-toplevel
        planner = CMakePlanTransformation(rootpath=rootpath)
        self.process(planner, proc_strategy=ProcessingStrategy.PLAN)
        planner.write_plan(filepath)
loki-ecmwf-0.3.6/loki/batch/pipeline.py0000664000175000017500000001203015167130205020133 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from inspect import signature, Parameter

from loki.batch.transformation import Transformation
from loki.tools import as_tuple, flatten


class Pipeline:
    """
    A transformation pipeline that combines multiple :any:`Transformation`
    passes and allows to apply them in unison.

    The associated :any:`Transformation` objects are constructed from keyword
    arguments in the constructor, so shared keywords get same initial value.

    Attributes
    ----------
    transformations : list of :any:`Transformation`
        The list of transformations applied to a source in this pipeline

    Parameters
    ----------
    classes : tuple of types
        A tuple of types from which to instantiate :any:`Transformation` objects.
    *args : optional
        Positional arguments that are passed on to the constructors of
        all transformations
    **kwargs : optional
        Keyword arguments that are matched to the constructor
        signature of the transformations.
    """

    def __init__(self, *args, classes=None, **kwargs):
        self.transformations = []
        for cls in as_tuple(classes):

            # Get all relevant constructor parameters from teh MRO,
            # but exclude catch-all keyword args, like ``**kwargs``
            t_parameters = {
                k: v for c in cls.__mro__ for k, v in signature(c).parameters.items()
                if not v.kind == Parameter.VAR_KEYWORD
            }
            # Filter kwargs for this transformation class specifically
            t_kwargs = {k: v for k, v in kwargs.items() if k in t_parameters}

            # We need to apply our own default, if we are to honour inheritance
            t_kwargs.update({
                k: param.default for k, param in t_parameters.items()
                if k not in t_kwargs and param.default is not None
            })

            # Then instantiate with the default *args and the derived **t_kwargs
            self.transformations.append(cls(*args, **t_kwargs))

    def __str__(self):
        """ Pretty-print pipeline details """
        trafo_str = '\n  '.join(flatten(str(t).splitlines() for t in self.transformations))
        return f'<{self.__class__.__name__}\n  {trafo_str}\n>'

    def __add__(self, other):
        """ Support native addition via ``+`` operands """
        if isinstance(other, Transformation):
            self.append(other)
            return self
        if isinstance(other, Pipeline):
            self.extend(other)
            return self
        raise TypeError(f'[Loki::Pipeline] Can not append {other} to pipeline!')

    def __radd__(self, other):
        """ Support native addition via ``+`` operands """
        if isinstance(other, Transformation):
            self.prepend(other)
            return self
        if isinstance(other, Pipeline):
            other.extend(self)
            return other
        raise TypeError(f'[Loki::Pipeline] Can not append {other} to pipeline!')

    def prepend(self, transformation):
        """
        Prepend a fully instantiated :any:`Transformation` object to this pipeline.

        Parameters
        ----------
        transformation : :any:`Transformation`
            Transformation object to prepend
        """
        assert isinstance(transformation, Transformation)

        self.transformations.insert(0, transformation)

    def append(self, transformation):
        """
        Append a fully instantiated :any:`Transformation` object to this pipeline.

        Parameters
        ----------
        transformation : :any:`Transformation`
            Transformation object to append
        """
        assert isinstance(transformation, Transformation)

        self.transformations.append(transformation)

    def extend(self, pipeline):
        """
        Append all :any`Transformation` objects of a given :any:`Pipeline`

        Parameters
        ----------
        pipeline : :any:`Pipeline`
            Pipeline whose transformations will be appended
        """
        assert isinstance(pipeline, Pipeline)

        self.transformations.extend(pipeline.transformations)

    def apply(self, source, **kwargs):
        """
        Apply each associated :any:`Transformation` to :data:`source`

        It dispatches to the respective :meth:`apply` of each
        :any:`Transformation` in the order specified in the constructor.

        Parameters
        ----------
        source : :any:`Sourcefile` or :any:`Module` or :any:`Subroutine`
            The source item to transform.
        **kwargs : optional
            Keyword arguments that are passed on to the methods defining the
            actual transformation.
        """
        for trafo in self.transformations:
            trafo.apply(source, **kwargs)
loki-ecmwf-0.3.6/loki/types/0000775000175000017500000000000015167130205016043 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/types/__init__.py0000664000175000017500000000134115167130205020153 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

""" 
The Loki type system to store and shared type information on symbols and IR nodes.
"""

from loki.types.datatypes import *  # noqa
from loki.types.derived_type import *  # noqa
from loki.types.module_type import *  # noqa
from loki.types.procedure_type import *  # noqa
from loki.types.scope import *  # noqa
from loki.types.symbol_table import *  # noqa
loki-ecmwf-0.3.6/loki/types/tests/0000775000175000017500000000000015167130205017205 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/types/tests/__init__.py0000664000175000017500000000057015167130205021320 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
loki-ecmwf-0.3.6/loki/types/tests/test_derived_types.py0000664000175000017500000014515515167130205023477 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

# pylint: disable=too-many-lines
from sys import getrecursionlimit
from inspect import stack

import re
import pytest
import numpy as np

from loki import Module, Subroutine, fgen
from loki.expression import symbols as sym
from loki.frontend import available_frontends, OMNI
from loki.ir import nodes as ir, FindNodes, FindVariables
from loki.jit_build import Builder, jit_compile, jit_compile_lib, Obj
from loki.types import BasicType, DerivedType, ProcedureType


@pytest.fixture(name='builder')
def fixture_builder(tmp_path):
    yield Builder(source_dirs=tmp_path, build_dir=tmp_path)
    Obj.clear_cache()


@pytest.mark.parametrize('frontend', available_frontends())
def test_derived_type_symbols_nested(tmp_path, frontend):
    """ Test basic scoping behaviour of nested derived types """

    fcode_mod = """
module my_types_mod
  implicit none

  type inner
    integer :: here
  end type inner

  type outer
    type(inner) :: was
    real(kind=4) :: red_herring
  end type outer
contains

  subroutine test_der_type(rick, dave)
    type(outer), intent(inout) :: rick, dave

    rick%red_herring = 42.0
    rick%was%here = 67

    dave%red_herring = 66.6
    dave%was%here = 6711
  end subroutine test_der_type
end module my_types_mod
"""
    module = Module.from_source(fcode_mod, frontend=frontend, xmods=[tmp_path])
    routine = module['test_der_type']

    assert len(routine.variables) == 2
    rick = routine.variable_map['rick']
    dave = routine.variable_map['dave']
    assert rick and dave and rick.type.dtype == dave.type.dtype

    vs = list(FindVariables().visit(routine.body))
    assert vs[0] == 'rick' == rick
    assert vs[1] == 'rick%red_herring' and vs[1].scope == routine
    assert vs[2] == 'rick%was' and vs[2].parent == rick and vs[2].scope == routine
    assert vs[3] == 'rick%was%here' and vs[3].parent == vs[2] and vs[3].scope == routine
    assert vs[4] == 'dave' == dave
    assert vs[5] == 'dave%red_herring' and vs[5].scope == routine
    assert vs[6] == 'dave%was' and vs[6].parent == dave and vs[6].scope == routine
    assert vs[7] == 'dave%was%here' and vs[7].parent == vs[6] and vs[7].scope == routine


@pytest.mark.parametrize('frontend', available_frontends())
def test_derived_type_arithmetic(tmp_path, frontend):
    """ Test simple vector/matrix arithmetic with a derived type via a JIT compile """

    fcode = """
module derived_types_mod
  integer, parameter :: jprb = selected_real_kind(13,300)

  type explicit
    real(kind=jprb) :: scalar, vector(3), matrix(3, 3)
    real(kind=jprb) :: red_herring
  end type explicit

  type nested
    real(kind=jprb) :: a_scalar, a_vector(3)
    type(explicit) :: another_item
  end type nested
contains

  subroutine simple_loops(item, item2, item3)
    type(explicit), intent(inout) :: item, item2
    type(nested), intent(inout) :: item3
    real(kind=jprb) :: vals(3) = (/ 1., 2., 3. /)
    integer :: i, j, n

    n = 3
    do i=1, n  ! Explicit vector loop
       item%vector(i) = item%vector(i) + item%scalar
    end do

    do j=1, n  ! Explicit two-level vector loops
       do i=1, n
          item%matrix(i, j) = item%matrix(i, j) + item%scalar
       end do
    end do

    item2%vector(:) = 666.
    do i=1, 3  ! Use of vector notations
       item2%matrix(:, i) = vals(i)
    end do

    ! Test vector notation on nested derived types
    item3%a_vector(:) = 666.
    item3%another_item%vector(:) = 999.
    do i=1, 3
       item3%another_item%matrix(:, i) = vals(i)
    end do
  end subroutine simple_loops
end module
"""
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    routine = module['simple_loops']

    # Ensure type info is attached correctly
    item_vars = [v for v in FindVariables(unique=False).visit(routine.body) if v.parent]
    assert all(v.type.dtype is BasicType.REAL or isinstance(v.type.dtype, DerivedType) for v in item_vars)
    assert item_vars[0].name == 'item%vector' and item_vars[0].shape == (3,)
    assert item_vars[1].name == 'item%vector' and item_vars[1].shape == (3,)
    assert item_vars[2].name == 'item%scalar' and item_vars[2].type.shape is None
    assert item_vars[3].name == 'item%matrix' and item_vars[3].shape == (3, 3)
    assert item_vars[4].name == 'item%matrix' and item_vars[4].shape == (3, 3)
    assert item_vars[5].name == 'item%scalar' and item_vars[5].type.shape is None
    assert item_vars[6].name == 'item2%vector' and item_vars[6].shape == (3,)
    assert item_vars[7].name == 'item2%matrix' and item_vars[7].shape == (3, 3)
    assert item_vars[8].name == 'item3%a_vector' and item_vars[8].shape == (3,)
    assert item_vars[9].name == 'item3%another_item' and item_vars[9].type.dtype.typedef == module['explicit']
    assert item_vars[10].name == 'item3%another_item%vector' and item_vars[10].shape == (3,)
    assert item_vars[11].name == 'item3%another_item' and item_vars[11].type.dtype.typedef == module['explicit']
    assert item_vars[12].name == 'item3%another_item%matrix' and item_vars[12].shape == (3, 3)

    # JIT-compile the module and create input objects
    filepath = tmp_path/(f'derived_types_simple_loops_{frontend}.f90')
    mod = jit_compile(module, filepath=filepath, objname='derived_types_mod')

    item, item2, item3 = mod.explicit(), mod.explicit(), mod.nested()
    item.scalar = 2.
    item.vector[:] = 5.
    item.matrix[:, :] = 4.

    # Execute compiled code and check constructed outputs
    mod.simple_loops(item, item2, item3)
    assert (item.vector == 7.).all() and (item.matrix == 6.).all()
    assert (item2.vector == 666.).all()
    assert (item2.matrix == np.array([[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]])).all()
    assert (item3.a_vector == 666.).all()
    assert (item3.another_item.vector == 999.).all()
    assert (item3.another_item.matrix == np.array([[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]])).all()


@pytest.mark.parametrize('frontend', available_frontends())
def test_derived_type_deferred_arrays(tmp_path, frontend):
    """
    Test simple vector/matrix arithmetic with a derived type
    with dynamically allocated arrays via JIT-compilation.
    """

    fcode = """
module derived_types_mod
  integer, parameter :: jprb = selected_real_kind(13,300)

  type deferred
    real(kind=jprb), allocatable :: scalar, vector(:), matrix(:, :)
    real(kind=jprb), allocatable :: red_herring
  end type deferred
contains

  subroutine alloc_deferred(item)
    type(deferred), intent(inout) :: item
    allocate(item%vector(3))
    allocate(item%matrix(3, 3))
  end subroutine alloc_deferred

  subroutine free_deferred(item)
    type(deferred), intent(inout) :: item
    deallocate(item%vector)
    deallocate(item%matrix)
  end subroutine free_deferred

  subroutine test_deferred_arrays(item)
    type(deferred), intent(inout) :: item
    real(kind=jprb) :: vals(3) = (/ 1., 2., 3. /)
    integer :: i

    item%vector(:) = 666.

    do i=1, 3
       item%matrix(:, i) = vals(i)
    end do
  end subroutine test_deferred_arrays

  subroutine test_deferred_arrays_with_temporary(item)
    type(deferred), intent(inout) :: item
    type(deferred), allocatable :: item2(:)
    real(kind=jprb) :: vals(3) = (/ 1., 2., 3. /)
    integer :: i, j

    allocate(item2(4))

    do j=1, 4
      call alloc_deferred(item2(j))
      item2(j)%vector(:) = 666.
      do i=1, 3
        item2(j)%matrix(:, i) = vals(i)
      end do
    end do

    item%vector(:) = 0.
    item%matrix(:,:) = 0.
    do j=1, 4
      item%vector(:) = item%vector(:) + item2(j)%vector(:)
      do i=1, 3
          item%matrix(:,i) = item%matrix(:,i) + item2(j)%matrix(:,i)
      end do
      call free_deferred(item2(j))
    end do

    deallocate(item2)
  end subroutine test_deferred_arrays_with_temporary
end module
"""
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    filepath = tmp_path/(f'derived_types_array_indexing_deferred_{frontend}.f90')
    mod = jit_compile(module, filepath=filepath, objname='derived_types_mod')

    item = mod.deferred()
    mod.alloc_deferred(item)
    mod.test_deferred_arrays(item)
    assert (item.vector == 666.).all()
    assert (item.matrix == np.array([[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]])).all()
    mod.free_deferred(item)

    item2 = mod.deferred()
    mod.alloc_deferred(item2)
    mod.test_deferred_arrays_with_temporary(item2)
    assert (item2.vector == 4 * 666.).all()
    assert (item2.matrix == 4 * np.array([[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]])).all()
    mod.free_deferred(item2)


@pytest.mark.parametrize('frontend', available_frontends())
def test_derived_type_caller(tmp_path, frontend):
    """
    Test a simple call to another routine specifying a derived type as argument
    """

    fcode = """
module derived_types_mod
  integer, parameter :: jprb = selected_real_kind(13,300)

  type explicit
    real(kind=jprb) :: scalar, vector(3), matrix(3, 3)
    real(kind=jprb) :: red_herring
  end type explicit
contains

  subroutine simple_loops(item)
    type(explicit), intent(inout) :: item
    integer :: i, j, n

    n = 3
    do i=1, n
       item%vector(i) = item%vector(i) + item%scalar
    end do

    do j=1, n
       do i=1, n
          item%matrix(i, j) = item%matrix(i, j) + item%scalar
       end do
    end do
  end subroutine simple_loops

  subroutine derived_type_caller(item)
    ! simple call to another routine specifying a derived type as argument
    type(explicit), intent(inout) :: item

    item%red_herring = 42.
    call simple_loops(item)
  end subroutine derived_type_caller

end module
"""
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    filepath = tmp_path/(f'derived_types_derived_type_caller_{frontend}.f90')
    mod = jit_compile(module, filepath=filepath, objname='derived_types_mod')

    # Test the generated identity
    item = mod.explicit()
    item.scalar = 2.
    item.vector[:] = 5.
    item.matrix[:, :] = 4.
    item.red_herring = -1.
    mod.derived_type_caller(item)
    assert (item.vector == 7.).all() and (item.matrix == 6.).all() and item.red_herring == 42.


@pytest.mark.parametrize('frontend', available_frontends())
def test_case_sensitivity(tmp_path, frontend):
    """
    Some abuse of the case agnostic behaviour of Fortran
    """

    fcode = """
module derived_types_mod
  integer, parameter :: jprb = selected_real_kind(13,300)

  type case_sensitive
    real(kind=jprb) :: u, v, T
    real(kind=jprb) :: q, A
  end type case_sensitive
contains

  subroutine check_case(item)
    type(case_sensitive), intent(inout) :: item

    item%u = 1.0
    item%v = 2.0
    item%t = 3.0
    item%q = -1.0
    item%A = -5.0
  end subroutine check_case
end module
"""
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    filepath = tmp_path/(f'derived_types_case_sensitivity_{frontend}.f90')
    mod = jit_compile(module, filepath=filepath, objname='derived_types_mod')

    item = mod.case_sensitive()
    item.u = 0.
    item.v = 0.
    item.t = 0.
    item.q = 0.
    item.a = 0.
    mod.check_case(item)
    assert item.u == 1.0 and item.v == 2.0 and item.t == 3.0
    assert item.q == -1.0 and item.a == -5.0


@pytest.mark.parametrize('frontend', available_frontends())
def test_derived_type_bind_c(frontend, tmp_path):
    # Example code from F2008, Note 15.13
    fcode = """
module derived_type_bind_c
    ! typedef struct {
    ! int m, n;
    ! float r;
    ! } myctype;

    USE, INTRINSIC :: ISO_C_BINDING
    TYPE, BIND(C) :: MYFTYPE
      INTEGER(C_INT) :: I, J
      REAL(C_FLOAT) :: S
    END TYPE MYFTYPE
end module derived_type_bind_c
    """.strip()

    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    myftype = module.typedef_map['myftype']
    assert myftype.bind_c is True
    assert ', BIND(C)' in fgen(myftype)


@pytest.mark.parametrize('frontend', available_frontends())
def test_derived_type_inheritance(frontend, tmp_path):
    fcode = """
module derived_type_private_mod
    implicit none

    type, abstract :: base_type
        integer :: val
    end type base_type

    type, extends(base_type) :: some_type
        integer :: other_val
    end type some_type

contains

    function base_proc(self) result(result)
        class(base_type) :: self
        integer :: result
        result = self%val
    end function base_proc

    function some_proc(self) result(result)
        class(some_type) :: self
        integer :: result
        result = self%val + self%other_val
    end function some_proc
end module derived_type_private_mod
    """.strip()

    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    base_type = module.typedef_map['base_type']
    some_type = module.typedef_map['some_type']

    # Verify correct properties on the `TypeDef` object
    assert base_type.abstract is True
    assert some_type.abstract is False

    assert base_type.extends is None
    assert some_type.extends.lower() == 'base_type'

    assert base_type.bind_c is False
    assert some_type.bind_c is False

    # Verify fgen
    assert 'type, abstract' in fgen(base_type).lower()
    assert 'extends(base_type)' in fgen(some_type).lower()


@pytest.mark.parametrize('frontend', available_frontends())
def test_derived_type_private(frontend, tmp_path):
    fcode = """
module derived_type_private_mod
    implicit none
    public
    TYPE, private :: PRIV_TYPE
      INTEGER :: I, J
    END TYPE PRIV_TYPE
end module derived_type_private_mod
    """.strip()

    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    priv_type = module.typedef_map['priv_type']
    assert priv_type.private is True
    assert priv_type.public is False
    assert ', PRIVATE' in fgen(priv_type)


@pytest.mark.parametrize('frontend', available_frontends())
def test_derived_type_public(frontend, tmp_path):
    fcode = """
module derived_type_public_mod
    implicit none
    private
    TYPE, public :: PUB_TYPE
      INTEGER :: I, J
    END TYPE PUB_TYPE
end module derived_type_public_mod
    """.strip()

    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    pub_type = module.typedef_map['pub_type']
    assert pub_type.public is True
    assert pub_type.private is False
    assert ', PUBLIC' in fgen(pub_type)


@pytest.mark.parametrize('frontend', available_frontends())
def test_derived_type_private_comp(frontend, tmp_path):
    fcode = """
module derived_type_private_comp_mod
    implicit none

    type, abstract :: base_type
        integer :: val
    end type base_type

    type, extends(base_type) :: some_private_comp_type
        private
        integer :: other_val
    contains
        procedure :: proc => other_proc
    end type some_private_comp_type

    type, extends(base_type) :: type_bound_proc_type
        integer :: other_val
    contains
        private
        procedure :: proc => other_proc
    end type type_bound_proc_type

contains

    function other_proc(self) result(result)
        class(type_bound_proc_type) :: self
        integer :: result
        result = self%val
    end function other_proc

end module derived_type_private_comp_mod
    """.strip()

    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    some_private_comp_type = module.typedef_map['some_private_comp_type']
    type_bound_proc_type = module.typedef_map['type_bound_proc_type']

    intrinsic_nodes = FindNodes(ir.Intrinsic).visit(type_bound_proc_type.body)
    assert len(intrinsic_nodes) == 2
    assert intrinsic_nodes[0].text.lower() == 'contains'
    assert intrinsic_nodes[1].text.lower() == 'private'

    assert re.search(
      r'^\s+contains$\s+private', fgen(type_bound_proc_type), re.I | re.MULTILINE
    ) is not None

    # OMNI gets the below wrong as it doesn't retain the private statement for components
    if frontend != OMNI:
        intrinsic_nodes = FindNodes(ir.Intrinsic).visit(some_private_comp_type.body)
        assert len(intrinsic_nodes) == 2
        assert intrinsic_nodes[0].text.lower() == 'private'
        assert intrinsic_nodes[1].text.lower() == 'contains'

        assert re.search(
            r'^\s+private*$(\s.*?){2}\s+contains', fgen(some_private_comp_type), re.I | re.MULTILINE
        ) is not None


@pytest.mark.parametrize('frontend', available_frontends())
def test_derived_type_procedure_designator(frontend, tmp_path):
    mcode = """
module derived_type_procedure_designator_mod
  implicit none
  type :: some_type
    integer :: val
  contains
    procedure :: SOME_PROC => some_TYPE_some_proc
    PROCEDURE :: some_FUNC => SOME_TYPE_SOME_FUNC
    PROCEDURE :: OTHER_PROC
  end type some_type

  TYPE other_type
    real :: val
  END TYPE other_type
contains
  subroutine some_type_some_proc(self, val)
    class(some_type) :: self
    integer, intent(in) :: val
    self%val = val
  end subroutine some_type_some_proc

  function some_type_some_func(self)
    integer :: some_type_some_func
    CLASS(SOME_TYPE) :: self
    some_type_some_func = self%val
  end function some_type_some_func

  subroutine other_proc(self)
    class(some_type) :: self
    self%val = self%val + 1
  end subroutine other_proc
end module derived_type_procedure_designator_mod
    """.strip()

    fcode = """
subroutine derived_type_procedure_designator(val)
  use derived_type_procedure_designator_mod
  implicit none
  integer, intent(out) :: val
  type(some_type) :: tp

  call tp%some_proc(3)
  val = tp%some_func()
end subroutine derived_type_procedure_designator
    """.strip()

    module = Module.from_source(mcode, frontend=frontend, xmods=[tmp_path])
    assert 'some_type' in module.typedef_map
    assert 'other_type' in module.typedef_map
    assert 'some_type' in module.symbol_attrs
    assert 'other_type' in module.symbol_attrs

    # First, with external definitions (generates xmod for OMNI)
    routine = Subroutine.from_source(fcode, frontend=frontend, definitions=[module], xmods=[tmp_path])

    for name in ('some_type', 'other_type'):
        assert name in routine.symbol_attrs
        assert routine.symbol_attrs[name].imported is True
        assert isinstance(routine.symbol_attrs[name].dtype, DerivedType)
        assert isinstance(routine.symbol_attrs[name].dtype.typedef, ir.TypeDef)

    # Make sure type-bound procedure declarations exist
    some_type = module.typedef_map['some_type']
    proc_decls = FindNodes(ir.ProcedureDeclaration).visit(some_type.body)
    assert len(proc_decls) == 3
    assert all(decl.interface is None for decl in proc_decls)

    proc_symbols = {s.name.lower(): s for d in proc_decls for s in d.symbols}
    assert set(proc_symbols.keys()) == {'some_proc', 'some_func', 'other_proc'}
    assert all(s.scope is some_type for s in proc_symbols.values())
    assert all(isinstance(s.type.dtype, ProcedureType) for s in proc_symbols.values())

    assert proc_symbols['some_proc'].type.bind_names == ('some_type_some_proc',)
    assert proc_symbols['some_proc'].type.bind_names[0].scope is module
    assert proc_symbols['some_func'].type.bind_names == ('some_type_some_func',)
    assert proc_symbols['some_proc'].type.bind_names[0].scope is module
    assert proc_symbols['other_proc'].type.bind_names is None
    assert all(proc.type.initial is None for proc in proc_symbols.values())

    # Verify type representation in bound routines
    some_type_some_proc = module['some_type_some_proc']
    self = some_type_some_proc.symbol_map['self']
    assert isinstance(self.type.dtype, DerivedType)
    assert self.type.dtype.typedef is some_type
    assert self.type.polymorphic is True
    decls = FindNodes(ir.VariableDeclaration).visit(some_type_some_proc.spec)
    assert 'CLASS(SOME_TYPE)' in fgen(decls[0]).upper()

    # Verify type representation in using routine
    assert isinstance(routine.symbol_attrs['tp'].dtype, DerivedType)
    assert isinstance(routine.symbol_attrs['tp'].dtype.typedef, ir.TypeDef)
    assert routine.symbol_attrs['tp'].polymorphic is None
    assert routine.symbol_attrs['tp'].dtype.typedef is some_type
    decls = FindNodes(ir.VariableDeclaration).visit(routine.spec)
    assert 'TYPE(SOME_TYPE)' in fgen(decls[1]).upper()

    # TODO: verify correct type association of calls to type-bound procedures

    # Next, without external definitions
    routine = Subroutine.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    assert 'some_type' not in routine.symbol_attrs
    assert 'other_type' not in routine.symbol_attrs
    assert isinstance(routine.symbol_attrs['tp'].dtype, DerivedType)
    assert routine.symbol_attrs['tp'].dtype.typedef == BasicType.DEFERRED

    # TODO: verify correct type association of calls to type-bound procedures


@pytest.mark.parametrize('frontend', available_frontends())
def test_derived_type_bind_attrs(frontend, tmp_path):
    """
    Test attribute representation in type-bound procedures
    """
    fcode = """
module derived_types_bind_attrs_mod
    implicit none

    type some_type
        integer :: val
    contains
        PROCEDURE, PASS, NON_OVERRIDABLE :: pass_proc
        PROCEDURE, NOPASS, PUBLIC :: no_pass_proc
        PROCEDURE, PASS(this), private :: pass_arg_proc
    end type some_type

contains

    subroutine pass_proc(self)
        class(some_type) :: self
    end subroutine pass_proc

    subroutine no_pass_proc(val)
        integer, intent(inout) :: val
    end subroutine no_pass_proc

    subroutine pass_arg_proc(val, this)
        integer, intent(inout) :: val
        class(some_type) :: this
    end subroutine pass_arg_proc

end module derived_types_bind_attrs_mod
    """.strip()

    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    some_type = module.typedef_map['some_type']

    proc_decls = FindNodes(ir.ProcedureDeclaration).visit(some_type.body)
    assert len(proc_decls) == 3
    assert all(decl.interface is None for decl in proc_decls)

    proc_symbols = {s.name.lower(): s for d in proc_decls for s in d.symbols}
    assert set(proc_symbols.keys()) == {'pass_proc', 'no_pass_proc', 'pass_arg_proc'}

    assert proc_symbols['pass_proc'].type.pass_attr is True
    assert proc_symbols['pass_proc'].type.non_overridable is True
    assert proc_symbols['pass_proc'].type.private is None
    assert proc_symbols['pass_proc'].type.public is None

    assert proc_symbols['no_pass_proc'].type.pass_attr is False
    assert proc_symbols['no_pass_proc'].type.non_overridable is None
    assert proc_symbols['no_pass_proc'].type.private is None
    assert proc_symbols['no_pass_proc'].type.public is True

    assert proc_symbols['pass_arg_proc'].type.pass_attr == 'this'
    assert proc_symbols['pass_arg_proc'].type.private is True
    assert proc_symbols['pass_arg_proc'].type.public is None

    proc_decls = {decl.symbols[0].name: decl for decl in proc_decls}
    assert ', PASS' in fgen(proc_decls['pass_proc'])
    assert ', NON_OVERRIDABLE' in fgen(proc_decls['pass_proc'])

    assert ', NOPASS' in fgen(proc_decls['no_pass_proc'])
    assert ', PUBLIC' in fgen(proc_decls['no_pass_proc'])

    assert ', PASS(this)' in fgen(proc_decls['pass_arg_proc'])
    assert ', PRIVATE' in fgen(proc_decls['pass_arg_proc'])


@pytest.mark.parametrize('frontend', available_frontends())
def test_derived_type_bind_deferred(frontend, tmp_path):
    # Example from https://www.ibm.com/docs/en/xffbg/121.141?topic=types-abstract-deferred-bindings-fortran-2003
    fcode = """
module derived_type_bind_deferred_mod
implicit none
TYPE, ABSTRACT :: FILE_HANDLE
   CONTAINS
   PROCEDURE(OPEN_FILE), DEFERRED, PASS(HANDLE) :: OPEN
END TYPE

INTERFACE
    SUBROUTINE OPEN_FILE(HANDLE)
        IMPORT FILE_HANDLE
        CLASS(FILE_HANDLE), INTENT(IN):: HANDLE
    END SUBROUTINE OPEN_FILE
END INTERFACE
end module derived_type_bind_deferred_mod
    """.strip()

    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    file_handle = module.typedef_map['file_handle']
    assert len(file_handle.body) == 2

    proc_decl = file_handle.body[1]
    assert proc_decl.interface == 'open_file'

    proc_sym = proc_decl.symbols[0]
    assert proc_sym.type.deferred is True
    assert proc_sym.type.pass_attr.lower() == 'handle'

    assert ', DEFERRED' in fgen(proc_decl)
    assert ', PASS(HANDLE)' in fgen(proc_decl).upper()


@pytest.mark.parametrize('frontend', available_frontends())
def test_derived_type_final_generic(frontend, tmp_path):
    """
    Test derived types with generic and final bindings
    """
    fcode = """
module derived_type_final_generic_mod
    implicit none

    type hdf5_file
        logical :: is_open = .false.
        integer :: file_id
    contains
        procedure :: open_file => hdf5_file_open
        procedure, private :: hdf5_file_load_int
        procedure, private :: hdf5_file_load_real
        generic, public :: load => hdf5_file_load_int, hdf5_file_load_real
        final :: hdf5_file_close
    end type hdf5_file

contains

    subroutine hdf5_file_open (self, filepath)
        class(hdf5_file) :: self
        character(len=*), intent(in) :: filepath
        self%file_id = LEN(filepath)  ! dummy operation
        self%is_open = .true.
    end subroutine hdf5_file_open

    subroutine hdf5_file_load_int (self, val)
        class(hdf5_file) :: self
        integer, intent(out) :: val
        val = 0
        if (self%is_open) then
            val = self%file_id  ! dummy operation
        end if
    end subroutine hdf5_file_load_int

    subroutine hdf5_file_load_real (self, val)
        class(hdf5_file) :: self
        real, intent(out) :: val
        val = 0.
        if (self%is_open) then
            val = real(self%file_id)  ! dummy operation
        end if
    end subroutine hdf5_file_load_real

    subroutine hdf5_file_close (self)
        type(hdf5_file) :: self
        if (self%is_open) then
            self%file_id = 0
            self%is_open = .false.
        end if
    end subroutine hdf5_file_close
end module derived_type_final_generic_mod
    """.strip()

    mod = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    hdf5_file = mod.typedef_map['hdf5_file']
    proc_decls = FindNodes(ir.ProcedureDeclaration).visit(hdf5_file.body)
    assert len(proc_decls) == 5

    assert all(decl.final is False for decl in proc_decls[:-1])
    assert all(decl.generic is False for decl in proc_decls[:-2])

    proc_map = {proc.name.lower(): proc for decl in proc_decls for proc in decl.symbols}

    assert proc_decls[-2].generic is True
    assert 'generic, public ::' in fgen(proc_decls[-2]).lower()
    assert 'load => ' in fgen(proc_decls[-2]).lower()
    assert proc_decls[-2].symbols == ('load',)
    assert proc_decls[-2].symbols[0].type.bind_names == ('hdf5_file_load_int', 'hdf5_file_load_real')
    assert proc_decls[-2].symbols[0].type.dtype.name == 'load'
    assert proc_decls[-2].symbols[0].type.dtype.is_generic is True
    assert all(proc.type.dtype.name == proc.name for proc in proc_decls[-2].symbols[0].type.bind_names)
    assert all(proc == proc_map[proc.name] for proc in proc_decls[-2].symbols[0].type.bind_names)

    assert proc_decls[-1].final is True
    assert proc_decls[-1].generic is False
    assert 'final ::' in fgen(proc_decls[-1]).lower()
    assert proc_decls[-1].symbols == ('hdf5_file_close',)
    assert proc_decls[-1].symbols[0].type.dtype.name == 'hdf5_file_close'


@pytest.mark.parametrize('frontend', available_frontends())
def test_derived_type_clone(frontend, tmp_path):
    """
    Test cloning of derived types
    """
    fcode = """
module derived_types_clone_mod
  integer, parameter :: jprb = selected_real_kind(13,300)

  type explicit
    real(kind=jprb) :: scalar, vector(3), matrix(3, 3)
    real(kind=jprb) :: red_herring
  end type explicit
end module
"""
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    explicit = module.typedef_map['explicit']
    other = explicit.clone(name='other')

    assert explicit.name == 'explicit'
    assert other.name == 'other'
    assert all(v.scope is other for v in other.variables)
    assert all(v.scope is explicit for v in explicit.variables)

    fcode = fgen(other)
    assert fgen(explicit) == fcode.replace('other', 'explicit')


@pytest.mark.parametrize('frontend', available_frontends())
def test_derived_type_linked_list(frontend, tmp_path):
    """
    Test correct initialization of derived type members that create a circular
    dependency
    """
    fcode = """
module derived_type_linked_list
    implicit none

    type list_t
        integer :: payload
        type(list_t), pointer :: next => null()
    end type list_t

    type(list_t), pointer :: beg => null()
    type(list_t), pointer :: cur => null()

contains

    subroutine find(val, this)
        integer, intent(in) :: val
        type(list_t), pointer, intent(inout) :: this
        type(list_t), pointer :: x
        this => null()
        x => beg
        do while (associated(x))
            if (x%payload == val) then
                this => x
                return
            endif
            x => x%next
        end do
    end subroutine find
end module derived_type_linked_list
    """.strip()

    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    # Test correct instantiation and association of module-level variables
    for name in ('beg', 'cur'):
        assert name in module.variables
        assert isinstance(module.variable_map[name].type.dtype, DerivedType)
        assert module.variable_map[name].type.dtype.typedef is module.typedef_map['list_t']

        variables = module.variable_map[name].type.dtype.typedef.variables
        assert all(v.scope is module.variable_map[name].type.dtype.typedef for v in variables)
        assert 'payload' in variables
        assert 'next' in variables

        variables = module.variable_map[name].variables
        assert all(v.scope is module for v in variables)
        assert f'{name}%payload' in variables
        assert f'{name}%next' in variables

    # Test correct instantiation and association of subroutine-level variables
    routine = module['find']
    for name in ('this', 'x'):
        var = routine.variable_map[name]
        assert var.type.dtype.typedef is module.typedef_map['list_t']

        assert 'payload' in var.variable_map
        assert 'next' in var.variable_map
        assert all(v.scope is var.scope for v in var.variables)

    # Test on-the-fly creation of variable lists
    # Chase the next-chain to the limit with a buffer
    var = routine.variable_map['x']
    name = 'x'
    for _ in range(min(1000, getrecursionlimit()-len(stack())-50)):
        var = var.variable_map['next']
        assert var
        assert var.type.dtype.typedef is module.typedef_map['list_t']
        name = f'{name}%next'
        assert var.name == name


@pytest.mark.parametrize('frontend', available_frontends())
def test_derived_type_nested_procedure_call(frontend, tmp_path):
    """
    Test correct representation of inline calls and call statements for
    type-bound procedures in nested derived types.
    """
    fcode = """
module derived_type_nested_proc_call_mod
    implicit none

    type netcdf_file_raw
        private
    contains
        procedure, public :: exists => raw_exists
    end type

    type netcdf_file
        type(netcdf_file_raw) :: file
    contains
        procedure :: exists
    end type netcdf_file

contains

    function exists(this, var_name) result(is_present)
        class(netcdf_file)           :: this
        character(len=*), intent(in) :: var_name
        logical :: is_present

        is_present = this%file%exists(var_name)
    end function exists

    function raw_exists(this, var_name) result(is_present)
        class(netcdf_file_raw)      :: this
        character(len=*), intent(in) :: var_name
        logical :: is_present

        is_present = .true.
    end function raw_exists

end module derived_type_nested_proc_call_mod
    """.strip()

    mod = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    assignment = FindNodes(ir.Assignment).visit(mod['exists'].body)
    assert len(assignment) == 1
    assignment = assignment[0]
    assert isinstance(assignment.rhs, sym.InlineCall)
    assert fgen(assignment.rhs).lower() == 'this%file%exists(var_name)'

    assert isinstance(assignment.rhs.function, sym.ProcedureSymbol)
    assert isinstance(assignment.rhs.function.type.dtype, ProcedureType)
    assert assignment.rhs.function.parent and isinstance(assignment.rhs.function.parent.type.dtype, DerivedType)
    assert assignment.rhs.function.parent.type.dtype.name == 'netcdf_file_raw'
    assert assignment.rhs.function.parent.type.dtype.typedef is mod['netcdf_file_raw']
    assert assignment.rhs.function.parent.parent
    assert isinstance(assignment.rhs.function.parent.parent.type.dtype, DerivedType)
    assert assignment.rhs.function.parent.parent.type.dtype.name == 'netcdf_file'
    assert assignment.rhs.function.parent.parent.type.dtype.typedef is mod['netcdf_file']


@pytest.mark.parametrize('frontend', available_frontends())
def test_derived_type_sequence(frontend, tmp_path):
    """
    Verify derived types with ``SEQUENCE`` stmt work as expected
    """
    # F2008, Note 4.18
    fcode = """
module derived_type_sequence
    implicit none
    TYPE NUMERIC_SEQ
        SEQUENCE
        INTEGER :: INT_VAL
        REAL :: REAL_VAL
        LOGICAL :: LOG_VAL
    END TYPE NUMERIC_SEQ
end module derived_type_sequence
    """.strip()

    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    numeric_seq = module.typedef_map['numeric_seq']
    assert 'SEQUENCE' in fgen(numeric_seq)


@pytest.fixture(name='shadowed_typedef_symbols_fcode')
def fixture_shadowed_typedef_symbols_fcode(tmp_path, builder):
    # Use a bespoke module name to avoid name clashes
    module_name = f'rad_rand_numb_{tmp_path.name[-4:]}'

    # Excerpt from ecrad's radiation_random_numbers.F90
    fcode = f"""
module {module_name}

  implicit none

  public :: rng_type, IRngNative

  enum, bind(c)
    enumerator IRngNative      ! Built-in Fortran-90 RNG
  end enum

  integer, parameter            :: jpim = selected_int_kind(9)
  integer, parameter            :: jprb = selected_real_kind(13,300)
  integer(kind=jpim), parameter :: NMaxStreams = 512

  type rng_type

    integer(kind=jpim) :: itype = IRngNative
    real(kind=jprb)    :: istate(NMaxStreams)
    integer(kind=jpim) :: nmaxstreams = NMaxStreams
    integer(kind=jpim) :: iseed = 0

  end type rng_type

contains

  subroutine rng_default(istate_dim, maxstreams)
    integer, intent(out) :: istate_dim, maxstreams
    type(rng_type) :: rng
    integer :: dim(1)
    rng = rng_type(istate=0._jprb)
    dim = shape(rng%istate)
    istate_dim = dim(1)
    maxstreams = rng%nmaxstreams
  end subroutine rng_default

  subroutine rng_init(istate_dim, maxstreams)
    integer, intent(out) :: istate_dim, maxstreams
    type(rng_type) :: rng
    integer :: dim(1)
    rng = rng_type(nmaxstreams=256, istate=0._jprb)
    dim = shape(rng%istate)
    istate_dim = dim(1)
    maxstreams = rng%nmaxstreams
  end subroutine rng_init

end module {module_name}
    """.strip()

    # Verify that this code behaves as expected
    ref_path = tmp_path/'radiation_random_numbers.F90'
    ref_path.write_text(fcode)

    ref_lib = jit_compile_lib([ref_path], path=tmp_path, name=module_name, builder=builder)
    ref_mod = getattr(ref_lib, module_name)
    ref_default_shape, ref_default_maxstreams = ref_mod.rng_default()
    ref_init_shape, ref_init_maxstreams = ref_mod.rng_init()

    assert ref_default_shape == 512
    assert ref_default_maxstreams == 512
    assert ref_init_shape == 512
    assert ref_init_maxstreams == 256

    yield fcode


@pytest.mark.parametrize('frontend', available_frontends())
def test_derived_type_rescope_symbols_shadowed(tmp_path, shadowed_typedef_symbols_fcode, frontend):
    """
    Test the rescoping of symbols with shadowed symbols in a typedef.
    """
    # Parse into Loki IR
    module = Module.from_source(shadowed_typedef_symbols_fcode, frontend=frontend, xmods=[tmp_path])
    mod_var = module.variable_map['nmaxstreams']
    assert mod_var.scope is module

    # Verify scope of variables in type def
    rng_type = module.typedef_map['rng_type']
    istate = rng_type.variable_map['istate']
    tdef_var = rng_type.variable_map['nmaxstreams']

    assert istate in ('istate(nmaxstreams)', 'istate(1:nmaxstreams)')
    assert istate.scope is rng_type

    assert istate.dimensions[0] == 'nmaxstreams'
    assert istate.dimensions[0].scope

    # FIXME: Use of NMaxStreams from parent scope is in the wrong scope (LOKI-52)
    #assert istate.dimensions[0].scope is module

    assert tdef_var.scope is rng_type

    if frontend != OMNI:
        # FIXME: OMNI doesn't retain the initializer expressions in the typedef
        from loki.expression import Scalar  # pylint: disable=import-outside-toplevel
        assert tdef_var.type.initial == 'NMaxStreams'
        assert tdef_var.type.initial.scope is module
        assert tdef_var.type.initial == mod_var
        assert isinstance(tdef_var.type.initial, Scalar)

        # Test the outcome works as expected
        filepath = tmp_path/f'{module.name}_{frontend}.F90'
        mod = jit_compile(module, filepath=filepath, objname=module.name)

        default_shape, default_maxstreams = mod.rng_default()
        init_shape, init_maxstreams = mod.rng_init()

        assert default_shape == 512
        assert default_maxstreams == 512
        assert init_shape == 512
        assert init_maxstreams == 256


@pytest.mark.parametrize('frontend', available_frontends())
def test_derived_types_character_array_subscript(frontend, tmp_path):
    fcode = """
module derived_type_char_arr_mod
    implicit none

    type char_arr_type
        character(len=511) :: some_name(3) = ["","",""]
    end type char_arr_type

contains

    subroutine some_routine(config)
        type(char_arr_type), intent(in) :: config
        integer :: i, strlen
        do i=1,3
            if (config%some_name(i)(1:1) == '/') then
                print *, 'absolute path'
            end if
            strlen = len_trim(config%some_name(i))
            if (config%some_name(i)(strlen-2:strlen) == '.nc') then
                print *, 'netcdf file'
            end if
        end do
    end subroutine some_routine
end module derived_type_char_arr_mod
    """.strip()

    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    conditionals = FindNodes(ir.Conditional).visit(module['some_routine'].body)
    assert all(isinstance(c.condition.left, sym.StringSubscript) for c in conditionals)
    assert [fgen(c.condition.left) for c in conditionals] == [
      'config%some_name(i)(1:1)', 'config%some_name(i)(strlen - 2:strlen)'
    ]


@pytest.mark.parametrize('frontend', available_frontends())
def test_derived_types_nested_subscript(frontend, tmp_path):
    fcode = """
module derived_types_nested_subscript
    implicit none

    type inner_type
        integer :: val
    contains
        procedure :: some_routine
    end type inner_type

    type outer_type
        type(inner_type) :: inner(3)
    end type outer_type

contains

    subroutine some_routine(this, val)
        class(inner_type), intent(inout) :: this
        integer, intent(in) :: val
        this%val = val
    end subroutine some_routine

    subroutine driver(outers)
        type(outer_type), intent(inout) :: outers(5)
        integer :: i, j

        do i=1,5
            do j=1,3
                call outers(i)%inner(j)%some_routine(i*10 + j)
            end do
        end do
    end subroutine driver

end module derived_types_nested_subscript
    """.strip()

    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    calls = FindNodes(ir.CallStatement).visit(module['driver'].body)
    assert len(calls) == 1
    assert str(calls[0].name) == 'outers(i)%inner(j)%some_routine'
    assert fgen(calls[0].name) == 'outers(i)%inner(j)%some_routine'


@pytest.mark.parametrize('frontend', available_frontends())
def test_derived_types_nested_type(frontend, tmp_path):
    fcode_module = """
module some_mod
    implicit none

    type some_type
        integer :: val
    contains
        procedure :: some_routine
    end type some_type

    type other_type
        type(some_type) :: data
    contains
        procedure :: other_routine
    end type other_type

contains

    subroutine some_routine(this)
        class(some_type), intent(inout) :: this
        this%val = 5
    end subroutine some_routine

    subroutine other_routine(this)
        class(other_type), intent(inout) :: this
        call this%data%some_routine
    end subroutine other_routine
end module some_mod
    """.strip()

    fcode_driver = """
subroutine driver
    use some_mod, only: other_type
    implicit none
    type(other_type) :: var
    integer :: val
    call var%other_routine
    call var%data%some_routine
    val = var%data%val
end subroutine driver
    """.strip()

    module = Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path])
    driver = Subroutine.from_source(fcode_driver, frontend=frontend, definitions=[module], xmods=[tmp_path])

    other_routine = module['other_routine']
    call = other_routine.body.body[0]
    assert isinstance(call, ir.CallStatement)
    assert isinstance(call.name.type.dtype, ProcedureType)
    assert call.name.parent and isinstance(call.name.parent.type.dtype, DerivedType)
    assert call.name.parent.type.dtype.name == 'some_type'
    assert call.name.parent.type.dtype.typedef is module['some_type']
    assert call.name.parent.parent and isinstance(call.name.parent.parent.type.dtype, DerivedType)
    assert call.name.parent.parent.type.dtype.name == 'other_type'
    assert call.name.parent.parent.type.dtype.typedef is module['other_type']

    calls = FindNodes(ir.CallStatement).visit(driver.body)
    assert len(calls) == 2
    for call in calls:
        assert isinstance(call.name.type.dtype, ProcedureType)
        assert call.name.parent and isinstance(call.name.parent.type.dtype, DerivedType)

    assert calls[0].name.parent.type.dtype.name == 'other_type'
    assert calls[0].name.parent.type.dtype.typedef is module['other_type']

    assert calls[1].name.parent.type.dtype.name == 'some_type'
    assert calls[1].name.parent.type.dtype.typedef is module['some_type']
    assert calls[1].name.parent.parent
    assert calls[1].name.parent.parent.type.dtype.name == 'other_type'
    assert calls[1].name.parent.parent.type.dtype.typedef is module['other_type']

    assignment = driver.body.body[-1]
    assert isinstance(assignment, ir.Assignment)
    assert assignment.rhs.type.dtype is BasicType.INTEGER
    assert assignment.rhs.parent and isinstance(assignment.rhs.parent.type.dtype, DerivedType)
    assert assignment.rhs.parent.type.dtype.name == 'some_type'
    assert assignment.rhs.parent.type.dtype.typedef is module['some_type']
    assert assignment.rhs.parent.parent.type.dtype.name == 'other_type'
    assert assignment.rhs.parent.parent.type.dtype.typedef is module['other_type']


@pytest.mark.parametrize('frontend', available_frontends())
def test_derived_types_abstract_deferred_procedure(frontend, tmp_path):
    fcode = """
module some_mod
    implicit none
    type, abstract :: abstract_type
        contains
        procedure (some_proc), deferred :: some_proc
        procedure (other_proc), deferred :: other_proc
    end type abstract_type

    abstract interface
        subroutine some_proc(this)
            import abstract_type
            class(abstract_type), intent(in) :: this
        end subroutine some_proc
    end interface

    abstract interface
        subroutine other_proc(this)
            import abstract_type
            class(abstract_type), intent(inout) :: this
        end subroutine other_proc
    end interface
end module some_mod
    """.strip()
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    typedef = module['abstract_type']
    assert typedef.abstract is True
    assert typedef.variables == ('some_proc', 'other_proc')
    for symbol in typedef.variables:
        assert isinstance(symbol, sym.ProcedureSymbol)
        assert isinstance(symbol.type.dtype, ProcedureType)
        assert symbol.type.dtype.name.lower() == symbol.name.lower()
        assert symbol.type.bind_names == (symbol,)
        assert symbol.scope is typedef
        assert symbol.type.bind_names[0].scope is module

    assert typedef.imported_symbols == ()
    assert not typedef.imported_symbol_map


@pytest.mark.parametrize('frontend', available_frontends())
def test_derived_type_symbol_inheritance(frontend, tmp_path):
    fcode = """
module some_mod
implicit none
type :: base_type
    integer :: memberA
    real :: memberB
    contains
    procedure :: init => init_base_type
    procedure :: final => final_base_type
    procedure :: copy
end type base_type

type, extends(base_type) :: extended_type
    integer :: memberC
    contains
    procedure :: init => init_extended_type
    procedure :: final => final_extended_type
    procedure :: do_something
end type extended_type

type, extends(extended_type) :: extended_extended_type
    integer :: memberD
    contains
    procedure :: init => init_extended_extended_type
    procedure :: final => final_extended_extended_type
    procedure :: do_something => do_something_else
end type extended_extended_type

contains

subroutine init_base_type(self)
  class(base_type) :: self
end subroutine init_base_type
subroutine final_base_type(self)
  class(base_type) :: self
end subroutine final_base_type
subroutine copy(self)
  class(base_type) :: self
end subroutine copy

subroutine init_extended_type(self)
  class(extended_type) :: self
end subroutine init_extended_type
subroutine final_extended_type(self)
  class(extended_type) :: self
end subroutine final_extended_type
subroutine do_something(self)
  class(extended_type) :: self
end subroutine do_something

subroutine init_extended_extended_type(self)
  class(extended_extended_type) :: self
end subroutine init_extended_extended_type
subroutine final_extended_extended_type(self)
  class(extended_extended_type) :: self
end subroutine final_extended_extended_type
subroutine do_something_else(self)
  class(extended_extended_type) :: self
end subroutine do_something_else
end module some_mod
""".strip()

    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    base_type = module['base_type']
    extended_type = module['extended_type']
    extended_extended_type = module['extended_extended_type']

    assert base_type.variables == ('memberA', 'memberB', 'init', 'final', 'copy')
    assert base_type.variables[2].type.bind_names[0] == 'init_base_type'
    assert base_type.variables[3].type.bind_names[0] == 'final_base_type'
    assert not base_type.variables[4].type.bind_names
    assert all(s.scope is base_type for d in base_type.declarations for s in d.symbols)
    assert base_type.imported_symbols == ()
    assert not base_type.imported_symbol_map

    assert extended_type.variables == ('memberC', 'init', 'final', 'do_something', 'memberA', 'memberB', 'copy')
    assert extended_type.variables[1].type.bind_names[0] == 'init_extended_type'
    assert extended_type.variables[2].type.bind_names[0] == 'final_extended_type'
    assert not extended_type.variables[3].type.bind_names
    assert not extended_type.variables[6].type.bind_names
    assert all(s.scope is extended_type for d in extended_type.declarations for s in d.symbols)
    assert extended_type.imported_symbols == ()
    assert not extended_type.imported_symbol_map
    #check for non-empty declarations
    assert all(decl.symbols for decl in extended_type.declarations)


    assert extended_extended_type.variables == ('memberD', 'init', 'final', 'do_something', 'memberC',
                                                'memberA', 'memberB', 'copy')
    assert extended_extended_type.variables[1].type.bind_names[0] == 'init_extended_extended_type'
    assert extended_extended_type.variables[2].type.bind_names[0] == 'final_extended_extended_type'
    assert extended_extended_type.variables[3].type.bind_names[0] == 'do_something_else'
    assert not extended_extended_type.variables[7].type.bind_names
    assert all(s.scope is extended_extended_type for d in extended_extended_type.declarations for s in d.symbols)
    assert extended_extended_type.imported_symbols == ()
    assert not extended_extended_type.imported_symbol_map
    #check for non-empty declarations
    assert all(decl.symbols for decl in extended_extended_type.declarations)


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('qualified_import', (True, False))
def test_derived_type_inheritance_missing_parent(frontend, qualified_import, tmp_path):
    fcode_parent = """
module parent_mod
    implicit none
    type, abstract, public :: parent_type
        integer :: val
    end type parent_type
end module parent_mod
    """.strip()

    fcode_derived = f"""
module derived_mod
    use parent_mod{", only: parent_type" if qualified_import else ""}
    implicit none
    type, public, extends(parent_type) :: derived_type
        integer :: val2
    end type derived_type
contains
    subroutine do_something(this)
        class(derived_type), intent(inout) :: this
        this%val = 1
        this%val2 = 2
    end subroutine do_something
end module derived_mod
    """.strip()

    parent = Module.from_source(fcode_parent, frontend=frontend, xmods=[tmp_path])

    # Without enrichment we obtain only DEFERRED type information (but don't fail!)
    derived = Module.from_source(fcode_derived, frontend=frontend, xmods=[tmp_path])
    assert derived['derived_type'].parent_type == BasicType.DEFERRED

    # With enrichment we obtain the parent type from the parent module
    derived = Module.from_source(fcode_derived, frontend=frontend, xmods=[tmp_path], definitions=[parent])
    assert derived['derived_type'].parent_type is parent['parent_type']


@pytest.mark.parametrize('frontend', available_frontends(
    xfail=[(OMNI, 'OMNI does not handle type-bound procedures')]
))
def test_derived_type_type_bound_call_proctype(frontend):
    """
    Test correct procedure type information on type-bound procedures
    """

    fcode_functor = """
module functor_mod
  implicit none
  type functor
    contains
    procedure :: calc
  end type functor
contains

  subroutine calc(this, c)
    type(functor) :: this
    real, intent(inout) :: c
    !$loki routine seq
    c = c + 1
  end subroutine calc
end module functor_mod
"""

    fcode_kernel = """
module kernel_mod
  use functor_mod
  implicit none
contains

  subroutine kernel(ncol, istart, iend)
    integer, intent(in) :: ncol, istart, iend
    real, intent(inout) :: a(ncol)
    type(functor) :: f

    do i = istart,iend
      call f%calc(a(i))
    end do
  end subroutine kernel
end module kernel_mod
"""

    functor_mod = Module.from_source(fcode_functor, frontend=frontend)
    kernel_mod = Module.from_source(fcode_kernel, definitions=functor_mod, frontend=frontend)
    kernel = kernel_mod['kernel']

    calls = FindNodes(ir.CallStatement).visit(kernel.body)
    assert len(calls) == 1
    assert calls[0].name == 'f%calc'
    assert calls[0].routine and calls[0].routine == functor_mod['calc']
loki-ecmwf-0.3.6/loki/types/tests/test_procedure_types.py0000664000175000017500000000565415167130205024044 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki import Function, Module, Subroutine
from loki.expression import symbols as sym
from loki.frontend import available_frontends, OMNI
from loki.ir import nodes as ir, FindNodes
from loki.types import ProcedureType


@pytest.mark.parametrize('frontend', available_frontends())
def test_procedure_type(tmp_path, frontend):
    """ Test `ProcedureType` links to the procedure when it is defined. """

    fcode_mod = """
module my_mod
implicit none

contains

  subroutine test_routine(n, a)
    integer, intent(in) :: n
    real(kind=4), intent(inout) :: a(3)
    real(kind=4) :: smoke
    real(kind=4) :: pants, on, fire
    pants(on, fire) = on + fire

    call me_maybe(n, a)

    smoke = on_the_water(a(3))
  end subroutine test_routine

  subroutine me_maybe(n, a)
    integer, intent(in) :: n
    real(kind=4), intent(inout) :: a(3)

    a(1) = on_the_water(a(2))
  end subroutine me_maybe

  function on_the_water(b) result(rick)
    real(kind=4), intent(in) :: b
    real(kind=4) :: rick

    rick = 2*b
  end function on_the_water
end module my_mod
"""
    module = Module.from_source(fcode_mod, frontend=frontend, xmods=[tmp_path])
    routine = module['test_routine']
    assert isinstance(module['me_maybe'], Subroutine)
    assert isinstance(module.symbol_attrs['me_maybe'].dtype, ProcedureType)
    assert isinstance(module['on_the_water'], Function)
    assert isinstance(module.symbol_attrs['on_the_water'].dtype, ProcedureType)

    # Procedure type linked to Subroutine
    calls = FindNodes(ir.CallStatement).visit(routine.body)
    assert len(calls) == 1
    ptype = calls[0].name.type.dtype
    assert isinstance(ptype, ProcedureType)
    assert str(ptype) == 'me_maybe' and repr(ptype) == ''
    assert ptype.procedure == module['me_maybe']

    # Procedure type linked to Function
    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assigns) == 1
    assert isinstance(assigns[0].rhs, sym.InlineCall)
    ftype = assigns[0].rhs.function.type.dtype
    assert str(ftype) == 'on_the_water' and repr(ftype) == ''
    assert ftype.procedure == module['on_the_water']

    # Procedure type linked to StatementFunction (not supported in OMNI)
    stmtfuncs = FindNodes(ir.StatementFunction).visit(routine.spec)
    if frontend != OMNI:
        assert len(stmtfuncs) == 1
        sftype = stmtfuncs[0].variable.type.dtype
        assert isinstance(sftype, ProcedureType)
        assert str(sftype) == 'pants' and repr(sftype) == ''
        assert sftype.procedure == stmtfuncs[0]
loki-ecmwf-0.3.6/loki/types/tests/test_types.py0000664000175000017500000005032115167130205021763 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from pathlib import Path
from random import choice
import pytest

from loki import (
    Sourcefile, Module, Subroutine, BasicType, SymbolAttributes,
    DerivedType, TypeDef, FCodeMapper, DataType, fgen, ProcedureType,
    FindNodes, ProcedureDeclaration
)
from loki.expression import symbols as sym
from loki.frontend import available_frontends, OMNI


@pytest.fixture(scope='module', name='here')
def fixture_here():
    return Path(__file__).parent


def test_basic_type():
    """
    Tests the conversion of strings to `BasicType`.
    """
    assert all(t == BasicType(t.value) for t in BasicType)
    assert all(isinstance(t, DataType) for t in BasicType)

    assert all(t == BasicType.from_name(t.name) for t in BasicType)
    assert all(t == BasicType.from_str(t.name) for t in BasicType)

    fortran_type_map = {'LOGICAL': BasicType.LOGICAL, 'INTEGER': BasicType.INTEGER,
                        'REAL': BasicType.REAL, 'CHARACTER': BasicType.CHARACTER,
                        'COMPLEX': BasicType.COMPLEX}

    # Randomly change case of single letters (FORTRAN is not case-sensitive)
    test_map = {''.join(choice((str.upper, str.lower))(c) for c in s): t
                for s, t in fortran_type_map.items()}

    assert all(t == BasicType.from_fortran_type(s) for s, t in test_map.items())
    assert all(t == BasicType.from_str(s) for s, t in test_map.items())

    c99_type_map = {'bool': BasicType.LOGICAL, '_Bool': BasicType.LOGICAL,
                    'short': BasicType.INTEGER, 'unsigned short': BasicType.INTEGER,
                    'signed short': BasicType.INTEGER, 'int': BasicType.INTEGER,
                    'unsigned int': BasicType.INTEGER, 'signed int': BasicType.INTEGER,
                    'long': BasicType.INTEGER, 'unsigned long': BasicType.INTEGER,
                    'signed long': BasicType.INTEGER, 'long long': BasicType.INTEGER,
                    'unsigned long long': BasicType.INTEGER, 'signed long long': BasicType.INTEGER,
                    'float': BasicType.REAL, 'double': BasicType.REAL, 'long double': BasicType.REAL,
                    'char': BasicType.CHARACTER, 'float _Complex': BasicType.COMPLEX,
                    'double _Complex': BasicType.COMPLEX, 'long double _Complex': BasicType.COMPLEX}

    assert all(t == BasicType.from_c99_type(s) for s, t in c99_type_map.items())
    assert all(t == BasicType.from_str(s) for s, t in c99_type_map.items())


@pytest.mark.parametrize('frontend', available_frontends())
def test_type_declaration_attributes(frontend):
    """
    Test recognition of different declaration attributes.
    """
    fcode = """
subroutine test_type_declarations(b, c)
    integer, parameter :: a = 4
    integer, intent(in) :: b
    real(kind=a), target, intent(inout) :: c(:)
    real(kind=a), allocatable :: d(:)
    real(kind=a), pointer, contiguous :: e(:)

end subroutine test_type_declarations
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)
    assert routine.symbol_attrs['a'].parameter
    assert routine.symbol_attrs['b'].intent == 'in'
    assert routine.symbol_attrs['c'].target
    assert routine.symbol_attrs['c'].intent == 'inout'
    assert routine.symbol_attrs['d'].allocatable
    assert routine.symbol_attrs['e'].pointer
    assert routine.symbol_attrs['e'].contiguous


@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'Segfault with pragmas in derived types')]))
def test_pragmas(frontend):
    """
    Test detection of `!$loki dimension` pragmas to indicate intended shapes.
    """
    fcode = """
module types

  integer, parameter :: jprb = selected_real_kind(13,300)

  type pragma_type
    !$loki dimension(3,3)
    real(kind=jprb), dimension(:,:), pointer :: matrix
    !$loki dimension(klon,klat,2)
    real(kind=jprb), pointer :: tensor(:, :, :)
  end type pragma_type

contains

  subroutine alloc_pragma_type(item)
    type(pragma_type), intent(inout) :: item
    allocate(item%matrix(5,5))
    allocate(item%tensor(3,4,5))
  end subroutine

  subroutine free_pragma_type(item)
    type(pragma_type), intent(inout) :: item
    deallocate(item%matrix)
    deallocate(item%tensor)
  end subroutine

end module types
"""
    fsymgen = FCodeMapper()

    source = Sourcefile.from_source(fcode, frontend=frontend)
    pragma_type = source['types'].symbol_attrs['pragma_type'].dtype

    assert pragma_type.typedef is source['types'].typedef_map['pragma_type']
    assert fsymgen(pragma_type.typedef.variables[0].shape) == '(3, 3)'
    assert fsymgen(pragma_type.typedef.variables[1].shape) == '(klon, klat, 2)'


@pytest.mark.parametrize('frontend', available_frontends())
def test_type_derived_type(frontend, tmp_path):
    """
    Test the detection of known derived type definitions.
    """

    fcode = """
module test_type_derived_type_mod
  implicit none
  integer, parameter :: a_kind = 4

  type my_struct
    real(kind=a_kind) :: a(:), b(:,:)
  end type my_struct

  contains
  subroutine test_type_derived_type(a, b, c)
    type(my_struct), target, intent(inout) :: a
    type(my_struct), allocatable :: b(:)
    type(my_struct), pointer :: c

  end subroutine test_type_derived_type
end module test_type_derived_type_mod
"""
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    routine = module['test_type_derived_type']

    a, b, c = routine.variables
    assert isinstance(a, sym.Scalar)
    assert isinstance(a.type.dtype, DerivedType)
    assert a.type.target
    assert isinstance(b, sym.Array)
    assert isinstance(b.type.dtype, DerivedType)
    assert b.type.allocatable
    assert isinstance(c, sym.Scalar)
    assert isinstance(c.type.dtype, DerivedType)
    assert c.type.pointer

    # Ensure derived types have links to type definition and correct scope
    for var_getter in [lambda v: v.type.dtype.typedef.variables, lambda v: v.variables]:
        assert len(var_getter(a)) == 2
        assert len(var_getter(b)) == 2
        assert len(var_getter(c)) == 2
    assert all(v.scope is routine for v in a.variables)
    assert all(v.scope is routine for v in b.variables)
    assert all(v.scope is routine for v in c.variables)

    # Ensure all member variable have an entry in the local symbol table
    assert routine.symbol_attrs['a%a'].shape == (':',)
    assert routine.symbol_attrs['a%b'].shape == (':',':')
    assert routine.symbol_attrs['b%a'].shape == (':',)
    assert routine.symbol_attrs['b%b'].shape == (':',':')
    assert routine.symbol_attrs['c%a'].shape == (':',)
    assert routine.symbol_attrs['c%b'].shape == (':',':')


@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'OMNI cannot deal with deferred type info')]))
def test_type_module_imports(frontend):
    """
    Test the detection of known / unknown symbols types from module imports.
    """
    fcode = """
subroutine test_type_module_imports(arg_a, arg_b)
  use my_types_mod, only: a_kind, a_dim, a_type
  implicit none

  real(kind=a_kind), intent(in) :: arg_a(a_dim)
  type(a_type), intent(in) :: arg_b
end subroutine test_type_module_imports
"""
    # Ensure types are deferred without a-priori context info
    routine = Subroutine.from_source(fcode, frontend=frontend)
    assert routine.symbol_attrs['a_kind'].dtype == BasicType.DEFERRED
    assert routine.symbol_attrs['a_dim'].dtype == BasicType.DEFERRED
    assert routine.symbol_attrs['a_type'].dtype == BasicType.DEFERRED

    # Ensure local variable info is correct, as far as known
    arg_a, arg_b = routine.variables
    assert arg_a.type.kind.type.compare(routine.symbol_attrs['a_kind'], ignore=('imported',))
    assert arg_a.dimensions[0].type.compare(routine.symbol_attrs['a_dim'])
    assert isinstance(arg_b.type.dtype, DerivedType)
    assert arg_b.type.dtype.typedef == BasicType.DEFERRED

    fcode_module = """
module my_types_mod
  implicit none

  integer, parameter :: a_kind = 4
  integer(kind=a_kind) :: a_dim

  type a_type
    real(kind=a_kind), allocatable :: a(:), b(:,:)
  end type a_type
end module my_types_mod
"""
    module = Module.from_source(fcode_module, frontend=frontend)
    routine = Subroutine.from_source(fcode, definitions=module, frontend=frontend)

    # Check that module variables and types have been imported
    assert routine.symbol_attrs['a_kind'].dtype == BasicType.INTEGER
    assert routine.symbol_attrs['a_kind'].parameter
    assert routine.symbol_attrs['a_kind'].initial == 4
    assert routine.symbol_attrs['a_dim'].dtype == BasicType.INTEGER
    assert routine.symbol_attrs['a_dim'].kind == 'a_kind'
    assert isinstance(routine.symbol_attrs['a_type'].dtype.typedef, TypeDef)

    # Check that external type definition has been linked
    assert isinstance(routine.variable_map['arg_b'].type.dtype.typedef, TypeDef)
    assert routine.variable_map['arg_b'].type.dtype.typedef.symbol_attrs != routine.symbol_attrs

    # Check that we correctly re-scoped the member variable
    a, b = routine.variable_map['arg_b'].variables
    assert ','.join(str(d) for d in a.dimensions) == ':'
    assert ','.join(str(d) for d in b.dimensions) == ':,:'
    assert a.type.kind == b.type.kind == 'a_kind'
    assert a.scope is routine
    assert b.scope is routine

    # Ensure all member variable have an entry in the local symbol table
    assert routine.symbol_attrs['arg_b%a'].shape == (':',)
    assert routine.symbol_attrs['arg_b%b'].shape == (':',':')


@pytest.mark.parametrize('frontend', available_frontends())
def test_type_char_length(frontend):
    """
    Test the various beautiful ways of how Fortran allows to specify
    character lengths
    """
    fcode = f"""
subroutine test_type_char_length
    implicit none
    character*80  :: kill_it_with_fire
    character(60) :: if_you_insist
    character(len=21) :: okay
    character(len=*) :: oh_dear
    character(len=:) :: come_on
    character :: you_gotta_be_kidding_me*20
    character(*) :: whatever(5)
    character(10, 1) :: this_is_getting_silly
    {'character(11, kind=1) :: i_mean' if frontend != OMNI else ''}
    character(len=12, kind=1) :: WHAT
    character(kind=1) :: DO_YOU_WANT
    character(kind=1, len=13) :: FROM_ME
    character*(*) :: where_do_I_begin
    character :: and_how_does_it_end*(*)
end subroutine test_type_char_length
    """.strip()

    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert routine.variable_map['kill_it_with_fire'].type.length == '80'
    assert routine.variable_map['if_you_insist'].type.length == '60'
    assert routine.variable_map['okay'].type.length == '21'
    assert routine.variable_map['oh_dear'].type.length == '*'
    assert routine.variable_map['come_on'].type.length == ':'
    assert routine.variable_map['you_gotta_be_kidding_me'].type.length == '20'
    assert routine.variable_map['whatever'].type.length == '*'
    assert routine.variable_map['whatever'].shape == ('5',)
    assert routine.variable_map['this_is_getting_silly'].type.length == '10'
    if frontend != OMNI:
        # OMNI swallows this one
        assert routine.variable_map['this_is_getting_silly'].type.kind == '1'
    if frontend != OMNI:
        # OMNI fails with syntax error on this one
        assert routine.variable_map['i_mean'].type.length == '11'
        assert routine.variable_map['i_mean'].type.kind == '1'
    assert routine.variable_map['what'].type.length == '12'
    assert routine.variable_map['what'].type.kind == '1'
    assert routine.variable_map['do_you_want'].type.length is None
    if frontend != OMNI:
        # OMNI swallows that one, too
        assert routine.variable_map['do_you_want'].type.kind == '1'
    assert routine.variable_map['from_me'].type.length == '13'
    if frontend != OMNI:
        # And that one
        assert routine.variable_map['from_me'].type.kind == '1'
    assert routine.variable_map['and_how_does_it_end'].type.length == '*'

    code = routine.to_fortran()
    for length in ('80', '60', '21', '*', ':', '20'):
        assert f'CHARACTER(LEN={length}) ::' in code

    if frontend == OMNI:
        for length in (10, 13):
            assert f'CHARACTER(LEN={length!s}) :: ' in code
        assert 'CHARACTER(LEN=12, KIND=1) :: ' in code

    else:
        for length in range(10, 14):
            assert f'CHARACTER(LEN={length!s}, KIND=1) :: ' in code
        assert 'CHARACTER(KIND=1) :: ' in code


@pytest.mark.parametrize('frontend', available_frontends())
def test_type_kind_value(frontend):
    """
    Test the various way how kind parameters can be specified
    """
    fcode = """
subroutine test_type_kind_value
    implicit none

    integer, parameter :: jprb = selected_real_kind(13,300)
    integer, parameter :: jpim = selected_int_kind(9)

    integer*8 :: int_8_s
    integer(8) :: int_8_p
    integer(kind=8) :: int_8_k

    integer(jpim) :: int_jpim_p
    integer(kind=jpim) :: int_jpim_k

    real*16 :: real_16_s
    real(16) :: real_16_p
    real(kind=16) :: real_16_k

    real(jprb) :: real_jprb_p
    real(kind=jprb) :: real_jprb_k
end subroutine test_type_kind_value
    """.strip()

    routine = Subroutine.from_source(fcode, frontend=frontend)

    if frontend == OMNI:
        int_kinds = ('8', 'selected_int_kind(9)')
        real_kinds = ('16', 'selected_real_kind(13, 300)')
    else:
        int_kinds = ('8', 'jpim')
        real_kinds = ('16', 'jprb')

    for kind in int_kinds:
        for var in routine.variables:
            if var.name.lower().startswith(f'int_{kind}'):
                assert var.type.kind == kind and f'INTEGER(KIND={kind.upper()})' in str(fgen(var.type)).upper()

    for kind in real_kinds:
        for var in routine.variables:
            if var.name.lower().startswith(f'real_{kind}'):
                assert var.type.kind == kind and f'REAL(KIND={kind.upper()})' in str(fgen(var.type)).upper()


@pytest.mark.parametrize('frontend', available_frontends())
def test_type_contiguous(frontend):
    """
    Test pointer arguments with contiguous attribute (a F2008-feature, which is not supported by
    all frontends).
    """
    fcode = """
subroutine routine_contiguous(vec)
  integer, parameter :: jprb = selected_real_kind(13,300)
  real(kind=jprb), pointer, contiguous :: vec(:)

  vec(:) = 2.
end subroutine routine_contiguous
    """
    routine = Sourcefile.from_source(fcode, frontend=frontend, preprocess=True)['routine_contiguous']
    assert len(routine.arguments) == 1
    assert routine.arguments[0].type.contiguous and routine.arguments[0].type.pointer


@pytest.mark.parametrize('frontend', available_frontends())
def test_type_procedure_pointer_declaration(frontend, tmp_path):
    # Example code from F2008 standard, Note 12.15
    fcode = """
MODULE some_mod

ABSTRACT INTERFACE
    FUNCTION REAL_FUNC (X)
        REAL, INTENT (IN) :: X
        REAL :: REAL_FUNC
    END FUNCTION REAL_FUNC
END INTERFACE
INTERFACE
    SUBROUTINE SUB (X)
        REAL, INTENT (IN) :: X
    END SUBROUTINE SUB
END INTERFACE

!-- Some external or dummy procedures with explicit interface.
PROCEDURE (REAL_FUNC) :: BESSEL, GFUN
PROCEDURE (SUB) :: PRINT_REAL

!-- Some procedure pointers with explicit interface,
!-- one initialized to NULL().
PROCEDURE (REAL_FUNC), POINTER :: P, R => NULL()
PROCEDURE (REAL_FUNC), POINTER :: PTR_TO_GFUN

!-- A derived type with a procedure pointer component ...
TYPE STRUCT_TYPE
    PROCEDURE (REAL_FUNC), POINTER :: COMPONENT
END TYPE STRUCT_TYPE

!-- ... and a variable of that type.
TYPE(STRUCT_TYPE) :: STRUCT

!-- An external or dummy function with implicit interface
PROCEDURE (REAL) :: PSI

END MODULE some_mod
    """.strip()

    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    # FIXME: Because of our broken way of capturing function return types this gets the wrong
    #        variable type currently...
    assert isinstance(module.symbol_map['real_func'], (sym.Scalar, sym.ProcedureSymbol))

    decl_map = {s.name.lower(): d for d in FindNodes(ProcedureDeclaration).visit(module.spec) for s in d.symbols}

    # Check symbols are declared and have the right type
    procedure_names = ('sub', 'bessel', 'gfun', 'print_real', 'p', 'r', 'ptr_to_gfun', 'psi')  # 'real_func'
    pointer_names = ('p', 'r', 'ptr_to_gfun')
    null_init_names = ('r',)
    for name in procedure_names:
        assert name in module.symbols
        symbol = module.symbol_map[name]
        assert isinstance(symbol, sym.ProcedureSymbol)
        assert isinstance(symbol.type.dtype, ProcedureType)
        if name in pointer_names:
            assert symbol.type.pointer is True
            assert ', POINTER' in fgen(decl_map[name])
        else:
            assert symbol.type.pointer is None
        if name in null_init_names:
            assert fgen(symbol.type.initial).upper() == 'NULL()'
        else:
            assert symbol.type.initial is None

    # Assert symbols have the right procedure type associated
    real_funcs = ('bessel', 'gfun', 'p', 'r', 'ptr_to_gfun')  # 'real_func'
    for name in real_funcs:
        symbol = module.symbol_map[name]
        assert symbol.type.dtype.name.upper() == 'REAL_FUNC'
        assert symbol.type.dtype.is_function is True
        if name in decl_map:
            assert decl_map[name].interface == 'REAL_FUNC'
            assert 'PROCEDURE(REAL_FUNC)' in fgen(decl_map[name]).upper()
        if name in null_init_names:
            assert f'{name.upper()} => NULL()' in fgen(decl_map[name]).upper()

    sub_funcs = ('print_real', 'sub')
    for name in sub_funcs:
        symbol = module.symbol_map[name]
        assert symbol.type.dtype.name.upper() == 'SUB'
        assert symbol.type.dtype.is_function is False
        if name in decl_map:
            assert decl_map[name].interface == 'SUB'
            assert 'PROCEDURE(SUB)' in fgen(decl_map[name]).upper()

    # Assert procedure pointer component in the derived_type is sane
    struct_type = module.typedef_map['struct_type']
    decls = FindNodes(ProcedureDeclaration).visit(struct_type.body)
    assert len(decls) == 1
    assert decls[0].symbols == ('component',)
    assert decls[0].symbols[0].type.dtype.name.upper() == 'REAL_FUNC'
    assert decls[0].symbols[0].type.pointer is True
    assert decls[0].interface == 'real_func'

    # Assert the variable of that type is sane
    struct = module.symbol_map['struct']
    assert struct.type.dtype.name.upper() == 'STRUCT_TYPE'
    assert struct.type.dtype.typedef is struct_type

    # Assert the external procedure with implicit interface is sane
    psi = module.symbol_map['psi']
    assert isinstance(psi, sym.ProcedureSymbol)
    assert isinstance(psi.type.dtype, ProcedureType)
    assert psi.type.dtype.name.upper() == 'PSI'
    assert psi.type.dtype.return_type.compare(SymbolAttributes(BasicType.REAL))
    assert decl_map['psi'].interface == BasicType.REAL
    assert 'PROCEDURE(REAL)' in fgen(decl_map['psi']).upper()


@pytest.mark.parametrize('frontend', available_frontends())
def test_type_attach_scope_kind(frontend, tmp_path):
    """
    Validate scopes for nested variables (such as initial values for kind parameters
    that are shadowed in a nested scope) are assigned to the right scope
    """
    fcode = """
module phys_mod
use iso_fortran_env
implicit none

integer, parameter :: dp = REAL64
integer, parameter :: lp = dp     !! lp : "local" precision
integer, parameter :: ip = INT64
integer, parameter :: nspecies = 5

contains

subroutine phys_kernel_LU_SOLVER_COMPACT(dim1,dim2,i1,i2)
    integer(kind=ip),intent(in) :: dim1, dim2, i1,i2
    real(kind=lp) :: dp(i1:i2), temp_hor1(i1:i2)
    real(kind=lp) :: temp_out(i1:i2,nspecies), out_lev_m_1(i1:i2,nspecies)

end subroutine phys_kernel_LU_SOLVER_COMPACT

end module phys_mod
    """.strip()

    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    routine = module['phys_kernel_lu_solver_compact']
    assert routine.variable_map['temp_out'].scope is routine
    assert module.variable_map['dp'].scope is module
    assert routine.variable_map['dp'].scope is routine

    if frontend != OMNI:
        assert routine.variable_map['temp_out'].type.kind == 'lp'
        assert routine.variable_map['temp_out'].type.kind.scope is module

        assert routine.variable_map['temp_out'].type.kind.initial == 'dp'
        assert routine.variable_map['temp_out'].type.kind.initial.scope is module
loki-ecmwf-0.3.6/loki/types/tests/test_scope.py0000664000175000017500000000313515167130205021731 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
A collection of tests for :any:`SymbolAttrs`, :any:`SymbolTable` and :any:`Scope`.
"""

from loki.types import SymbolAttributes, BasicType


def test_symbol_attributes():
    """
    Tests the attachment, lookup and deletion of arbitrary attributes from
    :any:`SymbolAttributes`
    """
    _type = SymbolAttributes('integer', a='a', b=True, c=None)
    assert _type.dtype == BasicType.INTEGER
    assert _type.a == 'a'
    assert _type.b
    assert _type.c is None
    assert _type.foofoo is None

    _type.foofoo = 'bar'
    assert _type.foofoo == 'bar'

    delattr(_type, 'foofoo')
    assert _type.foofoo is None

    _type.b = None
    assert _type.b is None


def test_symbol_attributes_compare():
    """
    Test dedicated `type.compare` methods that allows certain
    attributes to be excluded from comparison.
    """
    someint = SymbolAttributes('integer', a='a', b=True, c=None)
    another = SymbolAttributes('integer', a='a', b=False, c=None)
    somereal = SymbolAttributes('real', a='a', b=True, c=None)

    assert not someint.compare(another)
    assert not another.compare(someint)
    assert someint.compare(another, ignore='b')
    assert another.compare(someint, ignore=['b'])
    assert not someint.compare(somereal)
loki-ecmwf-0.3.6/loki/types/symbol_table.py0000664000175000017500000002553415167130205021102 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Representation of symbol tables and scopes in
:doc:`internal_representation`
"""

import weakref

from loki.tools import as_tuple
from loki.types.datatypes import BasicType, DataType


__all__ = ['SymbolAttributes', 'SymbolTable']


class SymbolAttributes:
    """
    Representation of a symbol's attributes, such as data type and declared
    properties

    It has a fixed :any:`DataType` associated with it, available as property
    :attr:`SymbolAttributes.dtype`.

    Any other properties can be attached on-the-fly, thus allowing to store
    arbitrary metadata for a symbol, e.g., declaration attributes such as
    ``POINTER``, ``ALLOCATABLE``, or the shape of an array, or structural
    information, e.g., whether a variable is a loop index, argument, etc.

    There is no need to check for the presence of attributes, undefined
    attributes can be queried and default to `None`.

    Parameters
    ----------
    dtype : :any:`DataType`
        The data type associated with the symbol
    **kwargs : optional
        Any attributes that should be stored as properties
    """

    def __init__(self, dtype, **kwargs):
        if isinstance(dtype, DataType):
            self.dtype = dtype
        else:
            self.dtype = BasicType.from_str(dtype)

        for k, v in kwargs.items():
            if v is not None:
                self.__setattr__(k, v)

    def __hash__(self):
        return hash(tuple(self.__dict__))

    def __setattr__(self, name, value):
        if value is None and name in dir(self):
            delattr(self, name)
        else:
            object.__setattr__(self, name, value)

    def __getattr__(self, name):
        if name not in dir(self):
            return None
        return object.__getattribute__(self, name)

    def __delattr__(self, name):
        object.__delattr__(self, name)

    def __getstate__(self):
        return self.__dict__

    def __setstate__(self, d):
        self.__dict__.update(d)

    def __repr__(self):
        parameters = [str(self.dtype)]
        for k, v in self.__dict__.items():
            if k in ['dtype', 'source']:
                continue
            if isinstance(v, bool):
                if v:
                    parameters += [str(k)]
            else:
                parameters += [f'{k}={str(v)}']
        return f'<{self.__class__.__name__} {", ".join(parameters)}>'

    def __eq__(self, other):
        """
        Compare :any:`SymbolAttributes` via internal comparison but without execptions.
        """
        return self.compare(other, ignore=None)

    def clone(self, **kwargs):
        """
        Clone the :any:`SymbolAttributes`, optionally overwriting any attributes

        Attributes that should be removed should simply be given as `None`.
        """
        args = self.__dict__.copy()
        args.update(kwargs)
        return self.__class__(**args)

    def compare(self, other, ignore=None):
        """
        Compare :any:`SymbolAttributes` objects while ignoring a set of select attributes.

        Parameters
        ----------
        other : :any:`SymbolAttributes`
            The object to compare with
        ignore : iterable, optional
            Names of attributes to ignore while comparing.

        Returns
        -------
        bool
        """
        ignore_attrs = as_tuple(ignore)
        keys = set(as_tuple(self.__dict__.keys()) + as_tuple(other.__dict__.keys()))
        return all(self.__dict__.get(k) == other.__dict__.get(k)
                   for k in keys if k not in ignore_attrs)


class SymbolTable(dict):
    """
    Lookup table for symbol types that maps symbol names to :any:`SymbolAttributes`

    It is used to store types for declared variables, defined types or imported
    symbols within their respective scope. If its associated scope is nested
    into an enclosing scope, it allows to perform recursive look-ups in parent
    scopes.

    The interface of this table behaves like a :any:`dict`.

    Parameters
    ----------
    parent : :any:`SymbolTable`, optional
        The symbol table of the parent scope for recursive look-ups.
    case_sensitive : bool, optional
        Respect the case of symbol names in lookups (default: `False`).
    """

    def __new__(cls, *args, case_sensitive=False, **kwargs):
        """
        Set the lookup function on object creation, so that they are safe to pickle
        """
        obj = super(SymbolTable, cls).__new__(cls, *args, **kwargs)
        obj._case_sensitive = case_sensitive
        if obj.case_sensitive:
            obj.format_lookup_name = SymbolTable._case_sensitive_format_lookup_name
        else:
            obj.format_lookup_name = SymbolTable._not_case_sensitive_format_lookup_name
        return obj

    def __init__(self, parent=None, **kwargs):
        super().__init__(**kwargs)
        self._parent = weakref.ref(parent) if parent is not None else None

    @property
    def parent(self):
        """
        The symbol table of the parent scope

        Returns
        -------
        :any:`SymbolTable` or `None`
        """
        return self._parent() if self._parent is not None else None

    @parent.setter
    def parent(self, parent):
        self._parent = weakref.ref(parent) if parent is not None else None

    @property
    def case_sensitive(self):
        """
        Indicate if the :any:`SymbolTable` is case-sensitive when looking up
        names

        Returns
        -------
        `bool`
        """
        return self._case_sensitive  # pylint: disable=no-member

    def format_lookup_name(self, name):  # pylint: disable=method-hidden
        """
        Format a variable name for look-up (e.g., convert to lower case if
        case-insensitive)

        Parameters
        ----------
        name : `str`
            the name to look up

        Returns
        -------
        str :
            the name used for look-ups
        """

    @staticmethod
    def _case_sensitive_format_lookup_name(name):
        name = name.partition('(')[0]  # Remove any dimension parameters
        return name

    @staticmethod
    def _not_case_sensitive_format_lookup_name(name):
        name = name.lower()
        name = name.partition('(')[0]  # Remove any dimension parameters
        return name

    def _lookup_formatted_name(self, name, recursive):
        """
        Helper routine to recursively look for a symbol in the table.

        Look-ups should always be done via :meth:`lookup` as this makes sure
        the look-up name is formatted according to the expected format.

        Parameters
        ----------
        name : `str`
            the name to look for, formatted according to :meth:`format_lookup_name`
        recursive : `bool`
            recursive look-up in parent tables
        """
        value = super().get(name, None)
        if value is None and recursive and self.parent is not None:
            return self.parent._lookup_formatted_name(name, recursive)
        return value.clone() if value is not None else None

    def lookup(self, name, recursive=True):
        """
        Look-up a symbol in the symbol table and return the type or `None` if not found.

        Parameters
        ----------
        name : `str`
            Name of the type or symbol
        recursive : `bool`, optional
            If no entry by that name is found, try to find it in the table of the parent scope

        Returns
        -------
        :any:`SymbolAttributes` or `None`
        """
        formatted_name = self.format_lookup_name(name)  # pylint: disable=assignment-from-no-return
        value = self._lookup_formatted_name(formatted_name, recursive)
        return value

    def __contains__(self, key):
        return super().__contains__(self.format_lookup_name(key))

    def __getitem__(self, key):
        value = self.lookup(key, recursive=False)
        if value is None:
            raise KeyError(key)
        return value.clone()

    def get(self, key, default=None):
        """
        Get a symbol's entry without recursive lookup

        Parameters
        ----------
        key : `str`
            Name of the type or symbol
        default : optional
            Return this value if :attr:`key` is not found in the table
        """
        value = self.lookup(key, recursive=False)
        return value.clone() if value is not None else default

    def __setitem__(self, key, value):
        assert isinstance(value, SymbolAttributes)
        name_parts = self.format_lookup_name(key)  # pylint: disable=assignment-from-no-return
        super().__setitem__(name_parts, value.clone())

    def __hash__(self):
        return hash(tuple(self.keys()))

    def __repr__(self):
        return f''

    def __getstate__(self):
        _ignored = ('_parent', )
        return {k: v for k, v in self.__dict__.items() if k not in _ignored}

    def __setstate__(self, s):
        self.__dict__.update(s)

        self._parent = None

    def setdefault(self, key, default=None):
        """
        Insert a default value for a key into the table if it does not exist

        Parameters
        ----------
        key : `str`
            Name of the type or symbol
        default : optional
            The default value to store for the key. Defaults to
            ``SymbolAttributes(BasicType.DEFERRED)``.
        """
        if default is None:
            default = SymbolAttributes(BasicType.DEFERRED)
        assert isinstance(default, SymbolAttributes)
        super().setdefault(self.format_lookup_name(key), default.clone())

    def update(self, other):
        """
        Update this symbol table with entries from :attr:`other`
        """
        if isinstance(other, dict):
            other = {self.format_lookup_name(k): v.clone() for k, v in other.items()}
        else:
            other = {self.format_lookup_name(k): v.clone() for k, v in other}
        super().update(other)

    def clone(self, **kwargs):
        """
        Create a copy of the symbol table with the option to override individual
        parameters

        Parameters
        ----------
        **kwargs :
            Any parameters from the constructor of :any:`SymbolTable`

        Returns
        -------
        :any:`SymbolTable`
            The clone symbol table with copies of all :any:`SymbolAttributes`
        """
        if self.case_sensitive and 'case_sensitive' not in kwargs:
            kwargs['case_sensitive'] = self.case_sensitive
        if self.parent and 'parent' not in kwargs:
            kwargs['parent'] = self.parent
        obj = type(self)(**kwargs)
        obj.update(self)
        return obj
loki-ecmwf-0.3.6/loki/types/scope.py0000664000175000017500000001227515167130205017535 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Representation of symbol tables and scopes in
:doc:`internal_representation`
"""

from dataclasses import dataclass, field, InitVar
import weakref

from loki.tools import WeakrefProperty
from loki.types.symbol_table import SymbolTable


__all__ = ['Scope']


@dataclass(frozen=True)
class Scope:
    """
    Scoping object that manages type caching and derivation for typed symbols.

    The :any:`Scope` provides a symbol table that uniquely maps a symbol's name
    to its :any:`SymbolAttributes` or, for a derived type definition, directly
    to its :any:`DerivedType`.

    See :any:`SymbolTable` for more details on how to look-up symbols.

    Parameters
    ----------
    parent : :any:`Scope`, optional
        The enclosing scope, thus allowing recursive look-ups
    symbol_attrs : :any:`SymbolTable`, optional
        Use the given symbol table instead of instantiating a new
    """

    symbol_attrs: SymbolTable = field(default_factory=SymbolTable, init=False)
    parent: InitVar[object] = WeakrefProperty(default=None, frozen=True)

    def __post_init__(self, parent=None):
        self._reset_parent(parent)

        assert isinstance(self.symbol_attrs, SymbolTable)
        self.symbol_attrs.parent = None if self.parent is None else self.parent.symbol_attrs

    def __repr__(self):
        """
        String representation.
        """
        return f'Scope<{id(self)}>'

    @property
    def parents(self):
        """
        All parent scopes enclosing the current scope, with the top-level
        scope at the end of the list

        Returns
        -------
        tuple
            The list of parent scopes
        """
        parent = self.parent
        if parent:
            return parent.parents + (parent,)
        return ()

    def rescope_symbols(self):
        """
        Make sure all symbols declared and used inside this node belong
        to a scope in the scope hierarchy
        """
        from loki.ir import AttachScopes  # pylint: disable=import-outside-toplevel,cyclic-import
        AttachScopes().visit(self, scope=self)

    def make_complete(self, **frontend_args):
        """
        Trigger a re-parse of the object if incomplete to produce a full Loki IR

        See :any:`ProgramUnit.make_complete` for more details.

        This method relays the call only to the :attr:`parent`.
        """
        if hasattr(super(), 'make_complete'):
            super().make_complete(**frontend_args)
        self.parent.make_complete(**frontend_args)

    def clone(self, **kwargs):
        """
        Create a copy of the scope object with the option to override individual
        parameters

        Note that this will also create a copy of the symbol table via
        :any:`SymbolTable.clone` and force rescoping of variables,
        unless :attr:`symbol_attrs` and :attr:`rescope_symbols` are explicitly
        specified.

        Parameters
        ----------
        **kwargs : Any parameter from the constructor

        Returns
        -------
        `type(self)`
            The cloned scope object
        """
        if self.parent and 'parent' not in kwargs:
            kwargs['parent'] = self.parent
        if 'symbol_attrs' not in kwargs:
            kwargs['symbol_attrs'] = self.symbol_attrs.clone(parent=kwargs.get('parent'))
            kwargs['rescope_symbols'] = True

        if hasattr(self, '_rebuild'):
            # When cloning IR nodes with a Scope mix-in we need to use the
            # rebuild mechanism
            return self._rebuild(**kwargs)  # pylint: disable=no-member
        return type(self)(**kwargs)

    def get_symbol_scope(self, name):
        """
        Find the scope in which :attr:`name` is declared

        This performs a recursive lookup in the :any:`SymbolTable` to find
        the scope in which :attr:`name` is declared. Note, that this may be
        the scope with a :any:`Import` of this name and not the original
        declaration.

        Parameters
        ----------
        name : `str`
            The name of the symbol to look for

        Returns
        -------
        :any:`Scope` or `None`
            The scope object in which the symbol is declared, or `None` if
            not found
        """
        scope = self
        while scope is not None:
            if name in scope.symbol_attrs:
                return scope
            scope = scope.parent
        return None

    def _reset_parent(self, parent):
        """
        Private method to reset the parent of a :any:`Scope` and
        update the symbol table accordingly.

        Parameters
        ----------
        parent : :any:`Scope`, optional
            The enclosing scope, thus allowing recursive look-ups
        """
        self.__dict__['_parent'] = weakref.ref(parent) if parent is not None else None

        if self.parent is not None:
            self.symbol_attrs.parent = self.parent.symbol_attrs
loki-ecmwf-0.3.6/loki/types/datatypes.py0000664000175000017500000000615715167130205020424 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Collection of classes to represent type information for symbols used throughout
:doc:`internal_representation`
"""

from enum import Enum

from loki.tools import flatten


__all__ = ['DataType', 'BasicType']


class DataType:
    """
    Base class for data types a symbol may have
    """


class BasicType(DataType, int, Enum):
    """
    Representation of intrinsic data types, names taken from the FORTRAN convention.

    Currently, there are

    - :any:`LOGICAL`
    - :any:`INTEGER`
    - :any:`REAL`
    - :any:`CHARACTER`
    - :any:`COMPLEX`

    and, to indicate an undefined data type (e.g., for imported
    symbols whose definition is not available), :any:`DEFERRED`.

    For convenience, string representations of FORTRAN and C99 types can be
    heuristically converted.
    """

    DEFERRED = -1
    LOGICAL = 1
    INTEGER = 2
    REAL = 3
    CHARACTER = 4
    COMPLEX = 5

    @classmethod
    def from_str(cls, value):
        """
        Try to convert the given string using one of the `from_*` methods.
        """
        lookup_methods = (cls.from_name, cls.from_fortran_type, cls.from_c99_type)
        for meth in lookup_methods:
            try:
                return meth(value)
            except KeyError:
                pass
        raise ValueError(f'Unknown data type: {value}')

    @classmethod
    def from_name(cls, value):
        """
        Convert the given string representation of the :any:`BasicType`.
        """
        return {t.name: t for t in cls}[value]

    @classmethod
    def from_fortran_type(cls, value):
        """
        Convert the given string representation of a FORTRAN type.
        """
        type_map = {'logical': cls.LOGICAL, 'integer': cls.INTEGER, 'real': cls.REAL,
                    'double precision': cls.REAL, 'double complex': cls.COMPLEX,
                    'character': cls.CHARACTER, 'complex': cls.COMPLEX}
        return type_map[value.lower()]

    @classmethod
    def from_c99_type(cls, value):
        """
        Convert the given string representation of a C99 type.
        """
        logical_types = ['bool', '_Bool']
        integer_types = ['short', 'int', 'long', 'long long']
        integer_types += flatten([(f'signed {t}', f'unsigned {t}') for t in integer_types])
        real_types = ['float', 'double', 'long double']
        character_types = ['char']
        complex_types = ['float _Complex', 'double _Complex', 'long double _Complex']

        type_map = {t: cls.LOGICAL for t in logical_types}
        type_map.update({t: cls.INTEGER for t in integer_types})
        type_map.update({t: cls.REAL for t in real_types})
        type_map.update({t: cls.CHARACTER for t in character_types})
        type_map.update({t: cls.COMPLEX for t in complex_types})

        return type_map[value]
loki-ecmwf-0.3.6/loki/types/module_type.py0000664000175000017500000000474415167130205020754 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

""" Representation of a module type for module definitions. """

import weakref

from loki.tools import LazyNodeLookup
from loki.types.datatypes import BasicType, DataType


__all__ = ['ModuleType']


class ModuleType(DataType):
    """
    Representation of a module definition.

    This serves as a caching mechanism for module definitions in symbol tables.

    Parameters
    ----------
    name : str, optional
        The name of the module. Can be skipped if :data:`module`
        is provided (not in the form of a :any:`LazyNodeLookup`)
    module : :any:`Module` :any:`LazyNodeLookup`, optional
        The procedure this type represents
    """

    def __init__(self, name=None, module=None):
        from loki.module import Module  # pylint: disable=import-outside-toplevel,cyclic-import
        super().__init__()
        assert name or isinstance(module, Module)
        if module is None or isinstance(module, LazyNodeLookup):
            self._module = module
            self._name = name
        else:
            self._module = weakref.ref(module)
            # Cache all properties for when module link becomes inactive
            assert name is None or name.lower() == self.module.name.lower()
            self._name = self.module.name

    @property
    def name(self):
        """
        The name of the module

        This looks up the name in the linked :attr:`module` if available, otherwise
        returns the name stored during instantiation of the :any:`ModuleType` object.
        """
        return self._name if self.module is BasicType.DEFERRED else self.module.name

    @property
    def module(self):
        """
        The :any:`Module` object represented by this type

        If not provided during instantiation or if the underlying :any:`weakref` is dead,
        this returns :any:`BasicType.DEFERRED`.
        """
        if self._module is None:
            return BasicType.DEFERRED
        if self._module() is None:
            return BasicType.DEFERRED
        return self._module()

    def __str__(self):
        return self.name

    def __repr__(self):
        return f''
loki-ecmwf-0.3.6/loki/types/procedure_type.py0000664000175000017500000001371315167130205021453 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

""" Representation of a procedure type for function or subroutine type definitions. """

import weakref

from loki.tools import LazyNodeLookup
from loki.types.datatypes import BasicType, DataType


__all__ = ['ProcedureType']


class ProcedureType(DataType):
    """
    Representation of a function or subroutine type definition.

    This serves also as the cross-link between the use of a procedure (e.g. in a
    :any:`CallStatement`) to the :any:`Subroutine` object that is the target of
    a call. If the corresponding object is not yet available when the
    :any:`ProcedureType` object is created, or its definition is transient and
    subject to IR rebuilds (e.g. :any:`StatementFunction`), the :any:`LazyNodeLookup`
    utility can be used to defer the actual instantiation. In that situation,
    :data:`name` should be provided in addition.

    Parameters
    ----------
    name : str, optional
        The name of the function or subroutine. Can be skipped if :data:`procedure`
        is provided (not in the form of a :any:`LazyNodeLookup`)
    is_function : bool, optional
        Indicate that this is a function
    is_generic : bool, optional
        Indicate that this is a generic function
    is_intrinsic : bool, optional
        Indicate that this is an intrinsic function
    procedure : :any:`Subroutine` or :any:`StatementFunction` or :any:`LazyNodeLookup`, optional
        The procedure this type represents
    """

    def __init__(
            self, name=None, is_function=None, is_generic=False,
            is_intrinsic=False, procedure=None, return_type=None
    ):
        # pylint: disable=import-outside-toplevel,cyclic-import
        from loki.subroutine import Subroutine
        from loki.types.symbol_table import SymbolAttributes

        super().__init__()
        assert name or isinstance(procedure, Subroutine)
        assert isinstance(return_type, SymbolAttributes) or procedure or not is_function or is_intrinsic
        self.is_generic = is_generic
        self.is_intrinsic = is_intrinsic
        if procedure is None or isinstance(procedure, LazyNodeLookup):
            self._procedure = procedure
            self._name = name
            self._is_function = is_function or False
            self._return_type = return_type
            # NB: not applying an assert on the procedure name for LazyNodeLookup as
            # the point of the lazy lookup is that we might not have the the procedure
            # definition available at type instantiation time
        else:
            self._procedure = weakref.ref(procedure)
            # Cache all properties for when procedure link becomes inactive
            assert name is None or name.lower() == self.procedure.name.lower()
            self._name = self.procedure.name
            assert is_function is None or is_function == self.procedure.is_function
            self._is_function = self.procedure.is_function
            # TODO: compare return type once type comparison is more robust
            self._return_type = self.procedure.return_type if self.procedure.is_function else None

    @property
    def _canonical(self):
        return (self._name, self._procedure, self.is_function, self.is_generic, self.return_type)

    def __eq__(self, other):
        if isinstance(other, ProcedureType):
            return self._canonical == other._canonical
        return super().__eq__(other)

    def __hash__(self):
        return hash(self._canonical)

    @property
    def name(self):
        """
        The name of the procedure

        This looks up the name in the linked :attr:`procedure` if available, otherwise
        returns the name stored during instanation of the :any:`ProcedureType` object.
        """
        return self._name if self.procedure is BasicType.DEFERRED else self.procedure.name

    @property
    def procedure(self):
        """
        The :any:`Subroutine` object of the procedure

        If not provided during instantiation or if the underlying :any:`weakref` is dead,
        this returns :any:`BasicType.DEFERRED`.
        """
        if self._procedure is None:
            return BasicType.DEFERRED
        if self._procedure() is None:
            return BasicType.DEFERRED
        return self._procedure()

    @property
    def is_function(self):
        """
        Return `True` if the procedure is a function, otherwise `False`
        """
        if self.procedure is BasicType.DEFERRED:
            return self._is_function
        return self.procedure.is_function

    @property
    def is_elemental(self):
        """
        Return ``True`` if the procedure has the ``elemental`` prefix, otherwise ``False``
        """
        if self.procedure is BasicType.DEFERRED:
            return False
        if not hasattr(self.procedure, 'prefix'):
            # StatementFunction objects have no prefix!
            # This will be fixed once procedures are unified
            return False
        return 'elemental'.lower() in tuple(pre.lower() for pre in self.procedure.prefix)

    @property
    def return_type(self):
        """
        The return type of the function (or `None`)
        """
        if not self.is_function:
            return None
        if self.procedure is BasicType.DEFERRED:
            return self._return_type
        return self.procedure.return_type

    def __str__(self):
        return self.name

    def __repr__(self):
        return f''

    def __getstate__(self):
        _ignore = ('_procedure', )
        return dict((k, v) for k, v in self.__dict__.items() if k not in _ignore)

    def __setstate__(self, s):
        self.__dict__.update(s)

        self._procedure = None
loki-ecmwf-0.3.6/loki/types/derived_type.py0000664000175000017500000000344715167130205021110 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

""" Representation of a derived type with local symbol table. """

from loki.types.datatypes import BasicType, DataType


__all__ = ['DerivedType']


class DerivedType(DataType):
    """
    Representation of derived data types that may have an associated :any:`TypeDef`

    Please note that the typedef attribute may be of :any:`TypeDef` or
    :any:`BasicType.DEFERRED`, if the associated type definition is not available.

    Parameters
    ----------
    name : str, optional
        The name of the derived type. Can be omitted if :data:`typedef` is provided
    typedef : :any:`TypeDef`, optional
        The definition of the derived type. Takes precedence over :data:`name`
    """

    def __init__(self, name=None, typedef=None):
        super().__init__()
        assert name or typedef
        self._name = name
        self.typedef = typedef if typedef is not None else BasicType.DEFERRED

    @property
    def name(self):
        return self._name if self.typedef is BasicType.DEFERRED else self.typedef.name

    def __str__(self):
        return self.name

    def __repr__(self):
        return f''

    @property
    def _canonical(self):
        return (self._name, self.typedef)

    def __eq__(self, other):
        if isinstance(other, DerivedType):
            return self._canonical == other._canonical
        return super().__eq__(other)

    def __hash__(self):
        return hash(self._canonical)
loki-ecmwf-0.3.6/loki/dimension.py0000664000175000017500000001536615167130205017251 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.tools import as_tuple

__all__ = ['Dimension']


class Dimension:
    """
    Dimension object that defines a one-dimensional data and iteration space.

    Parameters
    ----------
    name : string
        Name of the dimension to identify in configurations
    index : string or tuple of str
        String representation of the predominant loop index variable
        associated with this dimension; can be one or several.
    size : string or tuple of str
        String representation of the predominant size variable used
        to declare array shapes; can be one or several.
    lower : string or tuple of str
        String representation of the lower bound variable used to
        declare iteration spaces; can be one or several.
    lower : string or tuple of str
        String representation of the upper bound variable used to
        declare iteration spaces; can be one or several.
    bounds : tuple of strings
        String representation of the variables usually used to denote
        the iteration bounds of this dimension.

        **WARNING:** This argument is deprecated, instead ``lower``
        and ``upper`` should be used.
    aliases : list or tuple of strings
        String representations of alternative size variables that are
        used to define arrays shapes of this dimension (eg. alternative
        names used in "driver" subroutines).

        **WARNING:** This argument is deprecated, instead a tuple of
        variables names should be provided for ``size``.
    bounds_aliases : list or tuple of strings
        String representations of alternative bounds variables that are
        used to define loop ranges.

        **WARNING:** This argument is deprecated, instead a tuple of
        variables names should be provided for ``lower`` and
        ``upper``.
    index_aliases : list or tuple of strings
        String representations of alternative loop index variables associated
        with this dimension.

        **WARNING:** This argument is deprecated, instead a tuple of
        variables names should be provided for ``index``.
    """

    def __init__(
            self, name=None, index=None, size=None, lower=None,
            upper=None, step=None, aliases=None, bounds=None,
            bounds_aliases=None, index_aliases=None
    ):
        self.name = name

        if bounds:
            # Backward compat for ``bounds`` contructor argument
            assert not lower and not upper and len(bounds) == 2
            lower = (bounds[0],)
            upper = (bounds[1],)

        # Store one or more strings for dimension variables
        self._index = as_tuple(index) or None
        self._size = as_tuple(size) or None
        self._lower = as_tuple(lower) or None
        self._upper = as_tuple(upper) or None
        self._step = as_tuple(step) or None

        # Keep backward-compatibility for constructor arguments
        if aliases:
            self._size += as_tuple(aliases)
        if index_aliases:
            self._index += as_tuple(index_aliases)
        if bounds_aliases:
            self._lower = as_tuple(self._lower) + (bounds_aliases[0],)
            self._upper = as_tuple(self._upper) + (bounds_aliases[1],)

    def __repr__(self):
        """ Pretty-print dimension details """
        name = f'<{self.name}>' if self.name else ''
        index = str(self.index) or ''
        size = str(self.size) or ''
        bounds = ','.join(str(b) for b in self.bounds) if self.bounds else ''
        return f'Dimension{name}[{index},{size},({bounds})]'

    @property
    def variables(self):
        return (self.index, self.size) + self.bounds

    @property
    def sizes(self):
        """
        Tuple of strings that match the primary size and all secondary size expressions.

        .. note::
            For derived expressions, like ``end - start + 1`` or
            ``1:size``, please use :any:`size_expressions`.
        """
        return self._size

    @property
    def size(self):
        """
        String that matches the primary size expression of a data space (variable allocation).
        """
        return self.sizes[0] if self.sizes else None

    @property
    def indices(self):
        """
        Tuple of strings that matche the primary index and all secondary index expressions.
        """
        return self._index

    @property
    def index(self):
        """
        String that matches the primary index expression of an iteration space (loop).
        """
        return self.indices[0] if self.indices else None

    @property
    def lower(self):
        """
        String or tuple of strings that matches the lower bound of the iteration space.
        """
        return self._lower[0] if self._lower and len(self._lower) == 1 else self._lower

    @property
    def upper(self):
        """
        String or tuple of strings that matches the upper bound of the iteration space.
        """
        return self._upper[0] if self._upper and len(self._upper) == 1 else self._upper

    @property
    def step(self):
        """
        String or tuple of strings that matches the step size of the iteration space.
        """
        return self._step[0] if self._step and len(self._step) == 1 else self._step

    @property
    def bounds(self):
        """
        Tuple of expression string that represent the bounds of an iteration space.

        .. note::
            If mutiple lower or upper bound string have been provided,
            only the first pair will be used.
        """
        return (
            self.lower[0] if isinstance(self.lower, tuple) else self.lower,
            self.upper[0] if isinstance(self.upper, tuple) else self.upper
        )

    @property
    def range(self):
        """
        String that matches the range expression of an iteration space (loop).

        .. note::
            If mutiple lower or upper bound string have been provided,
            only the first pair will be used.
        """
        return f'{self.bounds[0]}:{self.bounds[1]}'

    @property
    def size_expressions(self):
        """
        A list of all expression strings representing the size of a data space.

        This includes generic aliases, like ``end - start + 1`` or ``1:size`` ranges.
        """
        exprs = self.sizes
        exprs += (f'1:{self.size}', )
        if self.bounds:
            exprs += (f'{self.bounds[1]} - {self.bounds[0]} + 1', )
            exprs += (f'{self.bounds[0]}:{self.bounds[1]}', )
        return exprs
loki-ecmwf-0.3.6/loki/tools/0000775000175000017500000000000015167130205016037 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/tools/__init__.py0000664000175000017500000000106515167130205020152 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
"""
Collection of tools and utility methods used throughout Loki.
"""

from loki.tools.util import *  # noqa
from loki.tools.files import *  # noqa
from loki.tools.strings import *  # noqa
loki-ecmwf-0.3.6/loki/tools/tests/0000775000175000017500000000000015167130205017201 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/tools/tests/__init__.py0000664000175000017500000000057015167130205021314 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
loki-ecmwf-0.3.6/loki/tools/tests/test_tools.py0000664000175000017500000003723115167130205021760 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Unit tests for utility functions and classes in loki.tools.
"""

import platform
import sys
import operator as op
from contextlib import contextmanager
from pathlib import Path
from subprocess import CalledProcessError
from time import sleep, perf_counter
import pytest

try:
    import yaml
    HAVE_YAML = True
except ImportError:
    HAVE_YAML = False

from loki.config import config_override
from loki.tools import (
    JoinableStringList, truncate_string, binary_insertion_sort, is_subset,
    optional, yaml_include_constructor, execute, timeout, dict_override,
    LokiTempdir, stdchannel_is_captured, stdchannel_redirected, OrderedSet
)


@pytest.fixture(scope='module', name='here')
def fixture_here():
    return Path(__file__).parent


@pytest.mark.parametrize('a, b, ref', [
    ((1, 2), (0, 1, 0, 2, 3), True),
    ((1, 2), (0, 2, 0, 1, 3), False),
    ((1, 2), (1, 0, 2, 3), True),
    ((1, 2), (1, 2), True),
    ((2, 1), (1, 2), False),
    ((1, 2), (1, 0, 2), True),
    ((), (1,), False),
    ((1,), (), False),
    ((), (), False),
    ((0, 0), (0, 1, 0, 2, 0, 3), True),
    ((0, 0), (0, 1, 2, 3), False),
])
def test_is_subset_ordered(a, b, ref):
    """
    Test :any:`is_subset` with ordered data types.
    """
    assert is_subset(a, b, ordered=True) == ref


@pytest.mark.parametrize('a, b, ref', [
    ((1, 2), (0, 1, 2, 3), True),
    ((1, 2), (0, 1, 2), True),
    ((1, 2), (1, 2, 3), True),
    ((1, 2), (1, 2), True),
    ((0, 1, 2, 3), (1, 2), False),
    ((0, 1, 2), (1, 2), False),
    ((1, 2, 3), (1, 2), False),
    ([1], (0, 1, 2), True),
    ((0, 1), [0, 1, 2, 3], True),
    ((1, 0), (0, 1), False),
    ((1,), (1, 2), True),
    ((1, 2), (1, 0, 2), False),
    ((), (1,), False),
    ((1,), (), False),
    ((), (), False),
    ((0, 0), (0, 1, 0, 2, 0, 3), False),
])
def test_is_subset_ordered_subsequent(a, b, ref):
    """
    Test :any:`is_subset` with ordered data types.
    """
    assert is_subset(a, b, ordered=True, subsequent=True) == ref


@pytest.mark.parametrize('a, b, ref', [
    ((1, 2), (0, 1, 2, 3), True),
    ((1, 2), (0, 1, 2), True),
    ((1, 2), (1, 2, 3), True),
    ((1, 2), (1, 2), True),
    ((0, 1, 2, 3), (1, 2), False),
    ((0, 1, 2), (1, 2), False),
    ((1, 2, 3), (1, 2), False),
    ([1], (0, 1, 2), True),
    ((0, 1), [0, 1, 2, 3], True),
    ((1, 0), (0, 1), True),
    ((1,), (1, 2), True),
    ((1, 2), (1, 0, 2), True),
])
def test_is_subset_not_ordered(a, b, ref):
    """
    Test :any:`is_subset` with ordered data types.
    """
    assert is_subset(a, b, ordered=False) == ref


@pytest.mark.parametrize('a, b', [
    ({1, 2}, [1, 2]),
    ([1, 2], {1, 2}),
])
def test_is_subset_raises(a, b):
    with pytest.raises(ValueError):
        is_subset(a, b, ordered=True)


@pytest.mark.parametrize('items, sep, width, cont, ref', [
    ([''], ' ', 90, '\n', ''),
    ([], ' ', 90, '\n', ''),
    (('H', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', '!'), '', 90, '\n', 'Hello world!'),
    (('H', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', '!'), '', 7, '\n', 'Hello \nworld!'),
    (('Hello', 'world!'), ' ', 90, '\n', 'Hello world!'),
    (('Hello', 'world!'), ' ', 7, '\n', 'Hello \nworld!'),
    (('Hello', 'world!'), ' ', 5, '\n', 'Hello\n \nworld!'),
    ((JoinableStringList(['H', 'e', 'l', 'l', 'o'], '', 5, '\n'), 'world!'), ' ', 5, '\n',
     'Hell\no \nworld!'),
    (('Hello', JoinableStringList(['w', 'o', 'r', 'l', 'd', '!'], '', 8, '\n', separable=False)),
     ' ', 8, '\n', 'Hello \nworld!'),
    (('Hello', JoinableStringList(['w', 'o', 'r', 'l', 'd', '!'], '', 8, '\n', separable=True)),
     ' ', 8, '\n', 'Hello w\norld!'),
    (["'Long word   '", '"list with    "', '"several entries   "', "'that may have'",
      "' trailing whitespace '", "'characters     '"], ', ', 20, '\n',
     ("'Long word   ', \n\"list with    \", \n\"several entries   \"\n, 'that may have', \n"
      "' trailing whitespace '\n, 'characters     '"))
])
def test_joinable_string_list(items, sep, width, cont, ref):
    """
    Test JoinableStringList for some common scenarios.
    """
    obj = JoinableStringList(items, sep, width, cont)
    assert str(obj) == ref


def test_joinable_string_list_long():
    """
    Test JoinableStringList with some long edge cases.
    """
    attributes = ['REAL(KIND=JPRB)', 'INTENT(IN)']
    attributes = JoinableStringList(attributes, ', ', 132, ' &\n   & ')
    variables = ['PDHTLS(KPROMA, YDMODEL%YRML_PHY_G%YRDPHY%NTILES, '
                 'YDMODEL%YRML_DIAG%YRMDDH%NDHVTLS + YDMODEL%YRML_DIAG%YRMDDH%NDHFTLS)']
    variables = JoinableStringList(variables, ', ', 132, ' &\n   & ')
    items = ['  ', attributes, ' :: ', variables]
    obj = JoinableStringList(items, '', 132, ' &\n  & ')
    ref = ('  REAL(KIND=JPRB), INTENT(IN) ::  &\n'
           '  & PDHTLS(KPROMA, YDMODEL%YRML_PHY_G%YRDPHY%NTILES, '
           'YDMODEL%YRML_DIAG%YRMDDH%NDHVTLS + YDMODEL%YRML_DIAG%YRMDDH%NDHFTLS)')
    assert str(obj) == ref

    name = 'io.output'
    args = ['"tensor_out"', 'tensor_out',
            'new DFEVectorType>(new DFEVectorType(dfeFloat(11, 53), m), n)']
    args_list = JoinableStringList(args, sep=', ', width=90, cont='\n      ', separable=True)
    items = ['    ', name, '(', args_list, ');']
    items_list = JoinableStringList(items, sep='', width=90, cont='\n      ', separable=True)
    line = str(items_list)
    ref = ('    io.output("tensor_out", tensor_out, \n'
           '      new DFEVectorType>(new DFEVectorType(dfeFloat(11, 53), m)\n'
           '      , n));')
    assert line == ref

    args = ['my_long_var = 5+3*tendency_loc(ibl)%T(jl,jk)']
    obj = JoinableStringList(args, sep=' ', width=40, cont=' &\n & ')
    ref = ('my_long_var =  &\n'
           ' & 5+3*tendency_loc(ibl)%T(jl,jk)')
    assert str(obj) == ref


def test_joinable_string_list_exceeding_indentation():
    items = ['a', 'b', 'c']
    for depth in range(50, 100):
        # This should not throw an assertion
        obj = JoinableStringList(
            items, sep='  ' * depth, width=132,
            cont=f' &\n{"  " * depth}&')
        assert all(len(line) <= 132 for line in str(obj).splitlines())


@pytest.mark.parametrize('string, length, continuation, ref', [
    ('short string', 16, '...', 'short string'),
    ('short string', 12, '...', 'short string'),
    ('short string', 11, '...', 'short st...'),
])
def test_truncate_string(string, length, continuation, ref):
    """
    Test string truncation for different string lengths.
    """
    assert truncate_string(string, length, continuation) == ref


def test_binary_insertion_sort():
    """
    Test binary insertion sort for some random cases.
    """
    items = [37, 23, 0, 17, 12, 72, 31, 46, 100, 88, 54]

    assert binary_insertion_sort(items) == sorted(items)
    assert binary_insertion_sort(items, lt=op.gt) == sorted(items, reverse=True)

    assert binary_insertion_sort(list(range(20))) == list(range(20))
    assert binary_insertion_sort(list(reversed(range(20)))) == list(range(20))

    assert binary_insertion_sort([1] * 5) == [1] * 5


def test_optional():
    @contextmanager
    def dummy_manager(a, b, c):
        ret = a + b + c
        try:
            yield ret
        finally:
            pass

    with optional(True, dummy_manager, 1, c=10, b=100) as val:
        assert val == 111

    with optional(False, dummy_manager, 1, c=10, b=100) as val:
        assert val is None


@pytest.mark.skipif(not HAVE_YAML, reason="Pyyaml is not installed")
def test_yaml_include(here):
    include_yaml = """
foo:
  bar:
  - abc
  - def

foobar:
  - baz:
      dummy: value
  - 42:
      dummy: other_value
    """.strip()

    include_path = here/'include.yml'
    include_path.write_text(include_yaml)

    main_yaml = f"""
include: !include {include_path}

nested_foo: !include {include_path}:["foo"]

nested_foo_list: !include {include_path}:["foo"]["bar"][1]

nested_foobar: !include {include_path}:["foobar"][0]['baz']["dummy"]
    """.strip()

    main_path = here/'main.yml'
    main_path.write_text(main_yaml)

    nested_yaml = f"""
main: !include {main_path}
    """.strip()

    yaml.add_constructor('!include', yaml_include_constructor, yaml.SafeLoader)

    included = yaml.safe_load(include_yaml)
    main = yaml.safe_load(main_yaml)

    assert main['include'] == included
    assert main['nested_foo'] == included['foo']
    assert main['nested_foo_list'] == included['foo']['bar'][1]
    assert main['nested_foobar'] == included['foobar'][0]['baz']['dummy']

    nested = yaml.safe_load(nested_yaml)
    assert nested['main'] == main

    include_path.unlink()
    main_path.unlink()


def test_execute(here, capsys):

    testfile = here/'test_execute.txt'
    if testfile.is_file():
        testfile.unlink()

    # Failure with no output
    cmd = 'false'
    if stdchannel_is_captured(capsys):
        with pytest.raises(CalledProcessError):
            execute(cmd)
    else:
        with capsys.disabled():
            with stdchannel_redirected(sys.stdout, testfile):
                with stdchannel_redirected(sys.stderr, testfile):
                    with pytest.raises(CalledProcessError):
                        execute(cmd)

        assert 'Execution of false failed' in testfile.read_text()
        assert 'Full command: false' in testfile.read_text()
        assert 'Output of the command:' not in testfile.read_text()
        testfile.unlink()

    # Failure with output
    cmd = ['cat', '/not/a/file']
    if stdchannel_is_captured(capsys):
        with pytest.raises(CalledProcessError):
            execute(cmd)
    else:
        with capsys.disabled():
            with stdchannel_redirected(sys.stdout, testfile):
                with stdchannel_redirected(sys.stderr, testfile):
                    with pytest.raises(CalledProcessError):
                        execute(cmd)

        assert 'Execution of cat failed' in testfile.read_text()
        assert f'Full command: {" ".join(cmd)}' in testfile.read_text()
        assert 'Output of the command:' in testfile.read_text()
        assert 'No such file or directory' in testfile.read_text()
        testfile.unlink()

    # Success
    cmd = 'true'
    execute(cmd)


@pytest.mark.skipif(platform.system() == 'Darwin',
    reason='Timeout utility test sporadically fails on MacOS CI runners.'
)
def test_timeout():
    # Should not trigger:
    start = perf_counter()
    with timeout(5):
        sleep(.3)
    stop = perf_counter()
    assert .2 < stop - start < .45

    # Timeout disabled:
    start = perf_counter()
    with timeout(0):
        sleep(.3)
    stop = perf_counter()
    assert .2 < stop - start < .45

    # Default exception
    with pytest.raises(RuntimeError) as exc:
        start = perf_counter()
        with timeout(1):
            sleep(5)
        stop = perf_counter()
        assert .9 < stop - start < 1.15
        assert "Timeout reached after 2 second(s)" in str(exc.value)

    # Custom message
    with pytest.raises(RuntimeError) as exc:
        start = perf_counter()
        with timeout(1, message="My message"):
            sleep(5)
        stop = perf_counter()
        assert .9 < stop - start < 1.15
        assert "My message" in str(exc.value)


def test_dict_override():
    kwargs = {'rick' : 42, 'dave' : 'yeah'}
    with dict_override(kwargs, {'dave' : 'nope', 'joe' : 'huh?'}):
        assert kwargs['dave'] == 'nope'
        assert kwargs['rick'] == 42
        assert kwargs['joe'] == 'huh?'
    assert kwargs['dave'] == 'yeah'
    assert kwargs['rick'] == 42
    assert 'joe' not in kwargs
    assert len(kwargs) == 2


def test_loki_tempdir(here):
    test_tmpdir = here/'loki_tempdir'
    assert not test_tmpdir.exists()

    # Create the object
    tmp_dir = LokiTempdir()

    # The directory still doesn't exist
    assert not test_tmpdir.exists()

    # It's created by the first call to get()
    with config_override({'tmp-dir': str(test_tmpdir)}):
        tmpdir_path = tmp_dir.get()

    # Does the directory exist now and was created under the test_tmpdir?
    assert tmpdir_path.exists()
    assert test_tmpdir.exists()
    assert tmpdir_path.parent == test_tmpdir

    # Create a temporary file
    tmp_file = tmpdir_path/'myfile'
    tmp_file.write_text('Hello world')
    assert tmp_file.exists()

    # Make sure the cleanup works
    tmp_dir.cleanup()
    assert not tmp_file.exists()
    assert not tmpdir_path.exists()

    # But the parent directory should not be deleted
    assert test_tmpdir.exists()
    test_tmpdir.rmdir()


@pytest.mark.parametrize('a,b', [
    (None, None),
    ([], []),
    (['a'], []),
    ([1], [2]),
    ([1, 2], [2]),
    (list(range(100)), list(range(10, 200))),
])
def test_ordered_set(a, b):
    a_set = set(a) if a is not None else set()
    b_set = set(b) if b is not None else set()
    a_oset = OrderedSet(a) if a is not None else OrderedSet()
    b_oset = OrderedSet(b) if b is not None else OrderedSet()

    if a is None:
        a = []
    if b is None:
        b = []

    if a:
        assert repr(a_oset) == f'OrderedSet({a})'
    else:
        assert repr(a_oset) == 'OrderedSet()'

    assert len(a_oset) == len(a_set)
    assert len(b_oset) == len(b_set)

    assert a_oset == a_set
    assert b_oset == b_set
    assert a_oset == a
    assert b_oset == b

    assert all(value in a_oset for value in a)
    assert all(value in b_oset for value in b)
    assert not any(value in a_oset for value in b if value not in a)
    assert all(value in a for value in a_oset)
    assert all(value in b for value in b_oset)

    assert all(v1 == v2 for v1, v2 in zip(a, a_oset))
    assert all(v1 == v2 for v1, v2 in zip(b, b_oset))

    assert all(v1 == v2 for v1, v2 in zip(reversed(a), reversed(a_oset)))
    assert all(v1 == v2 for v1, v2 in zip(reversed(b), reversed(b_oset)))

    assert all(v1 == v2 for v1, v2 in zip(OrderedSet(reversed(a)), reversed(a)))
    assert all(v1 == v2 for v1, v2 in zip(OrderedSet(reversed(b)), reversed(b)))

    assert (a_oset <= b_oset) == (a_set <= b_set)
    assert (a_oset < b_oset) == (a_set < b_set)
    assert (a_oset >= b_oset) == (a_set >= b_set)
    assert (a_oset > b_oset) == (a_set > b_set)

    if len(a) > 1:
        assert (a_oset <= OrderedSet(reversed(a))) and a_oset != OrderedSet(reversed(a))

    assert (a_oset | b_oset) == OrderedSet(a + b)
    assert (a_oset | b_oset) == OrderedSet(a + [value for value in b if value not in a])
    assert (a_oset & b_oset) == OrderedSet(
        [value for value in a if value in b] + [value for value in b if value in a]
    )
    assert (a_oset - b_oset) == OrderedSet([value for value in a if value not in b])
    assert (a_oset ^ b_oset) == OrderedSet(
        [value for value in a if value not in b] + [value for value in b if value not in a]
    )

    assert (a_oset | b_oset) == OrderedSet.union(a_oset, b_oset)

    oset = OrderedSet(a)
    oset.update(b_oset)
    assert oset == (a_oset | b_oset)

    assert a_oset.copy() == a_oset
    assert a_oset.copy() is not a_oset

    if a:
        assert a[0] in a_oset
        a_oset.add(a[0])
        assert a_oset == OrderedSet(a)

        a_oset.remove(a[0])
        assert a[0] not in a_oset

        with pytest.raises(KeyError):
            a_oset.remove(a[0])

    if b:
        assert b[0] in b_oset
        b_oset.discard(b[0])
        assert b[0] not in b_oset
        assert len(b_oset) == len(b) - 1

        b_oset.discard(b[0])
        assert len(b_oset) == len(b) - 1

    if a_oset:
        assert a_oset.pop() == a[-1]
    else:
        with pytest.raises(KeyError):
            a_oset.pop()
loki-ecmwf-0.3.6/loki/tools/util.py0000664000175000017500000006732315167130205017401 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from collections import OrderedDict, defaultdict
from collections.abc import MutableSet, Sequence
from contextlib import contextmanager
from functools import lru_cache
import io
from itertools import groupby
import operator as op
import os
from pathlib import Path
from shlex import split
import signal
from subprocess import run, PIPE, STDOUT, CalledProcessError
import sys
import weakref

from more_itertools import replace

try:
    import yaml
    HAVE_YAML = True
except ImportError:
    HAVE_YAML = False

from loki.logging import debug, error


__all__ = [
    'as_tuple', 'is_iterable', 'is_subset', 'flatten', 'chunks',
    'execute', 'CaseInsensitiveDict', 'CaseInsensitiveDefaultDict',
    'strip_inline_comments',
    'binary_insertion_sort', 'cached_func', 'optional',
    'LazyNodeLookup', 'yaml_include_constructor',
    'auto_post_mortem_debugger', 'set_excepthook', 'timeout',
    'WeakrefProperty', 'group_by_class', 'replace_windowed',
    'dict_override', 'stdchannel_redirected',
    'stdchannel_is_captured', 'graphviz_present',
    'OrderedSet'
]


def as_tuple(item, type=None, length=None):
    """
    Force item to a tuple, even if `None` is provided.
    """
    # Stop complaints about `type` in this function
    # pylint: disable=redefined-builtin

    # Empty list if we get passed None
    if item is None:
        t = ()
    elif isinstance(item, str):
        t = (item,)
    else:
        # Convert iterable to list...
        try:
            t = tuple(item)
        # ... or create a list of a single item
        except (TypeError, NotImplementedError):
            t = (item,) * (length or 1)
    if length and not len(t) == length:
        raise ValueError(f'Tuple needs to be of length {length: d}')
    if type and not all(isinstance(i, type) for i in t):
        raise TypeError(f'Items need to be of type {type}')
    return t


def is_iterable(o):
    """
    Checks if an item is truly iterable using duck typing.

    This was added because :class:`pymbolic.primitives.Expression` provide an ``__iter__`` method
    that throws an exception to avoid being iterable. However, with that method defined it is
    identified as a :class:`collections.Iterable` and thus this is a much more reliable test than
    ``isinstance(obj, collections.Iterable)``.
    """
    try:
        iter(o)
    except TypeError:
        return False
    return True


def is_subset(a, b, ordered=True, subsequent=False):
    """
    Check if all items in iterable :data:`a` are contained in iterable :data:`b`.

    Parameters
    ----------
    a : iterable
        The iterable whose elements are searched in :data:`b`.
    b : iterable
        The iterable of which :data:`a` is tested to be a subset.
    ordered : bool, optional
        Require elements to appear in the same order in :data:`a` and :data:`b`.
    subsequent : bool, optional
        If set to `False`, then other elements are allowed to sit in :data:`b`
        in-between the elements of :data:`a`. Only relevant when using
        :data:`ordered`.

    Returns
    -------
    bool :
        `True` if all elements of :data:`a` are found in :data:`b`, `False`
        otherwise.
    """
    if not ordered:
        return set(a) <= set(b)

    if not isinstance(a, Sequence):
        raise ValueError('a is not a Sequence')
    if not isinstance(b, Sequence):
        raise ValueError('b is not a Sequence')
    if not a:
        return False

    # Search for the first element of a in b and make sure a fits in the
    # remainder of b
    try:
        idx = b.index(a[0])
    except ValueError:
        return False
    if len(a) > (len(b) - idx):
        return False

    if subsequent:
        # Now compare the sequences one by one and bail out if they don't match
        for i, j in zip(a, b[idx:]):
            if i != j:
                return False
        return True

    # When allowing intermediate elements, we search for the next element
    # in the remainder of b after the previous element
    for i in a[1:]:
        try:
            idx = b.index(i, idx+1)
        except ValueError:
            return False
    return True


def flatten(l, is_leaf=None):
    """
    Flatten a hierarchy of nested lists into a plain list.

    :param callable is_leaf: Optional function that gets called for each iterable element
                             to decide if it is to be considered as a leaf that does not
                             need further flattening.
    """
    if is_leaf is None:
        is_leaf = lambda el: False  # pylint: disable=unnecessary-lambda-assignment
    newlist = []
    for el in l:
        if is_iterable(el) and not (isinstance(el, (str, bytes)) or is_leaf(el)):
            for sub in flatten(el, is_leaf):
                newlist.append(sub)
        else:
            newlist.append(el)
    return newlist


def filter_ordered(elements, key=None):
    """
    Filter elements in a list while preserving order.

    :param key: Optional conversion key used during equality comparison.
    """
    seen = set()
    if key is None:
        key = lambda x: x  # pylint: disable=unnecessary-lambda-assignment
    return [e for e in elements if not (key(e) in seen or seen.add(key(e)))]


def chunks(l, n):
    """Yield successive n-sized chunks from l."""
    for i in range(0, len(l), n):
        yield l[i:i + n]


def execute(command, silent=True, **kwargs):
    """
    Execute a single command within a given directory or environment

    Parameters
    ----------
    command` : str or list of str
        The command to execute
    silent : bool, optional
        Suppress output by redirecting stdout/stderr (default: `True`)
    stdout : file object, optional
        Redirect stdout to this file object (Note: :data:`silent` overwrites this)
    stderr : file object, optional
        Redirect stdout to this file object (Note: :data:`silent` overwrites this)
    cwd : str or :class:`pathlib.Path`
        Directory in which to execute :data:`command` (will be stringified)
    """

    cwd = kwargs.pop('cwd', None)
    cwd = cwd if cwd is None else str(cwd)

    if silent:
        kwargs['stdout'] = kwargs.pop('stdout', PIPE)
        kwargs['stderr'] = kwargs.pop('stderr', STDOUT)

    # Some string mangling to support lists and strings
    if isinstance(command, list):
        command = ' '.join(command)
    if isinstance(command, str):
        command = split(command, posix=False)

    debug('[Loki] Executing: %s', ' '.join(command))
    try:
        return run(command, check=True, cwd=cwd, **kwargs)
    except CalledProcessError as e:
        command_str = ' '.join(command)
        error(f'Error: Execution of {command[0]} failed:')
        error(f'  Full command: {command_str}')
        output_str = ''
        if e.stdout:
            output_str += e.stdout.decode() if isinstance(e.stdout, bytes) else e.stdout
        if e.stderr:
            output_str += '\n'
            output_str += e.stderr.decode() if isinstance(e.stderr, bytes) else e.stderr
        if output_str:
            error(f'  Output of the command:\n\n{output_str}')
        raise e


class CaseInsensitiveDict(OrderedDict):
    """
    Dict that ignores the casing of string keys.

    Basic idea from:
    https://stackoverflow.com/questions/2082152/case-insensitive-dictionary
    """
    def __setitem__(self, key, value):
        key = key.lower() if isinstance(key, str) else key
        super().__setitem__(key, value)

    def __getitem__(self, key):
        key = key.lower() if isinstance(key, str) else key
        return super().__getitem__(key)

    def get(self, key, default=None):
        key = key.lower() if isinstance(key, str) else key
        return super().get(key, default)

    def __contains__(self, key):
        key = key.lower() if isinstance(key, str) else key
        return super().__contains__(key)


class CaseInsensitiveDefaultDict(defaultdict):
    """
    Variant of :any:`collections.defaultdict` that ignores the casing of string keys.
    """
    def __setitem__(self, key, value):
        key = key.lower() if isinstance(key, str) else key
        super().__setitem__(key, value)

    def __getitem__(self, key):
        key = key.lower() if isinstance(key, str) else key
        return super().__getitem__(key)

    def get(self, key, default=None):
        key = key.lower() if isinstance(key, str) else key
        return super().get(key, default)

    def __contains__(self, key):
        key = key.lower() if isinstance(key, str) else key
        return super().__contains__(key)


def strip_inline_comments(source, comment_char='!', str_delim='"\''):
    """
    Strip inline comments from a source string and return the modified string.

    Note: this does only work reliably for Fortran strings at the moment (where quotation
    marks are escaped by double quotes and thus the string status is kept correct automatically).

    :param str source: the source line(s) to be stripped.
    :param str comment_char: the character that marks the beginning of a comment.
    :param str str_delim: one or multiple characters that are valid string delimiters.
    """
    if comment_char not in source:
        # No comment, we can bail out early
        return source

    # Split the string into lines and look for the start of comments
    source_lines = source.splitlines()

    def update_str_delim(open_str_delim, string):
        """Run through the string and update the string status."""
        for ch in string:
            if ch in str_delim:
                if open_str_delim == '':
                    # This opens a string
                    open_str_delim = ch
                elif open_str_delim == ch:
                    # TODO: Handle escaping of quotes in general. Fortran just works (TM)
                    # This closes a string
                    open_str_delim = ''
                # else: character is string delimiter but we are inside an open string
                # with a different character used => ignored
        return open_str_delim

    # If we are inside a string this holds the delimiter character that was used
    # to open the current string environment:
    #  '': if not inside a string
    #  'x':  inside a string with x being the opening string delimiter
    open_str_delim = ''

    # Run through lines to strip inline comments
    clean_lines = []
    for line in source_lines:
        end = line.find(comment_char)
        open_str_delim = update_str_delim(open_str_delim, line[:end])

        while end != -1:
            if not open_str_delim:
                # We have found the start of the inline comment, add the line up until there
                clean_lines += [line[:end].rstrip()]
                break
            # We are inside an open string, idx does not mark the start of a comment
            start, end = end, line.find(comment_char, end + 1)
            open_str_delim = update_str_delim(open_str_delim, line[start:end])
        else:
            # No comment char found in current line, keep original line
            clean_lines += [line]
            open_str_delim = update_str_delim(open_str_delim, line[end:])

    return '\n'.join(clean_lines)


def binary_search(items, val, start, end, lt=op.lt):
    """
    Search for the insertion position of a value in a given
    range of items.

    :param list items: the list of items to search.
    :param val: the value for which to seek the position.
    :param int start: first index for search range in ``items``.
    :param int end: last index (inclusive) for search range in ``items``.
    :param lt: the "less than" comparison operator to use. Default is the
        standard ``<`` operator (``operator.lt``).

    :return int: the insertion position for the value.

    This implementation was adapted from
    https://www.geeksforgeeks.org/binary-insertion-sort/.
    """
    # we need to distinugish whether we should insert before or after the
    # left boundary. imagine [0] is the last step of the binary search and we
    # need to decide where to insert -1
    if start == end:
        if lt(val, items[start]):
            return start
        return start + 1

    # this occurs if we are moving beyond left's boundary meaning the
    # left boundary is the least position to find a number greater than val
    if start > end:
        return start

    pos = (start + end) // 2
    if lt(items[pos], val):
        return binary_search(items, val, pos+1, end, lt=lt)
    if lt(val, items[pos]):
        return binary_search(items, val, start, pos-1, lt=lt)
    return pos


def binary_insertion_sort(items, lt=op.lt):
    """
    Sort the given list of items using binary insertion sort.

    In the best case (already sorted) this has linear running time O(n) and
    on average and in the worst case (sorted in reverse order) a quadratic
    running time O(n*n).

    A binary search is used to find the insertion position, which reduces
    the number of required comparison operations. Hence, this sorting function
    is particularly useful when comparisons are expensive.

    :param list items: the items to be sorted.
    :param lt: the "less than" comparison operator to use. Default is the
        standard ``<`` operator (``operator.lt``).

    :return: the list of items sorted in ascending order.

    This implementation was adapted from
    https://www.geeksforgeeks.org/binary-insertion-sort/.
    """
    for i in range(1, len(items)):
        val = items[i]
        pos = binary_search(items, val, 0, i-1, lt=lt)
        items = items[:pos] + [val] + items[pos:i] + items[i+1:]
    return items


def cached_func(func):
    """
    Decorator that memoizes (caches) the result of a function
    """
    return lru_cache(maxsize=None, typed=False)(func)


@contextmanager
def optional(condition, context_manager, *args, **kwargs):
    """
    Apply the context manager only when a condition is fulfilled.

    Based on https://stackoverflow.com/a/41251962.

    Parameters
    ----------
    condition : bool
        The condition that needs to be fulfilled to apply the context manager.
    context_manager :
        The context manager to apply.
    """
    if condition:
        with context_manager(*args, **kwargs) as y:
            yield y
    else:
        yield


class LazyNodeLookup:
    """
    Utility class for indirect, :any:`weakref`-style lookups

    References to IR nodes are usually not stable as the IR may be
    rebuilt at any time. This class offers a way to refer to a node
    in an IR by encoding how it can be found instead.

    .. note::
       **Example:**
       Reference a declaration node that contains variable "a"

       .. code-block::

          from loki import LazyNodeLookup, FindNodes, Declaration
          # Assume this has been initialized before
          # routine = ...

          # Create the reference
          query = lambda x: [d for d in FindNodes(VariableDeclaration).visit(x.spec) if 'a' in d.symbols][0]
          decl_ref = LazyNodeLookup(routine, query)

          # Use the reference (this carries out the query)
          decl = decl_ref()

    Parameters
    ----------
    anchor :
        The "stable" anchor object to which :attr:`query` is applied to find the object.
        This is stored internally as a :any:`weakref`.
    query :
        A function object that accepts a single argument and should return the lookup
        result. To perform the lookup, :attr:`query` is called with :attr:`anchor`
        as argument.
    """

    def __init__(self, anchor, query):
        self._anchor = weakref.ref(anchor)
        self.query = query

    @property
    def anchor(self):
        return self._anchor()

    def __call__(self):
        return self.query(self.anchor)


def yaml_include_constructor(loader, node):
    """
    Add support for ``!include`` tags to YAML load

    Activate via ``yaml.add_constructor("!include", yaml_include_constructor)``
    or ``yaml.add_constructor("!include", yaml_include_constructor, yaml.SafeLoader)``
    (for use with ``yaml.safe_load``).

    Adapted from JUBE2 (https://fz-juelich.de/jsc/jube) and
    http://code.activestate.com/recipes/577612-yaml-include-support/

    This allows to include other YAML files or parts of them inside a YAML file:

    .. code-block:: yaml

        # include.yml
        tag0:
          foo: bar

        tag1:
          baz: bar

    .. code-block:: yaml

        # main.yml
        nested: !include include.yml

        nested_filtered: !include include.yml:["tag0"]

    which is equivalent to the following:

    ..code-block:: yaml

        nested:
          tag0:
            foo: bar
          tag1:
            baz: bar
        nested_filtered:
          baz: bar
    """
    if not HAVE_YAML:
        error('Pyyaml is not installed')
        raise RuntimeError

    # Load the content of the included file
    yaml_node_data = node.value.split(":")
    file = Path(yaml_node_data[0])
    try:
        with file.open() as inputfile:
            content = yaml.load(inputfile.read(), type(loader))
    except OSError:
        error(f'Cannot open include file {file}')
        return f'!include {node.value}'

    # Filter included content if subscripts given
    if len(yaml_node_data) > 1:
        try:
            subscripts = yaml_node_data[1].strip().lstrip('[').rstrip(']').split('][')

            for subscript in subscripts:
                if subscript.isnumeric():
                    content = content[int(subscript)]
                elif subscript[0] == subscript[-1] and subscript[0] in '"\'':
                    content = content[subscript.strip('"\'')]
                else:
                    content = content[subscript]
        except KeyError as e:
            error(f'Cannot extract {yaml_node_data[1]} from {file}')
            raise e

    return content


def auto_post_mortem_debugger(type, value, tb):  # pylint: disable=redefined-builtin
    """
    Exception hook that automatically attaches a debugger

    Activate by calling ``set_excepthook(hook=auto_post_mortem_debugger)``.

    Adapted from https://code.activestate.com/recipes/65287/
    """
    is_interactive = hasattr(sys, 'ps1')
    no_tty = not sys.stderr.isatty() or not sys.stdin.isatty() or not sys.stdout.isatty()
    if is_interactive or no_tty or type == SyntaxError:
        # we are in interactive mode or we don't have a tty-like
        # device, so we call the default hook
        sys.__excepthook__(type, value, tb)
    else:
        import traceback # pylint: disable=import-outside-toplevel
        import pdb # pylint: disable=import-outside-toplevel
        # we are NOT in interactive mode, print the exception...
        traceback.print_exception(type, value, tb)
        # ...then start the debugger in post-mortem mode.
        pdb.post_mortem(tb)   # pylint: disable=no-member


def set_excepthook(hook=None):
    """
    Set an exception hook that is called for uncaught exceptions

    This can be called with :meth:`auto_post_mortem_debugger` to automatically
    attach a debugger (Pdb or, if installed, Pdb++) when exceptions occur.

    With :data:`hook` set to `None`, this will restore the default exception
    hook ``sys.__excepthook``.
    """
    if hook is None:
        sys.excepthook = sys.__excepthook__
    else:
        sys.excepthook = hook


@contextmanager
def timeout(time_in_s, message=None):
    """
    Context manager that specifies a timeout for the code section in its body

    This is implemented by installing a signal handler for :any:`signal.SIGALRM`
    and scheduling that signal for :data:`time_in_s` in the future.
    For that reason, this context manager cannot be nested.

    A value of 0 for :data:`time_in_s` will not install any timeout.

    The following example illustrates the usage, which will result in a
    :any:`RuntimeError` being raised.

    .. code-block::

       with timeout(5):
           sleep(10)

    Parameters
    ----------
    time_in_s : int
        Timeout in seconds after which to interrupt the code
    message : str
        A custom error message to use if a timeout occurs
    """
    if message is None:
        message = f"Timeout reached after {time_in_s} second(s)"

    def timeout_handler(signum, frame): # pylint: disable=unused-argument
        raise RuntimeError(message)

    if time_in_s > 0:
        handler = signal.getsignal(signal.SIGALRM)
        signal.signal(signal.SIGALRM, timeout_handler)
        signal.alarm(time_in_s)
    try:
        yield
    finally:
        if time_in_s > 0:
            signal.alarm(0)
            signal.signal(signal.SIGALRM, handler)


class WeakrefProperty:
    """
    Descriptor object that stores a weakref to the encapsulated object.
    """

    def __init__(self, *, default=None, frozen=False):
        self._default = default
        self._frozen = frozen

    def __set_name__(self, owner, name):
        self._name = "_" + name

    def __get__(self, obj, _type):
        if obj is None:
            return self._default

        ref = getattr(obj, self._name, None)
        return ref() if ref is not None else self._default

    def __set__(self, obj, value):
        value = weakref.ref(value) if value is not None else None
        if self._frozen:
            obj.__dict__[self._name] = value
        else:
            setattr(obj, self._name, value)


def group_by_class(iterable, klass):
    """
    Find groups of consecutive instances of the same type with more
    than one element.

    Parameters
    ----------
    iterable : iterable
        Input iterable from which to extract groups
    klass : type
        Type by which to group elements in the given iterable
    """
    groups = tuple(
        tuple(g) for k, g in groupby(iterable, key=lambda x: x.__class__)
        if k == klass
    )
    return tuple(g for g in groups if len(g) > 1)


def replace_windowed(iterable, group, subs):
    """
    Replace a set of consecutive elements in a larger iterable.

    Parameters
    ----------
    iterable : iterable
        Input iterable in which to replace elements
    group : iterable
        Group of elements to replace in ``iterable``
    subs : any
        Replacement for ``group`` in ``iterable``
    """
    group = as_tuple(group)
    return tuple(replace(
        iterable, pred=lambda *args: args == group,
        substitutes=as_tuple(subs), window_size=len(group)
    ))


@contextmanager
def dict_override(base, override):
    """
    Contextmanager to temporarily override a set of dictionary values.

    Parameters
    ----------
    base : dict
        The base dictionary in which to overide values
    replace : dict
        Replacement mapping to temporarily insert
    """
    original_values = tuple((k, base[k]) for k in override.keys() if k in base)
    added_keys = tuple(k for k in override.keys() if k not in base)
    base.update(override)

    yield base

    base.update(original_values)
    for k in added_keys:
        del base[k]


@contextmanager
def stdchannel_redirected(stdchannel, dest_filename):
    """
    A context manager to temporarily redirect stdout or stderr

    e.g.:

    .. code-block:: python

        with stdchannel_redirected(sys.stderr, os.devnull):
            if compiler.has_function('clock_gettime', libraries=['rt']):
                libraries.append('rt')

    Source: https://stackoverflow.com/a/17753573

    Note, that this only works when pytest is invoked with '--show-capture' (or '-s').
    This can be checked using `stdchannel_is_captured(capsys)`.
    Additionally, capturing of sys.stdout/sys.stderr needs to be disabled explicitly,
    i.e., use the fixture `capsys` and wrap the above:

    .. code-block:: python

        with capsys.disabled():
            with stdchannel_redirected(sys.stdout, 'stdout.log'):
                function()
    """

    def try_dup(fd):
        try:
            oldfd = os.dup(fd.fileno())
        except io.UnsupportedOperation:
            oldfd = None
        return oldfd

    def try_dup2(fd, fd2, fd_fileno=True):
        try:
            if fd_fileno:
                os.dup2(fd.fileno(), fd2.fileno())
            else:
                os.dup2(fd, fd2.fileno())
        except io.UnsupportedOperation:
            pass

    oldstdchannel, dest_file = None, None
    try:
        oldstdchannel = try_dup(stdchannel)
        dest_file = open(dest_filename, 'w')
        try_dup2(dest_file, stdchannel)

        yield
    finally:
        if oldstdchannel is not None:
            try_dup2(oldstdchannel, stdchannel, fd_fileno=False)
        if dest_file is not None:
            dest_file.close()


def stdchannel_is_captured(capsys):
    """
    Utility function to verify if pytest captures stdout/stderr.

    This hinders redirecting stdout/stderr for f2py/f90wrap functions.

    Parameters
    ----------
    capsys :
        The capsys fixture of the test.

    Returns
    -------
    bool
        `True` if pytest captures output, otherwise `False`.
    """

    capturemanager = capsys.request.config.pluginmanager.getplugin("capturemanager")
    return capturemanager._global_capturing.out is not None


def graphviz_present():
    """
    Test if graphviz is present and works
    The import will work as long as the graphviz python wrapper is available,
    but the underlying binaries may be missing.
    """
    try:
        import graphviz as gviz # pylint: disable=import-outside-toplevel
    except ImportError:
        return False

    try:
        gviz.Graph().pipe()
    except gviz.ExecutableNotFound:
        return False

    return True


class OrderedSet(MutableSet):
    """
    A :any:`set` implementation that remembers the insertion order of its items

    Implementation is based on a dictionary without using its values, inspired by
    the recipe linked in the Python documentation:
    https://code.activestate.com/recipes/576694/

    Parameters
    ----------
    iterable :
        An iterable to initalize the OrderedSet from
    """

    def __init__(self, iterable=()):
        if iterable:
            self._storage = dict.fromkeys(iterable)
        else:
            self._storage = {}

    def __len__(self):
        """Return the number of items in the set"""
        return len(self._storage)

    def __contains__(self, value):
        """Return `true` if :data:`value` is in the set, otherwise return `false`"""
        return value in self._storage

    def add(self, value):
        """Add :data:`value` to the set"""
        self._storage.setdefault(value)

    def remove(self, value):
        """Remove :data:`value` from the set. Raises :any:`KeyError` if not contained in the set."""
        self._storage.pop(value)

    def discard(self, value):
        """Remove :data:`value` from the set if present."""
        if value in self._storage:
            self._storage.pop(value)

    def __iter__(self):
        yield from self._storage

    def __reversed__(self):
        yield from reversed(self._storage)

    def pop(self):
        """Remove and return an element from the set. Elements are returned in LIFO order."""
        return self._storage.popitem()[0]

    def __repr__(self):
        if not self:
            return f'{self.__class__.__name__}()'
        return f'{self.__class__.__name__}({list(self)})'

    def __eq__(self, other):
        if isinstance(other, OrderedSet):
            return len(self) == len(other) and list(self) == list(other)
        return set(self) == set(other)

    def copy(self):
        """Return a shallow copy of the set."""
        return self.__class__(list(self))

    def union(self, *others):
        """Return the union of the set with :data:`others`"""
        oset = self.copy()
        for o in others:
            oset |= o
        return oset

    def update(self, *others):
        """Update the set, adding elements from all :data:`others`"""
        for o in others:
            self |= o
loki-ecmwf-0.3.6/loki/tools/files.py0000664000175000017500000002236515167130205017523 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import atexit
import fnmatch
from hashlib import md5
from importlib import import_module, reload, invalidate_caches
import os
from pathlib import Path
import re
import shutil
import sys
import tempfile

from loki.logging import debug
from loki.tools.util import as_tuple, flatten
from loki.config import config


__all__ = [
    'LokiTempdir', 'gettempdir', 'filehash', 'delete', 'find_paths',
    'find_files', 'load_module', 'write_env_launch_script',
    'local_loki_setup', 'local_loki_cleanup'
]


class LokiTempdir:
    """
    Data structure to hold an instance of :class:`tempfile.TemporaryDirectory`
    to provide a Loki-specific temporary directory that is automatically
    cleaned up when the Python interpreter is terminated

    This class provides the temporary directory creation that :any:`gettempdir`
    relies upon.
    """

    def __init__(self):
        self.tmp_dir = None
        atexit.register(self.cleanup)

    def create(self):
        """
        Create the temporary directory
        """
        if self.tmp_dir is not None:
            # The temporary directory has already been initialised
            return

        # Determine the basedir...
        if config['tmp-dir']:
            basedir = Path(config['tmp-dir'])
        else:
            basedir = Path(tempfile.gettempdir())/'loki'

        # ...and make sure it exists
        basedir.mkdir(parents=True, exist_ok=True)

        # Pick a unique prefix
        prefix = f'{os.getpid()!s}_'

        self.tmp_dir = tempfile.TemporaryDirectory(prefix=prefix, dir=basedir) # pylint: disable=consider-using-with
        debug(f'Created temporary directory {self.tmp_dir.name}')

    def get(self):
        """
        Get the temporary directory path

        Returns
        -------
        pathlib.Path
        """
        if self.tmp_dir is None:
            self.create()
        return Path(self.tmp_dir.name)

    def cleanup(self):
        """
        Clean up the temporary directory
        """
        if self.tmp_dir is not None:
            name = self.tmp_dir.name
            self.tmp_dir.cleanup()
            self.tmp_dir = None
            debug(f'Cleaned up temporary directory {name}')


TMP_DIR = LokiTempdir()
"""
An instance of :class:`LokiTempdir` representing the
temporary directory that the current Loki instance uses.
"""


def gettempdir():
    """
    Get a Loki-specific tempdir

    Throughout the lifetime of the Python interpreter process, this will always
    return the same temporary directory.

    The base directory, under which the temporary directory resides, can be
    specified by setting the environment variable ``LOKI_TMP_DIR``. Otherwise
    the platform default will be used, observing the rules specified by
    :any:`tempfile.gettempdir`.

    The temporary directory is created, managed, and cleaned up by an instance of
    :any:`LokiTempdir`. Loki will choose a process-specific temporary directory
    under the base directory to avoid race conditions between concurrently running
    Loki instances. The initialisation mechanism is lazy, creating the
    temporary directory only when this function is called for the first time.
    """
    return TMP_DIR.get()


def filehash(source, prefix=None, suffix=None):
    """
    Generate a filename from a hash of ``source`` with an optional ``prefix``.
    """
    prefix = '' if prefix is None else prefix
    suffix = '' if suffix is None else suffix
    return f'{prefix}{str(md5(source.encode()).hexdigest())}{suffix}'


def delete(filename, force=False):
    filepath = Path(filename)
    debug(f'Deleting {filepath}')
    if force:
        shutil.rmtree(f'{filepath}', ignore_errors=True)
    else:
        if filepath.exists():
            os.remove(f'{filepath}')


def find_paths(directory, pattern, ignore=None, sort=True):
    """
    Utility function to generate a list of file paths based on include
    and exclude patterns applied to a root directory.

    Parameters
    ----------
    directory : str or :any:`pathlib.Path`
        Root directory from which to glob files.
    pattern : list of str
        A list of glob patterns generating files to include in the list.
    ignore : list of str, optional
        A list of glob patterns generating files to exclude from the list.
    sort : bool, optional
        Flag to indicate alphabetic ordering of files

    Returns
    -------
    list :
        The list of file names
    """
    directory = Path(directory)
    excludes = flatten(directory.rglob(e) for e in as_tuple(ignore))

    files = []
    for incl in as_tuple(pattern):
        files += [f for f in directory.rglob(incl) if f not in excludes]

    return sorted(files) if sort else files


def find_files(pattern, srcdir='.'):
    """
    Case-insensitive alternative for glob patterns that recursively
    walks all sub-directories and matches a case-insensitive regex pattern.

    Basic idea from:
    http://stackoverflow.com/questions/8151300/ignore-case-in-glob-on-linux
    """
    rule = re.compile(fnmatch.translate(pattern), re.IGNORECASE)
    return [Path(dirpath)/fname for dirpath, _, fnames in os.walk(str(srcdir))
            for fname in fnames if rule.match(fname)]


def load_module(module, path=None):
    """
    Handle import paths and load the compiled module
    """
    if path and str(path) not in sys.path:
        sys.path.insert(0, str(path))
    if module in sys.modules:
        reload(sys.modules[module])
        return sys.modules[module]

    try:
        # Attempt to load module directly
        return import_module(module)
    except ModuleNotFoundError:
        # If module caching interferes, try again with clean caches
        invalidate_caches()
        return import_module(module)


def write_env_launch_script(here, binary, args):
    """
    Utility method that is used for regression tests that require
    activating an environment file before running :data:`binary`.

    This writes a simple script of the form

    .. code-block::

       source env.sh
       bin/ 
       exit $?

    Parameters
    ----------
    here : pathlib.Path or str
        The directory in which the script is created
    binary : str
        The name of the binary
    args : list of str
        List of arguments to pass to the binary

    Returns
    -------
    pathlib.Path
        The path to the created script file
    """

    script = Path(here/f'build/run_{binary}.sh')
    script.write_text(f"""
#!/bin/bash

source env.sh >&2
bin/{binary} {' '.join(args)}
exit $?
    """.strip())
    script.chmod(0o750)

    return script


def local_loki_setup(here):
    """
    Utility method that is used to determine paths for injecting the
    currently running source code of Loki into an
    `ecbundle `_-based worktree This
    is used for regression tests to facilitate the use of a local Loki
    source copy in the build. In particular, any existing Loki source
    copy in the bundle worktree is moved to a backup location.

    .. warning:: If a backup copy exists already at the backup
       location, this is removed before moving the existing Loki copy
       to the backup location.

    Note that injecting the currently running Loki installation only
    works if it has been installed in editable mode.  However, this
    utility also does not take care of the actual injection of the
    currently running installation, therefore making this also useful
    if the purpose is to trigger a Loki download via the bundle create
    mechanism.

    The companion utility :any:`local_loki_cleanup` can be used to
    revert these changes.

    Parameters
    ----------
    here : pathlib.Path
        The root path of the bundle worktree.

    Returns
    -------
    tuple of (str, pathlib.Path, pathlib.Path)
        The absolute path to the base directory of the currently
        running Loki installation, the ``target`` path where Loki
        needs to be injected in the bundle directory, and the
        ``backup`` path where an existing Loki copy in the bundle has
        been moved to.
    """

    lokidir = Path(__file__).parent.parent.parent
    target = here/'source/loki'
    backup = here/'source/loki.bak'

    # Do not overwrite any existing Loki copy
    if target.exists():
        if backup.exists():
            shutil.rmtree(backup)
        shutil.move(target, backup)

    return str(lokidir.resolve()), target, backup


def local_loki_cleanup(target, backup):
    """
    Companion utility to :any:`local_loki_setup` to revert the
    changes.

    This removes a symlink at :data:`target`, if it exists, and moves
    the :data:`backup` path in its original location.

    Parameters
    ---------
    target : pathlib.Path
        The target injection path as returned by :any:`local_loki_setup`
    backup : pathlib.Path
        The backup path as created by :any:`local_loki_setup`
    """

    if target.is_symlink():
        target.unlink()
    if not target.exists() and backup.exists():
        shutil.move(backup, target)
loki-ecmwf-0.3.6/loki/tools/strings.py0000664000175000017500000002205115167130205020102 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import re
from copy import deepcopy

from loki.tools.util import is_iterable


__all__ = ['truncate_string', 'JoinableStringList']


def truncate_string(string, length=16, continuation='...'):
    """
    Truncates a string to have a maximum given number of characters and indicates the
    truncation by continuation characters '...'.

    This is used, for example, in the representation strings of IR nodes.
    """
    if len(string) > length:
        return string[:length - len(continuation)] + continuation
    return string


class JoinableStringList:
    """
    Helper class that takes a list of :data:`items` and joins them into a long
    string, when converting the object to a string using custom separator :data:`sep`.
    Long lines are wrapped automatically.

    The behaviour is essentially the same as ``sep.join(items)`` but with the
    automatic wrapping of long lines. :data:`items` can contain strings as well
    as other instances of :class:`JoinableStringList`.

    Parameters
    ----------
    items : list of str or :any:`JoinableStringList`
        The list (or tuple) of items that should be joined into a string.
    sep : str
        The separator to be inserted between consecutive items.
    width : int
        The line width after which long lines should be wrapped.
    cont : (str, str) or str
        The line continuation string to be inserted on a line break, optionally
        separated as end-of-line and beginning-of-next-line strings
    separable : bool
        An indicator whether this object can be split up to fill
        lines or should stay as a unit (this is for cosmetic
        purposes only, as too long lines will be wrapped in any case).
    """

    _pattern_quoted_string = re.compile(r'(?:\'.*?\')|(?:".*?")')
    _pattern_chunk_separator = re.compile(r'(\s|\)(?!%)|\n)')

    def __init__(self, items, sep, width, cont, separable=True):
        super().__init__()

        assert is_iterable(items)
        assert isinstance(sep, str)
        if isinstance(cont, str):
            cont = cont.splitlines(keepends=True)
            if len(cont) == 1:
                cont += ['']
        assert is_iterable(cont) and len(cont) == 2
        # Reset indentation if we exceed the line length by just having
        # both continuation parts on the same line
        if len(cont[0] + cont[1]) >= width:
            cont = [c.strip(' ') for c in cont]
        assert all(width > len(c) for c in cont)

        self.items = [item for item in items if item is not None]
        self.sep = sep
        self.width = width
        self.cont = cont
        self.separable = separable

    def _add_item_to_line(self, line, item):
        """
        Append the given item to the line.

        :param str line: the line to which the item is appended.
        :param item: the item that is appended.
        :type item: str or `JoinableStringList`

        :return: the updated line and a list of preceeding lines that have
                 been wrapped in the process.
        :rtype: (str, list)
        """
        # Let's see if we can fit the current item plus separator
        # onto the line and have enough space left for a line break
        new_line = f'{line!s}{item!s}'
        if len(new_line) + len(self.cont[0]) <= self.width:
            return new_line, []

        # Putting the current item plus separator and potential line break
        # onto the current line exceeds the allowed width: we need to break.
        item_line = f'{self.cont[1]!s}{item!s}'
        item_fits_in_line = len(item_line) + len(self.cont[0]) <= self.width

        # First, let's see if we have a JoinableStringList object that we can split up.
        # However, we'll split this up only if allowed or if the item won't fit
        # on a line
        if (isinstance(item, type(self)) and (item.separable or not item_fits_in_line) and
                len(item.items) > 1):
            line_, new_item = item._to_str(line=line, stop_on_continuation=True)
            if len(new_item.items) < len(item.items):
                # If we have been able to put at least one entry from item on the line, we
                # continue recursively:
                new_line, lines = self._add_item_to_line(self.cont[1], new_item)
                return new_line, [line_ + self.cont[0], *lines]

        # Otherwise, let's put it on a new line if the item as a whole fits on the next line
        if item_fits_in_line:
            return item_line, [line + self.cont[0]]

        # The new item does not fit onto a line at all and it is not a JoinableStringList
        # where the first item fits onto a line, or for which we know how to split it:
        # let's try our best by splitting the string
        if isinstance(item, str):
            item_str = item
        elif isinstance(item, type(self)):
            # We simply join up the items here to avoid that any line continuations are introduced
            item_str = item.sep.join(str(i) for i in item.items)
        else:
            item_str = str(item)

        chunk_list = []
        offset = 0
        for string_match in self._pattern_quoted_string.finditer(item_str):
            if string_match.start() > offset:
                chunk_list += self._pattern_chunk_separator.split(item_str[offset:string_match.start()])
            chunk_list += [string_match[0]]
            offset = string_match.end()
        if offset < len(item_str):
            chunk_list += self._pattern_chunk_separator.split(item_str[offset:])

        # First, add as much as possible to the previous line
        next_chunk = 0
        for idx, chunk in enumerate(chunk_list):
            new_line = line + chunk
            if len(new_line) + len(self.cont[0]) > self.width:
                next_chunk = idx
                break
            line = new_line

        # Now put the rest on new lines
        lines = []
        if line != self.cont[1]:
            lines += [line + self.cont[0]]
            line = self.cont[1]
        for chunk in chunk_list[next_chunk:]:
            new_line = line + chunk
            if len(new_line) + len(self.cont[0]) > self.width and line != self.cont[1]:
                lines += [line + self.cont[0]]
                line = self.cont[1] + chunk
            else:
                line = new_line

        return line, lines

    def _to_str(self, line='', stop_on_continuation=False):
        """
        Join all items into a long string using the given separator and wrap lines if
        necessary.

        :param str line: the line this should be appended to.
        :param bool stop_on_continuation: if True, only items up to the line width are
            appended

        :return: the joined string and a `JoinableStringList` object with the remaining
                 items, if any, or None.
        :rtype: (str, JoinableStringList or NoneType)
        """
        if not self.items:
            return '', None
        lines = []
        # Add all items one after another
        for idx, item in enumerate(self.items):
            if str(item) == '':
                # Skip empty items
                continue
            sep = self.sep if idx + 1 < len(self.items) else ''
            old_line = line
            line, _lines = self._add_item_to_line(line, item + sep)
            if stop_on_continuation and _lines:
                return old_line, type(self)(self.items[idx:], sep=self.sep, width=self.width,
                                            cont=self.cont, separable=self.separable)
            lines += _lines
        return ''.join([*lines, line]), None

    def __add__(self, other):
        """
        Concatenate this object and a string or another py:class:`JoinableStringList`.

        :param other: the object to append.
        :type other: str or JoinableStringList
        """
        if isinstance(other, type(self)):
            return type(self)([self, other], sep='', width=self.width, cont=self.cont,
                              separable=False)
        if isinstance(other, str):
            obj = deepcopy(self)
            if obj.items:
                obj.items[-1] += other
            else:
                obj.items = [other]
            return obj
        raise TypeError('Concatenation only for strings or items of same type.')

    def __radd__(self, other):
        """
        Concatenate a string and this object.

        :param other: the str to prepend.
        :type other: str
        """
        if isinstance(other, str):
            obj = deepcopy(self)
            if obj.items:
                obj.items[0] = other + obj.items[0]
            else:
                obj.items = [other]
            return obj
        raise TypeError('Concatenation only for strings.')

    def __str__(self):
        """
        Convert to a string.
        """
        return self._to_str()[0]
loki-ecmwf-0.3.6/loki/expression/0000775000175000017500000000000015167130205017076 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/expression/mixins.py0000664000175000017500000000401415167130205020756 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.config import config
from loki.expression.mappers import ExpressionRetriever


__all__ = ['loki_make_stringifier', 'StrCompareMixin']


def loki_make_stringifier(self, originating_stringifier=None):  # pylint: disable=unused-argument
    """
    Return a :any:`LokiStringifyMapper` instance that can be used to generate a
    human-readable representation of :data:`self`.

    This is used as common abstraction for the :meth:`make_stringifier` method in
    Pymbolic expression nodes.
    """
    from loki.expression.mappers import LokiStringifyMapper  # pylint: disable=import-outside-toplevel
    return LokiStringifyMapper()


class StrCompareMixin:
    """
    Mixin to enable comparing expressions to strings.

    The purpose of the string comparison override is to reliably and flexibly
    identify expression symbols from equivalent strings.
    """

    @staticmethod
    def _canonical(s):
        """ Define canonical string representations (lower-case, no spaces) """
        if config['case-sensitive']:
            return str(s).replace(' ', '')
        return str(s).lower().replace(' ', '')

    def __hash__(self):
        return hash(self._canonical(self))

    def __eq__(self, other):
        if isinstance(other, (str, type(self))):
            # Do comparsion based on canonical string representations
            return self._canonical(self) == self._canonical(other)

        return super().__eq__(other)

    def __contains__(self, other):
        # Assess containment via a retriver with node-wise string comparison
        return len(ExpressionRetriever(lambda x: x == other).retrieve(self)) > 0

    make_stringifier = loki_make_stringifier
loki-ecmwf-0.3.6/loki/expression/__init__.py0000664000175000017500000000153515167130205021213 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
"""
Expression layer of the two-level Loki IR based on `Pymbolic
`_.
"""

from loki.expression.evaluation import *  # noqa
from loki.expression.literals import *  # noqa
from loki.expression.mappers import *  # noqa
from loki.expression.mixins import *  # noqa
from loki.expression.operations import *  # noqa
from loki.expression.parser import *  # noqa
from loki.expression.symbolic import *  # noqa
from loki.expression.symbols import *  # noqa
loki-ecmwf-0.3.6/loki/expression/tests/0000775000175000017500000000000015167130205020240 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/expression/tests/__init__.py0000664000175000017500000000057015167130205022353 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
loki-ecmwf-0.3.6/loki/expression/tests/test_mapper.py0000664000175000017500000001053415167130205023140 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki import Subroutine, Scope
from loki.expression import symbols as sym, parse_expr
from loki.expression.mappers import (
    ExpressionRetriever, LokiIdentityMapper, SubstituteExpressionsMapper
)
from loki.frontend import available_frontends
from loki.ir import nodes as ir, FindNodes


@pytest.mark.parametrize('frontend', available_frontends())
def test_expression_retriever(frontend):
    """ Test for :any:`ExpressionRetriever` (a :any:`LokiWalkMapper`) """

    fcode = """
subroutine test_expr_retriever(n, a, b, c)
  integer, intent(inout) :: n, a, b(n), c

  a = 5 * a + 4 * b(c) + a
end subroutine test_expr_retriever
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)
    expr = FindNodes(ir.Assignment).visit(routine.body)[0].rhs

    def q_symbol(n):
        return isinstance(n, sym.TypedSymbol)

    def q_array(n):
        return isinstance(n, sym.Array)

    def q_scalar(n):
        return isinstance(n, sym.Scalar)

    def q_deferred(n):
        return isinstance(n, sym.DeferredTypeSymbol)

    def q_literal(n):
        return isinstance(n, sym.IntLiteral)

    assert ExpressionRetriever(q_symbol).retrieve(expr) == ['a', 'b', 'c', 'a']
    assert ExpressionRetriever(q_array).retrieve(expr) == ['b(c)']
    assert ExpressionRetriever(q_scalar).retrieve(expr) == ['a', 'c', 'a']
    assert ExpressionRetriever(q_literal).retrieve(expr) == [5, 4]

    scope = Scope()
    expr = parse_expr('5 * a + 4 * b(c) + a', scope=scope)

    assert ExpressionRetriever(q_symbol).retrieve(expr) == ['a', 'b', 'c', 'a']
    assert ExpressionRetriever(q_array).retrieve(expr) == ['b(c)']
    # Cannot determine Scalar without declarations, so check for deferred
    assert ExpressionRetriever(q_deferred).retrieve(expr) == ['a', 'c', 'a']
    assert ExpressionRetriever(q_literal).retrieve(expr) == [5, 4]


@pytest.mark.parametrize('frontend', available_frontends())
def test_identity_mapper(frontend):
    """
    Test for :any:`LokiIdentityMapper`, in particular deep-copying
    expression nodes.
    """

    fcode = """
subroutine test_expr_retriever(n, a, b, c)
  integer, intent(inout) :: n, a, b(n), c

  a = 5 * a + 4 * b(c) + a
end subroutine test_expr_retriever
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)
    expr = FindNodes(ir.Assignment).visit(routine.body)[0].rhs

    # Run the identity mapper over the expression
    new_expr = LokiIdentityMapper()(expr)

    # Check that symbols and literals are equivalent, but distinct objects!
    get_symbols = ExpressionRetriever(lambda e: isinstance(e, sym.TypedSymbol)).retrieve
    get_literals = ExpressionRetriever(lambda e: isinstance(e, sym.IntLiteral)).retrieve

    for old, new in zip(get_symbols(expr), get_symbols(new_expr)):
        assert old == new
        assert not old is new

    for old, new in zip(get_literals(expr), get_literals(new_expr)):
        assert old == new
        assert not old is new


@pytest.mark.parametrize('frontend', available_frontends())
def test_substitute_expression_mapper(frontend):
    """
    Test for :any:`SubstituteExpressionsMapper`.
    """

    fcode = """
subroutine test_expr_retriever(n, a, b, c, d)
  integer, intent(inout) :: n, a, b(n), c, d

  a = 5 * a + 4 * b(c) + a
end subroutine test_expr_retriever
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)
    expr = FindNodes(ir.Assignment).visit(routine.body)[0].rhs

    retriever = ExpressionRetriever(lambda e: isinstance(e, sym.TypedSymbol))
    symbols = retriever.retrieve(expr)
    assert symbols == ['a', 'b', 'c', 'a']
    assert symbols[0] == symbols[3]
    assert not symbols[0] is symbols[3]
    a = symbols[0]
    d = routine.variable_map['d']

    new_expr = SubstituteExpressionsMapper(expr_map={a: d})(expr)

    assert new_expr == '5*d + 4*b(c) + d'
    new_symbols = retriever.retrieve(new_expr)
    assert new_symbols == ['d', 'b', 'c', 'd']
    assert new_symbols[0] == new_symbols[3]
    # Ensure multiple inserted symbols are still unique
    assert not new_symbols[0] is new_symbols[3]
loki-ecmwf-0.3.6/loki/expression/tests/test_parser.py0000664000175000017500000006727215167130205023163 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

import pymbolic.primitives as pmbl
import pymbolic.mapper as pmbl_mapper

from loki import Subroutine, Module, Scope
from loki.expression import symbols as sym, parse_expr, LokiEvaluationMapper
from loki.frontend import (
    available_frontends, HAVE_FP, parse_fparser_expression
)
from loki.ir import FindVariables
from loki.tools.util import CaseInsensitiveDict

# utility function to test parse_expr with different case
def convert_to_case(_str, mode='upper'):
    if mode == 'upper':
        return _str.upper()
    if mode == 'lower':
        return _str.lower()
    if mode == 'random':
        # this is obviously not random, but fulfils its purpose ...
        result = ''
        for i, char in enumerate(_str):
            result += char.upper() if i%2==0 and i<3 else char.lower()
        return result
    return convert_to_case(_str)


@pytest.mark.parametrize('source, ref', [
    ('1 + 1', '1 + 1'),
    ('1+2+3+4', '1 + 2 + 3 + 4'),
    ('5*4 - 3*2 - 1', '5*4 - 3*2 - 1'),
    ('1*(2 + 3)', '1*(2 + 3)'),
    ('5*a +3*7**5 - 4/b', '5*a + 3*7**5 - 4 / b'),
    ('5 + (4 + 3) - (2*1)', '5 + (4 + 3) - (2*1)'),
    ('a*(b*(c+(d+e)))', 'a*(b*(c + (d + e)))'),
])
@pytest.mark.parametrize('parse', (
    parse_expr,
    pytest.param(parse_fparser_expression,
        marks=pytest.mark.skipif(not HAVE_FP, reason='parse_fparser_expression not available!'))
))
def test_parse_expression(parse, source, ref):
    """
    Test the utility function that parses simple expressions.
    """
    scope = Scope()
    ir = parse(source, scope)  # pylint: disable=redefined-outer-name
    assert isinstance(ir, pmbl.Expression)
    assert str(ir) == ref


@pytest.mark.parametrize('case', ('upper', 'lower', 'random'))
@pytest.mark.parametrize('frontend', available_frontends())
def test_expression_parser(frontend, case, tmp_path):
    fcode = """
subroutine some_routine()
  implicit none
  integer :: i1, i2, i3, len1, len2, len3
  real :: a, b
  real :: arr(len1, len2, len3)
end subroutine some_routine
    """.strip()

    fcode_mod = """
module external_mod
  implicit none
contains
  function my_func(a)
    integer, intent(in) :: a
    integer :: my_func
    my_func = a
  end function my_func
end module external_mod
    """.strip()

    def to_str(_parsed):
        return str(_parsed).lower().replace(' ', '')

    routine = Subroutine.from_source(fcode, frontend=frontend)
    module = Module.from_source(fcode_mod, frontend=frontend, xmods=[tmp_path])

    parsed = parse_expr(convert_to_case('a + b', mode=case))
    assert isinstance(parsed, sym.Sum)
    assert all(isinstance(_parsed,  sym.DeferredTypeSymbol) for _parsed in parsed.children)
    assert to_str(parsed) == 'a+b'

    parsed = parse_expr(convert_to_case('a + b', mode=case), scope=routine)
    assert isinstance(parsed, sym.Sum)
    assert all(isinstance(_parsed,  sym.Scalar) for _parsed in parsed.children)
    assert all(_parsed.scope == routine for _parsed in parsed.children)
    assert to_str(parsed) == 'a+b'

    parsed = parse_expr(convert_to_case('a + b + 2 + 10', mode=case), scope=routine)
    assert isinstance(parsed, sym.Sum)
    assert to_str(parsed) == 'a+b+2+10'

    parsed = parse_expr(convert_to_case('a - b', mode=case), scope=routine)
    assert isinstance(parsed, sym.Sum)
    assert isinstance(parsed.children[0], sym.Scalar)
    assert isinstance(parsed.children[1], sym.Product)
    assert to_str(parsed) == 'a-b'

    parsed = parse_expr(convert_to_case('a * b', mode=case), scope=routine)
    assert isinstance(parsed, sym.Product)
    assert all(isinstance(_parsed,  sym.Scalar) for _parsed in parsed.children)
    assert all(_parsed.scope == routine for _parsed in parsed.children)
    assert to_str(parsed) == 'a*b'

    parsed = parse_expr(convert_to_case('a / b', mode=case), scope=routine)
    assert isinstance(parsed, sym.Quotient)
    assert all(isinstance(_parsed,  sym.Scalar) for _parsed in [parsed.numerator, parsed.denominator])
    assert all(_parsed.scope == routine for _parsed in [parsed.numerator, parsed.denominator])
    assert to_str(parsed) == 'a/b'

    parsed = parse_expr(convert_to_case('a ** b', mode=case), scope=routine)
    assert isinstance(parsed, sym.Power)
    assert all(isinstance(_parsed,  sym.Scalar) for _parsed in [parsed.base, parsed.exponent])
    assert all(_parsed.scope == routine for _parsed in [parsed.base, parsed.exponent])
    assert to_str(parsed) == 'a**b'

    parsed = parse_expr(convert_to_case(':', mode=case))
    assert isinstance(parsed, sym.RangeIndex)
    assert to_str(parsed) == ':'

    parsed = parse_expr(convert_to_case('a:b', mode=case), scope=routine)
    assert isinstance(parsed, sym.RangeIndex)
    assert all(isinstance(_parsed,  sym.Scalar) for _parsed in [parsed.lower, parsed.upper])
    assert all(_parsed.scope == routine for _parsed in [parsed.lower, parsed.upper])
    assert to_str(parsed) == 'a:b'

    parsed = parse_expr(convert_to_case('a:b:5', mode=case), scope=routine)
    assert isinstance(parsed, sym.RangeIndex)
    assert all(isinstance(_parsed,  (sym.Scalar, sym.IntLiteral))
            for _parsed in [parsed.lower, parsed.upper, parsed.step])
    assert to_str(parsed) == 'a:b:5'

    parsed = parse_expr(convert_to_case('a == b', mode=case), scope=routine)
    assert parsed.operator == '=='
    assert isinstance(parsed, sym.Comparison)
    assert all(isinstance(_parsed,  sym.Scalar) for _parsed in [parsed.left, parsed.right])
    assert all(_parsed.scope == routine for _parsed in [parsed.left, parsed.right])
    assert to_str(parsed) == 'a==b'
    parsed = parse_expr(convert_to_case('a.eq.b', mode=case), scope=routine)
    assert parsed.operator == '=='
    assert isinstance(parsed, sym.Comparison)
    assert all(isinstance(_parsed,  sym.Scalar) for _parsed in [parsed.left, parsed.right])
    assert all(_parsed.scope == routine for _parsed in [parsed.left, parsed.right])
    assert to_str(parsed) == 'a==b'

    parsed = parse_expr(convert_to_case('a!=b', mode=case), scope=routine)
    assert parsed.operator == '!='
    assert isinstance(parsed, sym.Comparison)
    assert all(isinstance(_parsed,  sym.Scalar) for _parsed in [parsed.left, parsed.right])
    assert all(_parsed.scope == routine for _parsed in [parsed.left, parsed.right])
    assert to_str(parsed) == 'a!=b'
    parsed = parse_expr(convert_to_case('a.ne.b', mode=case), scope=routine)
    assert parsed.operator == '!='
    assert isinstance(parsed, sym.Comparison)
    assert all(isinstance(_parsed,  sym.Scalar) for _parsed in [parsed.left, parsed.right])
    assert all(_parsed.scope == routine for _parsed in [parsed.left, parsed.right])
    assert to_str(parsed) == 'a!=b'

    parsed = parse_expr(convert_to_case('a>b', mode=case), scope=routine)
    assert parsed.operator == '>'
    assert isinstance(parsed, sym.Comparison)
    assert all(isinstance(_parsed,  sym.Scalar) for _parsed in [parsed.left, parsed.right])
    assert all(_parsed.scope == routine for _parsed in [parsed.left, parsed.right])
    assert to_str(parsed) == 'a>b'
    parsed = parse_expr(convert_to_case('a.gt.b', mode=case), scope=routine)
    assert parsed.operator == '>'
    assert isinstance(parsed, sym.Comparison)
    assert all(isinstance(_parsed,  sym.Scalar) for _parsed in [parsed.left, parsed.right])
    assert all(_parsed.scope == routine for _parsed in [parsed.left, parsed.right])
    assert to_str(parsed) == 'a>b'

    parsed = parse_expr(convert_to_case('a>=b', mode=case), scope=routine)
    assert parsed.operator == '>='
    assert isinstance(parsed, sym.Comparison)
    assert all(isinstance(_parsed,  sym.Scalar) for _parsed in [parsed.left, parsed.right])
    assert all(_parsed.scope == routine for _parsed in [parsed.left, parsed.right])
    assert to_str(parsed) == 'a>=b'
    parsed = parse_expr(convert_to_case('a.ge.b', mode=case), scope=routine)
    assert parsed.operator == '>='
    assert isinstance(parsed, sym.Comparison)
    assert all(isinstance(_parsed,  sym.Scalar) for _parsed in [parsed.left, parsed.right])
    assert all(_parsed.scope == routine for _parsed in [parsed.left, parsed.right])
    assert to_str(parsed) == 'a>=b'

    parsed = parse_expr(convert_to_case('a 1.) then
      out(i) = 3.
    else
      out(i) = 1.
    end if
  end do
end subroutine logical_array
"""
    filepath = tmp_path/(f'expression_logical_array_{frontend}.f90')
    routine = Subroutine.from_source(fcode, frontend=frontend)
    function = jit_compile(routine, filepath=filepath, objname='logical_array')

    out = np.zeros(6)
    function(6, [0., 2., -1., 3., 0., 2.], out)
    assert (out == [1., 1., 1., 3., 1., 3.]).all()
    clean_test(filepath)


@pytest.mark.parametrize('frontend', available_frontends())
def test_array_constructor(tmp_path, frontend):
    """
    Test various array constructor formats
    """
    fcode = """
subroutine array_constructor(dim, zarr1, zarr2, narr1, narr2, narr3, narr4, narr5)
    implicit none
    integer, intent(in) :: dim
    real(8), intent(inout) :: zarr1(dim+1)
    real(8), intent(inout) :: zarr2(3)
    integer, intent(inout) :: narr1(dim)
    integer, intent(inout) :: narr2(10)
    integer, intent(inout) :: narr3(3)
    integer, intent(inout) :: narr4(2,2)
    integer, intent(inout) :: narr5(10)
    integer :: i

    zarr1 = [ 3.6, (3.6 / I, I = 1, dim) ]
    narr1 = (/ (I, I = 1, DIM) /)
    narr2 = (/1, 0, (I, I = -1, -6, -1), -7, -8 /)
    narr3 = [integer :: 1, 2., 3d0]    ! A default integer array
    zarr2 = [real(8) :: 1, 2, 3._8]  ! A real(8) array
    narr4 = RESHAPE([1,2,3,4], shape=[2,2])
    narr5 = (/(I, I=30, 48, 2)/)
end subroutine array_constructor
    """.strip()

    filepath = tmp_path/f'array_constructor_{frontend}.f90'
    routine = Subroutine.from_source(fcode, frontend=frontend)
    function = jit_compile(routine, filepath=filepath, objname='array_constructor')

    literal_lists = [e for e in FindExpressions().visit(routine.body) if isinstance(e, sym.LiteralList)]
    assert len(literal_lists) == 8
    assert {str(l).lower() for l in literal_lists} == {
        '[ 3.6, ( 3.6 / i, i = 1:dim ) ]',
        '[ ( i, i = 1:dim ) ]',
        '[ 1, 0, ( i, i = -1:-6:-1 ), -7, -8 ]',
        '[  :: 1, 2., 3d0 ]',
        '[  :: 1, 2, 3._8 ]',
        '[ 1, 2, 3, 4 ]',
        '[ 2, 2 ]',
        '[ ( i, i = 30:48:2 ) ]'
    }

    dim = 13
    zarr1 = np.zeros(dim+1, dtype=np.float64)
    zarr2 = np.zeros(3, dtype=np.float64)
    narr1 = np.zeros(dim, dtype=np.int32)
    narr2 = np.zeros(10, dtype=np.int32)
    narr3 = np.zeros(3, dtype=np.int32)
    narr4 = np.zeros((2, 2), dtype=np.int32, order='F')
    narr5 = np.zeros(10, dtype=np.int32)
    function(dim, zarr1, zarr2, narr1, narr2, narr3, narr4, narr5)

    assert np.isclose(zarr1, ([3.6] + [3.6/(i+1) for i in range(dim)])).all()
    assert np.isclose(zarr2, [1., 2., 3.]).all()
    assert (narr1 == range(1, dim+1)).all()
    assert (narr2 == range(1, -9, -1)).all()
    assert (narr3 == [1, 2, 3]).all()
    assert (narr4 == np.array([[1, 3], [2, 4]], order='F')).all()
    assert (narr5 == range(30, 49, 2)).all()

    clean_test(filepath)


@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'Precedence not honoured')]))
def test_parenthesis(frontend):
    """
    Test explicit parenthesis in provided source code.

    Note, that this test is very niche, as it ensures that mathematically
    insignificant (and hence sort of wrong) bracketing is still honoured.
    The reason is that, if sub-expressions are sufficiently complex,
    this can still cause round-off deviations and hence destroy
    bit-reproducibility.

    Also note, that the OMNI-frontend parser will resolve precedence and
    hence we cannot honour these precedence cases (for now).
    """

    fcode = """
subroutine parenthesis(v1, v2, v3, i)
  integer, parameter :: jprb = selected_real_kind(13,300)
  real(kind=jprb), intent(in) :: v1(:), v2
  real(kind=jprb), intent(out) :: v3
  integer, intent(in) :: i

  v3 = (v1(i-1)**1.23_jprb) * 1.3_jprb + (1_jprb - v2**1.26_jprb)

  v3 = min(5._jprb - 3._jprb*v1(i), 3._jprb*exp(5._jprb*(v1(i) - v2) / (v1(i) - v3)) / 2._jprb*exp(5._jprb*(v1(i) - v2) / (v1(i) -  &
  & v3)))

  v3 = v1(i)*(1.0_jprb / (v2*v3))
end subroutine parenthesis
""".strip()
    routine = Subroutine.from_source(fcode, frontend=frontend)
    stmts = FindNodes(ir.Assignment).visit(routine.body)

    # Check that the reduntant bracket around the minus
    # and the first exponential are still there.
    assert fgen(stmts[0]) == 'v3 = (v1(i - 1)**1.23_jprb)*1.3_jprb + (1_jprb - v2**1.26_jprb)'

    # Now perform a simple substitutions on the expression
    # and make sure we are still parenthesising as we should!
    v2 = [v for v in FindVariables().visit(stmts[0]) if v.name == 'v2'][0]
    v4 = v2.clone(name='v4')
    stmt2 = SubstituteExpressions({v2: v4}).visit(stmts[0])
    assert fgen(stmt2) == 'v3 = (v1(i - 1)**1.23_jprb)*1.3_jprb + (1_jprb - v4**1.26_jprb)'

    # Make sure there are no additional brackets in the exponentials or numerators/denominators
    assert '\n'.join(l.lstrip() for l in fcode.splitlines()[-5:-3]) == fgen(stmts[1]).lower()
    assert fgen(stmts[2]) == fcode.splitlines()[-2].lstrip()


@pytest.mark.parametrize('frontend', available_frontends())
def test_commutativity(frontend):
    """
    Verifies the strict adherence to ordering of commutative terms,
    which can introduce round-off errors if not done conservatively.
    """
    fcode = """
subroutine commutativity(v1, v2, v3)
  integer, parameter :: jprb = selected_real_kind(13,300)
  real(kind=jprb), pointer, intent(in) :: v1(:), v2
  real(kind=jprb), pointer, intent(out) :: v3(:)

  v3(:) = 1._jprb + v2*v1(:) - v2 - v3(:)
end subroutine commutativity
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)
    stmt = FindNodes(ir.Assignment).visit(routine.body)[0]

    assert fgen(stmt) in ('v3(:) = 1.0_jprb + v2*v1(:) - v2 - v3(:)',
                          'v3(:) = 1._jprb + v2*v1(:) - v2 - v3(:)')


@pytest.mark.parametrize('frontend', available_frontends())
def test_index_ranges(frontend):
    """
    Test index range expressions for array accesses.
    """
    fcode = """
subroutine index_ranges(dim, v1, v2, v3, v4, v5)
  integer, parameter :: jprb = selected_real_kind(13,300)
  integer, intent(in) :: dim
  real(kind=jprb), intent(in) :: v1(:), v2(0:), v3(0:4), v4(dim)
  real(kind=jprb), intent(out) :: v5(1:dim)

  v5(:) = v2(1:dim)*v1(::2) - v3(0:4:2)
end subroutine index_ranges
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)
    vmap = routine.variable_map

    assert str(vmap['v1']) == 'v1(:)'
    assert str(vmap['v2']) == 'v2(0:)'
    assert str(vmap['v3']) == 'v3(0:4)'
    # OMNI will insert implicit lower=1 into shape declarations,
    # we simply have to live with it... :(
    assert str(vmap['v4']) == 'v4(dim)' or str(vmap['v4']) == 'v4(1:dim)'
    assert str(vmap['v5']) == 'v5(1:dim)' or str(vmap['v5']) == 'v5(dim)'

    vmap_body = {v.name: v for v in FindVariables().visit(routine.body)}
    assert str(vmap_body['v1']) == 'v1(::2)'
    assert str(vmap_body['v2']) == 'v2(dim)' or str(vmap_body['v2']) == 'v2(1:dim)'
    assert str(vmap_body['v3']) == 'v3(0:4:2)'
    assert str(vmap_body['v5']) == 'v5(:)'


@pytest.mark.parametrize('frontend', available_frontends())
def test_strings(tmp_path, frontend, capsys):
    """
    Test recognition of literal strings.
    """

    # This tests works only if stdout/stderr is not captured by pytest
    if stdchannel_is_captured(capsys):
        pytest.skip('pytest executed without "--show-capture"/"-s"')

    fcode = """
subroutine strings()
  print *, 'Hello world!'
  print *, "42!"
end subroutine strings
"""
    filepath = tmp_path/(f'expression_strings_{frontend}.f90')
    routine = Subroutine.from_source(fcode, frontend=frontend)

    function = jit_compile(routine, filepath=filepath, objname='strings')
    output_file = tmp_path/filehash(str(filepath), prefix='', suffix='.log')
    with capsys.disabled():
        with stdchannel_redirected(sys.stdout, output_file):
            function()

    with open(output_file, 'r') as f:
        output_str = f.read()

    assert output_str == ' Hello world!\n 42!\n'


@pytest.mark.parametrize('frontend', available_frontends())
def test_very_long_statement(tmp_path, frontend):
    """
    Test a long statement with line breaks.
    """
    fcode = """
subroutine very_long_statement(scalar, res)
  integer, intent(in) :: scalar
  integer, intent(out) :: res

  res = 5 * scalar + scalar - scalar + scalar - scalar + (scalar - scalar &
      & + scalar - scalar) - 1 + 2 - 3 + 4 - 5 + 6 - 7 + 8 - (9 + 10      &
        - 9) + 10 - 8 + 7 - 6 + 5 - 4 + 3 - 2 + 1
end subroutine very_long_statement
"""
    filepath = tmp_path/(f'expression_very_long_statement_{frontend}.f90')
    routine = Subroutine.from_source(fcode, frontend=frontend)
    function = jit_compile(routine, filepath=filepath, objname='very_long_statement')

    scalar = 1
    result = function(scalar)
    assert result == 5
    clean_test(filepath)


@pytest.mark.parametrize('frontend', available_frontends())
def test_output_intrinsics(frontend):
    """
    Some collected intrinsics or other edge cases that failed in cloudsc.
    """
    fcode = """
subroutine output_intrinsics
     integer, parameter :: jprb = selected_real_kind(13,300)
     integer :: numomp, ngptot
     real(kind=jprb) :: tdiff

     numomp = 1
     ngptot = 2
     tdiff = 1.2

1002 format(1x, 2i10, 1x, i4, ' : ', i10)
     write(0, 1002) numomp, ngptot, - 1, int(tdiff * 1000.0_jprb)
end subroutine output_intrinsics
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    ref = ['format(1x, 2i10, 1x, i4, \' : \', i10)',
           'write(0, 1002) numomp, ngptot, - 1, int(tdiff * 1000.0_jprb)']

    if frontend == OMNI:
        ref[0] = ref[0].replace("'", '"')
        ref[1] = ref[1].replace('0, 1002', 'unit=0, fmt=1002')
        ref[1] = ref[1].replace(' * ', '*')
        ref[1] = ref[1].replace('- 1', '-1')

    intrinsics = FindNodes(ir.Intrinsic).visit(routine.body)
    assert len(intrinsics) == 2
    assert intrinsics[0].text.lower() == ref[0]
    assert intrinsics[1].text.lower() == ref[1]
    assert fgen(intrinsics).lower() == '{} {}\n{}'.format('1002', *ref)


@pytest.mark.parametrize('frontend', available_frontends())
def test_nested_call_inline_call(tmp_path, frontend):
    """
    The purpose of this test is to highlight the differences between calls in expression
    (such as `InlineCall`, `Cast`) and call nodes in the IR.
    """
    fcode = """
subroutine simple_expr(v1, v2, v3, v4, v5, v6)
  ! simple floating point arithmetic
  integer, parameter :: jprb = selected_real_kind(13,300)
  real(kind=jprb), intent(in) :: v1, v2, v3, v4
  real(kind=jprb), intent(out) :: v5, v6

  v5 = (v1 + v2) * (v3 - v4)
  v6 = (v1 ** v2) - (v3 / v4)
end subroutine simple_expr

subroutine very_long_statement(scalar, res)
  integer, intent(in) :: scalar
  integer, intent(out) :: res

  res = 5 * scalar + scalar - scalar + scalar - scalar + (scalar - scalar &
        + scalar - scalar) - 1 + 2 - 3 + 4 - 5 + 6 - 7 + 8 - (9 + 10      &
        - 9) + 10 - 8 + 7 - 6 + 5 - 4 + 3 - 2 + 1
end subroutine very_long_statement

subroutine nested_call_inline_call(v1, v2, v3)
  integer, parameter :: jprb = selected_real_kind(13,300)
  integer, intent(in) :: v1
  real(kind=jprb), intent(out) :: v2
  integer, intent(out) :: v3
  real(kind=jprb) :: tmp1, tmp2

  tmp1 = real(1, kind=jprb)
  call simple_expr(tmp1, abs(-2.0_jprb), 3.0_jprb, real(v1, jprb), v2, tmp2)
  v2 = abs(tmp2 - v2)
  call very_long_statement(int(v2), v3)
end subroutine nested_call_inline_call
"""
    filepath = tmp_path/(f'expression_nested_call_inline_call_{frontend}.f90')
    routine = Sourcefile.from_source(fcode, frontend=frontend)
    function = jit_compile(routine, filepath=filepath, objname='nested_call_inline_call')

    v2, v3 = function(1)
    assert v2 == 8.
    assert v3 == 40
    clean_test(filepath)


@pytest.mark.parametrize('frontend', available_frontends())
def test_no_arg_inline_call(frontend, tmp_path):
    """
    Make sure that no-argument function calls are recognized as such,
    especially when their implementation is unknown.
    """
    fcode_mod = """
module external_mod
  implicit none
contains
  function my_func()
    integer :: my_func
    my_func = 2
  end function my_func
end module external_mod
    """.strip()

    fcode_routine = """
subroutine my_routine(var)
  use external_mod, only: my_func
  implicit none
  integer, intent(out) :: var
  var = my_func()
end subroutine my_routine
    """

    if frontend != OMNI:
        routine = Subroutine.from_source(fcode_routine, frontend=frontend)
        assert routine.symbol_attrs['my_func'].dtype is BasicType.DEFERRED
        assignment = FindNodes(ir.Assignment).visit(routine.body)[0]
        assert assignment.lhs == 'var'
        assert isinstance(assignment.rhs, sym.InlineCall)
        assert isinstance(assignment.rhs.function, sym.DeferredTypeSymbol)

    module = Module.from_source(fcode_mod, frontend=frontend, xmods=[tmp_path])
    routine = Subroutine.from_source(fcode_routine, frontend=frontend, definitions=module, xmods=[tmp_path])
    assert isinstance(routine.symbol_attrs['my_func'].dtype, ProcedureType)
    assignment = FindNodes(ir.Assignment).visit(routine.body)[0]
    assert assignment.lhs == 'var'
    assert isinstance(assignment.rhs, sym.InlineCall)
    assert isinstance(assignment.rhs.function, sym.ProcedureSymbol)


@pytest.mark.parametrize('frontend', available_frontends())
def test_kwargs_inline_call(frontend, tmp_path):
    """
    Test inline call with kwargs and correct sorting as well
    as correct conversion to args.
    """
    fcode_routine = """
subroutine my_kwargs_routine(var, v_a, v_b, v_c, v_d)
  implicit none
  integer, intent(out) :: var
  integer, intent(in) :: v_a, v_b, v_c, v_d
  var = my_kwargs_func(c=v_c, b=v_b, a=v_a, d=v_d)
contains
  function my_kwargs_func(a, b, c, d)
    integer, intent(in) :: a, b, c, d
    integer :: my_kwargs_func
    my_kwargs_func = a - b - c - d
  end function my_kwargs_func
end subroutine my_kwargs_routine
    """
    # Test the original implementation
    filepath = tmp_path/(f'orig_expression_kwargs_call_{frontend}.f90')
    routine = Subroutine.from_source(fcode_routine, frontend=frontend, xmods=[tmp_path])
    function = jit_compile(routine, filepath=filepath, objname='my_kwargs_routine')
    res_orig = function(100, 10, 5, 2)
    assert res_orig == 83

    # Sort the kwargs and test the transformed code
    inline_call = list(FindInlineCalls().visit(routine.body))[0]
    call_map = {inline_call: inline_call.clone_with_sorted_kwargs()}
    routine.body = SubstituteExpressions(call_map).visit(routine.body)
    inline_call = list(FindInlineCalls().visit(routine.body))[0]
    assert inline_call.is_kwargs_order_correct()
    assert not inline_call.arguments
    assert inline_call.kwarguments == (('a', 'v_a'), ('b', 'v_b'), ('c', 'v_c'), ('d', 'v_d'))
    filepath = tmp_path/(f'sorted_expression_kwargs_call_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname='my_kwargs_routine')
    res_sorted = function(100, 10, 5, 2)
    assert res_sorted == 83

    # Convert kwargs to args and test the transformed code
    call_map = {inline_call: inline_call.clone_with_kwargs_as_args()}
    routine.body = SubstituteExpressions(call_map).visit(routine.body)
    inline_call = list(FindInlineCalls().visit(routine.body))[0]
    assert not inline_call.kwarguments
    filepath = tmp_path/(f'converted_expression_kwargs_call_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname='my_kwargs_routine')
    res_args = function(100, 10, 5, 2)
    assert res_args == 83


@pytest.mark.parametrize('frontend', available_frontends())
def test_inline_call_derived_type_arguments(frontend, tmp_path):
    """
    Check that derived type arguments are correctly represented in
    function calls that include keyword parameters.

    This is due to fparser's habit of sometimes representing function calls
    wrongly as structure constructors, which are handled differently in
    Loki's frontend
    """
    fcode = """
module inline_call_mod
    implicit none

    type mytype
        integer :: val
        integer :: arr(3)
    contains
        procedure :: some_func
    end type mytype

contains

    function check(val, thr) result(is_bad)
        integer, intent(in) :: val
        integer, intent(in), optional :: thr
        integer :: eff_thr
        logical :: is_bad
        if (present(thr)) then
            eff_thr = thr
        else
            eff_thr = 10
        end if
        is_bad = val > thr
    end function check

    function some_func(this) result(is_bad)
        class(mytype), intent(in) :: this
        logical :: is_bad

        is_bad = check(this%val, thr=10) &
            &   .or. check(this%arr(1)) .or. check(val=this%arr(2)) .or. check(this%arr(3))
    end function some_func
end module inline_call_mod
    """.strip()
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    some_func = module['some_func']
    inline_calls = FindInlineCalls().visit(some_func.body)
    assert len(inline_calls) == 4
    assert {fgen(c) for c in inline_calls} == {
        'check(this%val, thr=10)', 'check(this%arr(1))', 'check(val=this%arr(2))', 'check(this%arr(3))'
    }


@pytest.mark.parametrize('frontend', available_frontends())
def test_character_concat(tmp_path, frontend):
    """
    Concatenation operator ``//``
    """
    fcode = """
subroutine character_concat(string)
  character(10) :: tmp_str1, tmp_str2
  character(len=12), intent(out) :: string

  tmp_str1 = "Hel" // "lo"
  tmp_str2 = "wor" // "l" // "d"
  string = trim(tmp_str1) // " " // trim(tmp_str2)
  string = trim(string) // "!"
end subroutine character_concat
"""
    filepath = tmp_path/(f'expression_character_concat_{frontend}.f90')
    routine = Subroutine.from_source(fcode, frontend=frontend)
    function = jit_compile(routine, filepath=filepath, objname='character_concat')

    result = function()
    assert result == b'Hello world!'
    clean_test(filepath)


@pytest.mark.parametrize('frontend', available_frontends())
def test_masked_statements(tmp_path, frontend):
    """
    Masked statements (WHERE(...) ... [ELSEWHERE ...] ENDWHERE)
    """
    fcode = """
subroutine expression_masked_statements(length, vec1, vec2, vec3)
  integer, parameter :: jprb = selected_real_kind(13,300)
  integer, intent(in) :: length
  real(kind=jprb), intent(inout), dimension(length) :: vec1, vec2, vec3

  where (vec1(:) > 5.0_jprb)
    vec1(:) = 7.0_jprb
    vec1(:) = 5.0_jprb
  endwhere

  where (vec2(:) < -0.d1)
    vec2(:) = -1.0_jprb
  elsewhere (vec2(:) > 0.d1)
    vec2(:) = 1.0_jprb
  elsewhere
    vec2(:) = 0.0_jprb
  endwhere

  where (0.0_jprb < vec3(:) .and. vec3(:) < 3.0_jprb) vec3(:) = 1.0_jprb
end subroutine expression_masked_statements
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/(f'{routine.name}_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname=routine.name)

    # Reference solution
    length = 11
    ref1 = np.append(np.arange(0, 6, dtype=np.float64),
                     5 * np.ones(length - 6, dtype=np.float64))
    ref2 = np.append(np.append(-1 *np.ones(5, dtype=np.float64), 0.0),
                     np.ones(5, dtype=np.float64))
    ref3 = np.append(np.arange(-2, 1, dtype=np.float64), np.ones(2, dtype=np.float64))
    ref3 = np.append(ref3, np.arange(3, length - 2, dtype=np.float64))

    vec1 = np.arange(0, length, dtype=np.float64)
    vec2 = np.arange(-5, length - 5, dtype=np.float64)
    vec3 = np.arange(-2, length - 2, dtype=np.float64)
    function(length, vec1, vec2, vec3)
    assert np.all(ref1 == vec1)
    assert np.all(ref2 == vec2)
    assert np.all(ref3 == vec3)
    clean_test(filepath)


@pytest.mark.parametrize('frontend', available_frontends())
def test_masked_statements_nested(tmp_path, frontend):
    """
    Nested masked statements (WHERE(...) ... [ELSEWHERE ...] ENDWHERE)
    """
    fcode = """
subroutine expression_nested_masked_statements(length, vec1)
    integer, parameter :: jprb = selected_real_kind(13,300)
    integer, intent(in) :: length
    real(kind=jprb), intent(inout), dimension(length) :: vec1

    where (vec1(:) >= 4.0_jprb)
        where (vec1(:) > 6.0_jprb)
            vec1(:) = 6.0_jprb
        elsewhere
            vec1(:) = 4.0_jprb
        endwhere
    elsewhere
        where (vec1(:) < 2.0_jprb)
            vec1(:) = 0.0_jprb
        elsewhere
            vec1(:) = 2.0_jprb
        endwhere
    endwhere
end subroutine expression_nested_masked_statements
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)
    filepath = tmp_path/(f'{routine.name}_{frontend}.f90')
    function = jit_compile(routine, filepath=filepath, objname=routine.name)

    # Reference solution
    length = 11
    vec1 = np.arange(0, length, dtype=np.float64)
    ref1 = np.zeros(length, dtype=np.float64)
    ref1[vec1 >= 4.0] = 4.0
    ref1[vec1 > 6.0] = 6.0
    ref1[vec1 < 4.0] = 2.0
    ref1[vec1 < 2.0] = 0.0
    function(length, vec1)
    assert np.all(ref1 == vec1)
    clean_test(filepath)


@pytest.mark.parametrize('frontend', available_frontends())
def test_pointer_nullify(tmp_path, frontend):
    """
    POINTERS and their nullification via '=> NULL()'
    """
    fcode = """
subroutine pointer_nullify()
  implicit none
  character(len=64), dimension(:), pointer :: charp => NULL()
  character(len=64), pointer :: pp => NULL()
  allocate(charp(3))
  charp(:) = "_ptr_"
  pp => charp(1)
  pp = "_other_ptr_"
  nullify(pp)
  deallocate(charp)
  charp => NULL()
end subroutine pointer_nullify
"""
    filepath = tmp_path/(f'expression_pointer_nullify_{frontend}.f90')
    routine = Subroutine.from_source(fcode, frontend=frontend)

    assert np.all(v.type.pointer for v in routine.variables)
    assert np.all(isinstance(v.initial, sym.InlineCall) and v.type.initial.name.lower() == 'null'
                  for v in routine.variables)
    nullify_stmts = FindNodes(ir.Nullify).visit(routine.body)
    assert len(nullify_stmts) == 1
    assert nullify_stmts[0].variables[0].name == 'pp'
    assert [stmt.ptr for stmt in FindNodes(ir.Assignment).visit(routine.body)].count(True) == 2

    # Execute the generated identity (to verify it is valid Fortran)
    function = jit_compile(routine, filepath=filepath, objname='pointer_nullify')
    function()
    clean_test(filepath)


@pytest.mark.parametrize('frontend', available_frontends())
def test_parameter_stmt(tmp_path, frontend):
    """
    PARAMETER(...) statement
    """
    fcode = """
subroutine parameter_stmt(out1)
  implicit none
  integer, parameter :: jprb = selected_real_kind(13,300)
  real(kind=jprb) :: param
  parameter(param=2.0)
  real(kind=jprb), intent(out) :: out1

  out1 = param
end subroutine parameter_stmt
"""
    filepath = tmp_path/(f'expression_parameter_stmt_{frontend}.f90')
    routine = Subroutine.from_source(fcode, frontend=frontend)
    function = jit_compile(routine, filepath=filepath, objname='parameter_stmt')

    out1 = function()
    assert out1 == 2.0
    clean_test(filepath)


def test_string_compare():
    """
    Test that we can identify symbols and expressions by equivalent strings.

    Note that this only captures comparsion of a canonical string representation,
    not full symbolic equivalence.
    """
    # Utility objects for manual expression creation
    scope = Scope()
    type_int = SymbolAttributes(dtype=BasicType.INTEGER)
    type_real = SymbolAttributes(dtype=BasicType.REAL)

    i = sym.Variable(name='i', scope=scope, type=type_int)
    j = sym.Variable(name='j', scope=scope, type=type_int)

    # Test a scalar variable
    u = sym.Variable(name='u', scope=scope, type=SymbolAttributes(dtype=BasicType.REAL))
    assert all(u == exp for exp in ['u', 'U', 'u ', 'U '])
    assert not all(u == exp for exp in ['u()', '_u', 'U()', '_U'])

    # Test an array variable
    v = sym.Variable(name='v', dimensions=(i, j), scope=scope, type=type_real)
    assert all(v == exp for exp in ['v(i,j)', 'v(i, j)', 'v (i , j)', 'V(i,j)', 'V(I, J)'])
    assert not all(v == exp for exp in ['v(i,j())', 'v(i,_j)', '_V(i,j)'])

    # Test a standard array dimension range
    r = sym.RangeIndex(children=(i, j))
    w = sym.Variable(name='w', dimensions=(r,), scope=scope, type=type_real)
    assert all(w == exp for exp in ['w(i:j)', 'w (i : j)', 'W(i:J)', ' w( I:j)'])

    # Test simple arithmetic expressions
    assert all(sym.Sum((i, u)) == exp for exp in ['i+u', 'i + u', 'i +  U', ' I + u'])
    assert all(sym.Product((i, u)) == exp for exp in ['i*u', 'i * u', 'i *  U', ' I * u'])
    assert all(sym.Quotient(i, u) == exp for exp in ['i/u', 'i / u', 'i /  U', ' I / u'])
    assert all(sym.Power(i, u) == exp for exp in ['i**u', 'i ** u', 'i **  U', ' I ** u'])
    assert all(sym.Comparison(i, '==', u) == exp for exp in ['i==u', 'i == u', 'i ==  U', ' I == u'])
    assert all(sym.LogicalAnd((i, u)) == exp for exp in ['i AND u', 'i and u', 'i and  U', ' I and u'])
    assert all(sym.LogicalOr((i, u)) == exp for exp in ['i OR u', 'i or u', 'i or  U', ' I oR u'])
    assert all(sym.LogicalNot(u) == exp for exp in ['not u', ' nOt u', 'not  U', ' noT u'])

    # Test literal behaviour
    assert sym.Literal(41) == 41
    assert sym.Literal(41) == '41'
    assert sym.Literal(41) != sym.Literal(41, kind='jpim')
    assert sym.Literal(66.6) == 66.6
    assert sym.Literal(66.6) == '66.6'
    assert sym.Literal(66.6) != sym.Literal(66.6, kind='jprb')
    assert sym.Literal('u') == 'u'
    assert sym.Literal('u') != 'U'
    assert sym.Literal('u') != u  # The `Variable(name='u', ...) from above
    assert sym.Literal('.TrUe.') == 'true'
    # Specific test for constructor checks
    assert sym.LogicLiteral(value=True) == 'true'


@pytest.mark.parametrize('expr, string, ref', [
    ('a + 1', 'a', True),
    ('u(a)', 'a', True),
    ('u(a + 1)', 'a', True),
    ('u(a + 1) + 2', 'u(a + 1)', True),
    ('ansatz(a + 1)', 'a', True),
    ('ansatz(b + 1)', 'a', False),  # Ensure no false positives
])
@pytest.mark.parametrize('parse', (
    parse_expr,
    pytest.param(parse_fparser_expression,
        marks=pytest.mark.skipif(not HAVE_FP, reason='parse_fparser_expression not available!'))
))
def test_subexpression_match(parse, expr, string, ref):
    """
    Test that we can identify individual symbols or sub-expressions in
    expressions via canonical string matching.
    """
    scope = Scope()
    expr = parse(expr, scope)
    assert (string in expr) == ref


@pytest.mark.parametrize('kwargs,reftype', [
    ({}, sym.DeferredTypeSymbol),
    ({'type': SymbolAttributes(BasicType.DEFERRED)}, sym.DeferredTypeSymbol),
    ({'type': SymbolAttributes(BasicType.INTEGER)}, sym.Scalar),
    ({'type': SymbolAttributes(BasicType.REAL)}, sym.Scalar),
    ({'type': SymbolAttributes(DerivedType('t'))}, sym.Scalar),
    ({'type': SymbolAttributes(BasicType.INTEGER, shape=(sym.Literal(3),))}, sym.Array),
    ({'type': SymbolAttributes(BasicType.INTEGER, shape=(sym.Literal(3),)),
      'dimensions': (sym.Literal(1),)}, sym.Array),
    ({'type': SymbolAttributes(BasicType.INTEGER), 'dimensions': (sym.Literal(1),)}, sym.Array),
    ({'type': SymbolAttributes(BasicType.DEFERRED), 'dimensions': (sym.Literal(1),)}, sym.Array),
    ({'type': SymbolAttributes(ProcedureType('routine'))}, sym.ProcedureSymbol),
])
def test_variable_factory(kwargs, reftype):
    """
    Test the factory class :any:`Variable` and the dispatch to correct classes.
    """
    scope = Scope()
    assert isinstance(sym.Variable(name='var', scope=scope, **kwargs), reftype)


def test_variable_factory_invalid():
    """
    Test invalid variable instantiations
    """
    with pytest.raises(KeyError):
        _ = sym.Variable()


@pytest.mark.parametrize('initype,inireftype,newtype,newreftype', [
    # From deferred type to other type
    (SymbolAttributes(BasicType.DEFERRED), sym.DeferredTypeSymbol,
     SymbolAttributes(BasicType.DEFERRED), sym.DeferredTypeSymbol),
    (SymbolAttributes(BasicType.DEFERRED), sym.DeferredTypeSymbol,
     SymbolAttributes(BasicType.INTEGER), sym.Scalar),
    (SymbolAttributes(BasicType.DEFERRED), sym.DeferredTypeSymbol,
     SymbolAttributes(BasicType.REAL), sym.Scalar),
    (SymbolAttributes(BasicType.DEFERRED), sym.DeferredTypeSymbol,
     SymbolAttributes(DerivedType('t')), sym.Scalar),
    (SymbolAttributes(BasicType.DEFERRED), sym.DeferredTypeSymbol,
     SymbolAttributes(BasicType.INTEGER, shape=(sym.Literal(4),)), sym.Array),
    (SymbolAttributes(BasicType.DEFERRED), sym.DeferredTypeSymbol,
     SymbolAttributes(ProcedureType('routine')), sym.ProcedureSymbol),
    (None, sym.DeferredTypeSymbol, SymbolAttributes(BasicType.INTEGER), sym.Scalar),
    # From Scalar to other type
    (SymbolAttributes(BasicType.INTEGER), sym.Scalar,
     SymbolAttributes(BasicType.DEFERRED), sym.DeferredTypeSymbol),
    (SymbolAttributes(BasicType.INTEGER), sym.Scalar,
     SymbolAttributes(BasicType.INTEGER, shape=(sym.Literal(3),)), sym.Array),
    (SymbolAttributes(BasicType.INTEGER), sym.Scalar,
     SymbolAttributes(ProcedureType('foo')), sym.ProcedureSymbol),
    # From Array to other type
    (SymbolAttributes(BasicType.INTEGER, shape=(sym.Literal(4),)), sym.Array,
     SymbolAttributes(BasicType.INTEGER), sym.Scalar),
    (SymbolAttributes(BasicType.INTEGER, shape=(sym.Literal(4),)), sym.Array,
     SymbolAttributes(BasicType.DEFERRED), sym.DeferredTypeSymbol),
    (SymbolAttributes(BasicType.INTEGER, shape=(sym.Literal(4),)), sym.Array,
     SymbolAttributes(ProcedureType('foo')), sym.ProcedureSymbol),
    # From ProcedureSymbol to other type
    (SymbolAttributes(ProcedureType('foo')), sym.ProcedureSymbol,
     SymbolAttributes(BasicType.DEFERRED), sym.DeferredTypeSymbol),
    (SymbolAttributes(ProcedureType('foo')), sym.ProcedureSymbol,
     SymbolAttributes(BasicType.INTEGER), sym.Scalar),
    (SymbolAttributes(ProcedureType('foo')), sym.ProcedureSymbol,
     SymbolAttributes(BasicType.INTEGER, shape=(sym.Literal(5),)), sym.Array),
])
def test_variable_rebuild(initype, inireftype, newtype, newreftype):
    """
    Test that rebuilding a variable object changes class according to symmbol type
    """
    scope = Scope()
    var = sym.Variable(name='var', scope=scope, type=initype)
    assert isinstance(var, inireftype)
    assert 'var' in scope.symbol_attrs
    scope.symbol_attrs['var'] = newtype
    assert isinstance(var, inireftype)
    var = var.clone()  # pylint: disable=no-member
    assert isinstance(var, newreftype)


@pytest.mark.parametrize('initype,inireftype,newtype,newreftype', [
    # From deferred type to other type
    (SymbolAttributes(BasicType.DEFERRED), sym.DeferredTypeSymbol,
     SymbolAttributes(BasicType.DEFERRED), sym.DeferredTypeSymbol),
    (SymbolAttributes(BasicType.DEFERRED), sym.DeferredTypeSymbol,
     SymbolAttributes(BasicType.INTEGER), sym.Scalar),
    (SymbolAttributes(BasicType.DEFERRED), sym.DeferredTypeSymbol,
     SymbolAttributes(BasicType.REAL), sym.Scalar),
    (SymbolAttributes(BasicType.DEFERRED), sym.DeferredTypeSymbol,
     SymbolAttributes(DerivedType('t')), sym.Scalar),
    (SymbolAttributes(BasicType.DEFERRED), sym.DeferredTypeSymbol,
     SymbolAttributes(BasicType.INTEGER, shape=(sym.Literal(4),)), sym.Array),
    (SymbolAttributes(BasicType.DEFERRED), sym.DeferredTypeSymbol,
     SymbolAttributes(ProcedureType('routine')), sym.ProcedureSymbol),
    (None, sym.DeferredTypeSymbol, SymbolAttributes(BasicType.INTEGER), sym.Scalar),
    # From Scalar to other type
    (SymbolAttributes(BasicType.INTEGER), sym.Scalar,
     SymbolAttributes(BasicType.DEFERRED), sym.DeferredTypeSymbol),
    (SymbolAttributes(BasicType.INTEGER), sym.Scalar,
     SymbolAttributes(BasicType.INTEGER, shape=(sym.Literal(3),)), sym.Array),
    (SymbolAttributes(BasicType.INTEGER), sym.Scalar,
     SymbolAttributes(ProcedureType('foo')), sym.ProcedureSymbol),
    # From Array to other type
    (SymbolAttributes(BasicType.INTEGER, shape=(sym.Literal(4),)), sym.Array,
     SymbolAttributes(BasicType.INTEGER), sym.Scalar),
    (SymbolAttributes(BasicType.INTEGER, shape=(sym.Literal(4),)), sym.Array,
     SymbolAttributes(BasicType.DEFERRED), sym.DeferredTypeSymbol),
    (SymbolAttributes(BasicType.INTEGER, shape=(sym.Literal(4),)), sym.Array,
     SymbolAttributes(ProcedureType('foo')), sym.ProcedureSymbol),
    # From ProcedureSymbol to other type
    (SymbolAttributes(ProcedureType('foo')), sym.ProcedureSymbol,
     SymbolAttributes(BasicType.DEFERRED), sym.DeferredTypeSymbol),
    (SymbolAttributes(ProcedureType('foo')), sym.ProcedureSymbol,
     SymbolAttributes(BasicType.INTEGER), sym.Scalar),
    (SymbolAttributes(ProcedureType('foo')), sym.ProcedureSymbol,
     SymbolAttributes(BasicType.INTEGER, shape=(sym.Literal(5),)), sym.Array),
])
def test_variable_clone_class(initype, inireftype, newtype, newreftype):
    """
    Test that cloning a variable object changes class according to symbol type
    """
    scope = Scope()
    var = sym.Variable(name='var', scope=scope, type=initype)
    assert isinstance(var, inireftype)
    assert 'var' in scope.symbol_attrs
    var = var.clone(type=newtype)  # pylint: disable=no-member
    assert isinstance(var, newreftype)

@pytest.mark.parametrize('initype,newtype,reftype', [
    # Preserve existing type info if type=None is given
    (SymbolAttributes(BasicType.REAL), None, SymbolAttributes(BasicType.REAL)),
    (SymbolAttributes(BasicType.INTEGER), None, SymbolAttributes(BasicType.INTEGER)),
    (SymbolAttributes(BasicType.DEFERRED), None, SymbolAttributes(BasicType.DEFERRED)),
    (SymbolAttributes(BasicType.DEFERRED, intent='in'), None,
     SymbolAttributes(BasicType.DEFERRED, intent='in')),
    # Update from deferred to known type
    (SymbolAttributes(BasicType.DEFERRED), SymbolAttributes(BasicType.INTEGER),
     SymbolAttributes(BasicType.INTEGER)),
    (SymbolAttributes(BasicType.DEFERRED), SymbolAttributes(BasicType.REAL),
     SymbolAttributes(BasicType.REAL)),
    (SymbolAttributes(BasicType.DEFERRED), SymbolAttributes(BasicType.DEFERRED, intent='in'),
     SymbolAttributes(BasicType.DEFERRED, intent='in')),  # Special case: Add attribute only
    # Invalidate type by setting to DEFERRED
    (SymbolAttributes(BasicType.INTEGER), SymbolAttributes(BasicType.DEFERRED),
     SymbolAttributes(BasicType.DEFERRED)),
    (SymbolAttributes(BasicType.REAL), SymbolAttributes(BasicType.DEFERRED),
     SymbolAttributes(BasicType.DEFERRED)),
    (SymbolAttributes(BasicType.DEFERRED, intent='in'), SymbolAttributes(BasicType.DEFERRED),
     SymbolAttributes(BasicType.DEFERRED)),
])
def test_variable_clone_type(initype, newtype, reftype):
    """
    Test type updates are handled as expected and types are never ``None``.
    """
    scope = Scope()
    var = sym.Variable(name='var', scope=scope, type=initype)
    assert 'var' in scope.symbol_attrs
    new = var.clone(type=newtype)  # pylint: disable=no-member
    assert new.type == reftype


def test_variable_without_scope():
    """
    Test that creating variables without scope works and scopes can be
    attached and detached
    """
    # pylint: disable=no-member
    # Create a plain variable without type or scope
    var = sym.Variable(name='var')
    assert isinstance(var, sym.DeferredTypeSymbol)
    assert var.type and var.type.dtype is BasicType.DEFERRED
    # Attach a scope with a data type for this variable
    scope = Scope()
    scope.symbol_attrs['var'] = SymbolAttributes(BasicType.INTEGER)
    assert isinstance(var, sym.DeferredTypeSymbol)
    assert var.type and var.type.dtype is BasicType.DEFERRED
    var = var.clone(scope=scope)
    assert var.scope is scope
    assert isinstance(var, sym.Scalar)
    assert var.type.dtype is BasicType.INTEGER
    # Change the data type via constructor
    var = var.clone(type=SymbolAttributes(BasicType.REAL))
    assert isinstance(var, sym.Scalar)
    assert var.type.dtype is BasicType.REAL
    assert scope.symbol_attrs['var'].dtype is BasicType.REAL
    # Detach the scope (type remains)
    var = var.clone(scope=None)
    assert var.scope is None
    assert isinstance(var, sym.Scalar)
    assert var.type.dtype is BasicType.REAL
    assert scope.symbol_attrs['var'].dtype is BasicType.REAL
    # Assign a data type locally
    var = var.clone(type=SymbolAttributes(BasicType.LOGICAL))
    assert var.scope is None
    assert isinstance(var, sym.Scalar)
    assert var.type.dtype is BasicType.LOGICAL
    assert scope.symbol_attrs['var'].dtype is BasicType.REAL
    # Re-attach the scope without specifying type
    var = var.clone(scope=scope, type=None)
    assert var.scope is scope
    assert isinstance(var, sym.Scalar)
    assert var.type.dtype is BasicType.REAL
    assert scope.symbol_attrs['var'].dtype is BasicType.REAL
    # Detach the scope and specify new type
    var = var.clone(scope=None, type=SymbolAttributes(BasicType.LOGICAL))
    assert var.scope is None
    assert isinstance(var, sym.Scalar)
    assert var.type.dtype is BasicType.LOGICAL
    assert scope.symbol_attrs['var'].dtype is BasicType.REAL
    # Rescope (doesn't overwrite scope-stored type with local type)
    rescoped_var = var.rescope(scope)
    assert rescoped_var.scope is scope
    assert isinstance(rescoped_var, sym.Scalar)
    assert rescoped_var.type.dtype is BasicType.REAL
    assert scope.symbol_attrs['var'].dtype is BasicType.REAL
    # Re-attach the scope (uses scope-stored type over local type)
    var = var.clone(scope=scope)
    assert var.scope is scope
    assert isinstance(var, sym.Scalar)
    assert var.type.dtype is BasicType.REAL
    assert scope.symbol_attrs['var'].dtype is BasicType.REAL


@pytest.mark.parametrize('expr', [
    ('1.8 - 3.E-03*ztp1'),
    ('1.8 - 0.003*ztp1'),
    ('(a / b) + 3.0_jprb'),
    ('a / b*3.0_jprb'),
    ('-5*3 + (-(5*3))'),
    ('5 + (-1)'),
    ('5 - 1')
])
@pytest.mark.parametrize('parse', (
    parse_expr,
    pytest.param(parse_fparser_expression,
        marks=pytest.mark.skipif(not HAVE_FP, reason='parse_fparser_expression not available!'))
))
def test_standalone_expr_parenthesis(expr, parse):
    scope = Scope()
    ir = parse(expr, scope)  # pylint: disable=redefined-outer-name
    assert isinstance(ir, pmbl.Expression)
    assert fgen(ir) == expr


@pytest.mark.parametrize('parse', (
    parse_expr,
    pytest.param(parse_fparser_expression,
        marks=pytest.mark.skipif(not HAVE_FP, reason='parse_fparser_expression not available!'))
))
def test_array_to_inline_call_rescope(parse):
    """
    Test a mechanism that can convert arrays to procedure calls, to mop up
    broken frontend behaviour wrongly classifying inline calls as array subscripts
    """
    # Parse the expression, which fparser will interpret as an array
    scope = Scope()
    expr = parse('FLUX%OUT_OF_PHYSICAL_BOUNDS(KIDIA, KFDIA)', scope=scope)
    assert isinstance(expr, sym.Array)

    # Detach the expression from the scope and update the type information in the scope
    expr = expr.clone(scope=None)
    return_type = SymbolAttributes(BasicType.INTEGER)
    proc_type = ProcedureType('out_of_physical_bounds', is_function=True, return_type=return_type)
    scope.symbol_attrs['flux%out_of_physical_bounds'] = SymbolAttributes(proc_type)

    # Re-attach the scope to trigger the rescoping (and symbol rebuild)
    expr = AttachScopesMapper()(expr, scope=scope)
    assert isinstance(expr, sym.InlineCall)
    assert expr.function.type.dtype is proc_type
    assert expr.function == 'flux%out_of_physical_bounds'
    assert expr.parameters == ('kidia', 'kfdia')


@pytest.mark.parametrize('frontend', available_frontends())
def test_recursive_substitution(frontend):
    """
    Test expression substitution where the substitution key is included
    in the replacement
    """
    fcode = """
subroutine my_routine(var, n)
    real, intent(inout) :: var(:)
    integer, intent(in) :: n
    integer j
    do j=1,n
        var(j) = 1.
    end do
end subroutine my_routine
    """.strip()

    routine = Subroutine.from_source(fcode, frontend=frontend)
    assignment = FindNodes(ir.Assignment).visit(routine.body)[0]
    assert assignment.lhs == 'var(j)'

    # Replace Array subscript by j+1
    j = routine.variable_map['j']
    expr_map = {j: sym.Sum((j, sym.Literal(1)))}
    assert j in FindVariables().visit(list(expr_map.values()))
    routine.body = SubstituteExpressions(expr_map).visit(routine.body)
    assignment = FindNodes(ir.Assignment).visit(routine.body)[0]
    assert assignment.lhs == 'var(j + 1)'


def test_nested_derived_type_substitution():
    """
    Test that :any:`SubstituteExpressions` can properly replace scalar
    parents when type is not changed
    """

    type_int = SymbolAttributes(dtype=BasicType.INTEGER)
    original = sym.Scalar(name='ydphy3')
    expr = sym.Scalar(name='n_spband', type=type_int, parent=sym.Scalar(name='ydphy3'))
    replace = sym.Scalar(name='yrphy3', parent=sym.Scalar(name='ydml_phy_mf'))
    new_expr = SubstituteExpressions({original:replace}).visit(expr)

    assert fgen(new_expr) == 'ydml_phy_mf%yrphy3%n_spband'


@pytest.mark.parametrize('frontend', available_frontends())
def test_variable_in_declaration_initializer(frontend):
    """
    Check correct handling of cases where the variable appears
    in the initializer expression (i.e. no infinite recursion)
    """
    fcode = """
subroutine some_routine(var)
implicit none
INTEGER, PARAMETER :: JPRB = SELECTED_REAL_KIND(13,300)
REAL(KIND=JPRB), PARAMETER :: ZEXPLIMIT = LOG(HUGE(ZEXPLIMIT))
real(kind=jprb), intent(inout) :: var
var = var + ZEXPLIMIT
end subroutine some_routine
    """.strip()

    def _check(routine_):
        # A few sanity checks
        assert 'zexplimit' in routine_.variable_map
        zexplimit = routine_.variable_map['zexplimit']
        assert zexplimit.scope is routine_
        # Now let's take a closer look at the initializer expression
        assert 'zexplimit' in str(zexplimit.type.initial).lower()
        variables = FindVariables().visit(zexplimit.type.initial)
        assert 'zexplimit' in variables
        assert variables[variables.index('zexplimit')].scope is routine_

    routine = Subroutine.from_source(fcode, frontend=frontend)
    _check(routine)
    # Make sure that's still true when doing another scope attachment
    routine.rescope_symbols()
    _check(routine)


@pytest.mark.parametrize('frontend', available_frontends())
def test_variable_in_dimensions(frontend, tmp_path):
    """
    Check correct handling of cases where the variable appears in the
    dimensions expression of the same variable (i.e. do not cause
    infinite recursion)
    """
    fcode = """
module some_mod
    implicit none

    type multi_level
        real, allocatable :: data(:, :)
    end type multi_level
contains
    subroutine some_routine(levels, num_levels)
        type(multi_level), intent(inout) :: levels(:)
        integer, intent(in) :: num_levels
        integer jscale

        do jscale = 2,num_levels
            allocate(levels(jscale)%data(size(levels(jscale-1)%data,1), size(levels(jscale-1)%data,2)))
        end do
    end subroutine some_routine
end module some_mod
    """.strip()

    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    routine = module['some_routine']
    assert 'levels%data' in routine.symbol_attrs
    alloc = FindNodes(ir.Allocation).visit(routine.body)[0]
    assert len(alloc.variables) == 1
    shape = alloc.variables[0].dimensions
    assert len(shape) == 2
    for i, dim in enumerate(shape):
        assert isinstance(dim, sym.InlineCall)
        assert str(dim).lower() == f'size(levels(jscale - 1)%data, {i+1})'


def test_expression_container_matching():
    """
    Tests how different expression types match as keys in different
    containers, with use of raw expressions and string equivalence.
    """
    scope = Scope()
    t_real = SymbolAttributes(BasicType.REAL)
    t_int = SymbolAttributes(BasicType.INTEGER)

    i = sym.Variable(name='i', scope=scope, type=t_int)
    a = sym.Variable(name='a', scope=scope, type=t_real)
    b = sym.Variable(name='b', scope=scope, type=t_real, dimensions=(i,))

    # Test for simple containment of scalars
    assert a in (a, b)
    assert a in [a, b]
    assert a in {a, b}
    assert a in {a: b}
    assert a in defaultdict(list, ((a, [b]),))

    # Test for simple containment of scalars against strings
    assert a == 'a'
    assert a in ('a', 'b(i)')
    assert a in ['a', 'b(i)']
    assert a in {'a', 'b(i)'}
    assert a in {'a': 'b(i)'}
    assert a in defaultdict(list, (('a', ['b(i)']),))

    # Test for simple containment of arrays against strings
    assert b == 'b(i)'
    assert b in ('b(i)', 'a')
    assert b in ['b(i)', 'a']
    assert b in {'b(i)', 'a'}
    assert b in {'b(i)': 'a'}
    assert b in defaultdict(list, (('b(i)', ['a']),))

    # Test for simple containment of strings indices against arrays
    assert 'b(i)' in (b, a)
    assert 'b(i)' in [b, a]
    assert 'b(i)' in {b, a}
    assert 'b(i)' in {b: a}
    assert 'b(i)' in defaultdict(list, ((b, [a]),))


@pytest.mark.parametrize('frontend', available_frontends())
def test_expression_c_de_reference(frontend):
    """
    Verify that ```Reference`` and ``Dereference`` work as expected.
    Thus, being ignored by Fortran-like backends but not by C-like
    backends.
    """
    fcode = """
subroutine some_routine()
implicit none
  integer :: var_reference
  integer :: var_dereference

  var_reference = 1
  var_dereference = 2
end subroutine some_routine
    """.strip()

    routine = Subroutine.from_source(fcode, frontend=frontend)
    var_map = {
        routine.variable_map['var_reference']: sym.Reference(routine.variable_map['var_reference']),
        routine.variable_map['var_dereference']: sym.Dereference(routine.variable_map['var_dereference'])
    }
    routine.body = SubstituteExpressions(var_map).visit(routine.body)

    f_str = fgen(routine).replace(' ', '')
    assert 'var_reference=1' in f_str
    assert 'var_dereference=2' in f_str
    assert '*' not in f_str
    assert '&' not in f_str

    c_str = cgen(routine).replace(' ', '')
    assert '(&var_reference)=1' in c_str
    assert '(*var_dereference)=2' in c_str

    # now test processing in mappers (by renaming variables being "De/Referenced")
    var_reference = routine.variable_map['var_reference']
    var_dereference = routine.variable_map['var_dereference']
    var_map = {var_reference: var_reference.clone(name='renamed_var_reference'),
            var_dereference: var_dereference.clone(name='renamed_var_dereference')}
    routine.spec = SubstituteExpressions(var_map).visit(routine.spec)
    routine.body = SubstituteExpressions(var_map).visit(routine.body)

    f_str = fgen(routine).replace(' ', '')
    assert 'renamed_var_reference=1' in f_str
    assert 'renamed_var_dereference=2' in f_str
    assert '*' not in f_str
    assert '&' not in f_str

    c_str = cgen(routine).replace(' ', '')
    assert '(&renamed_var_reference)=1' in c_str
    assert '(*renamed_var_dereference)=2' in c_str


@pytest.mark.parametrize('expr', [
    'a', 'a%b', 'a%b%c', 'a%b%c%d', 'a%b%c%d%e'
])
def test_typebound_resolution(expr):
    """
    Test that type-bound variables can be correctly resolved
    """

    scope = Scope()
    name_parts = expr.split('%', maxsplit=1)
    var = sym.Variable(name=name_parts[0], scope=scope)

    if len(name_parts) > 1:
        var = var.get_derived_type_member(name_parts[1]) # pylint: disable=no-member

    assert var == expr
    assert var.scope == scope


@pytest.mark.parametrize('frontend', available_frontends())
def test_typebound_resolution_type_info(frontend, tmp_path):
    fcode = """
module typebound_resolution_type_info_mod
    use some_mod, only: tt
    implicit none
    type t_a
        logical :: a
    end type t_a

    type t_b
        type(t_a) :: b_a
        integer :: b
    end type t_b

    type t_c
        type(t_b) :: c_b
        real :: c
    end type t_c
contains
    subroutine sub ()
        type(t_c) :: var_c
        type(tt) :: var_tt
    end subroutine sub
end module typebound_resolution_type_info_mod
    """.strip()

    if frontend == OMNI:
        dummy_mod = "module some_mod\ntype tt\nend type\nend module"
        Module.from_source(dummy_mod, frontend=frontend, xmods=[tmp_path])

    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    sub = module['sub']
    var_c = sub.variable_map['var_c']
    var_tt = sub.variable_map['var_tt']

    t_a = module['t_a']
    t_b = module['t_b']

    var_c_to_try = {
        'c': BasicType.REAL,
        'c_b': t_b.dtype,
        'c_b%b': BasicType.INTEGER,
        'c_b%b_a': t_a.dtype,
        'c_b%b_a%a': BasicType.LOGICAL,
    }

    var_tt_to_try = {
        'some': BasicType.DEFERRED,
        'some%member': BasicType.DEFERRED
    }

    # Make sure none of the derived type members exist
    # in the symbol table initially
    for var_name in var_c_to_try:
        assert f'var_c%{var_name}' not in sub.symbol_attrs

    for var_name in var_tt_to_try:
        assert f'var_tt%{var_name}' not in sub.symbol_attrs

    assert 'var_c%c_b%b_a%a' == sub.resolve_typebound_var('var_c%c_b%b_a%a')

    # Create each derived type member and verify its type
    for var_name, dtype in var_c_to_try.items():
        var = var_c.get_derived_type_member(var_name)
        assert var == f'var_c%{var_name}'
        assert var.scope is sub
        assert isinstance(var, sym.Scalar)
        assert var.type.dtype == dtype

    for var_name, dtype in var_tt_to_try.items():
        var = var_tt.get_derived_type_member(var_name)
        assert var == f'var_tt%{var_name}'
        assert var.scope is sub
        assert isinstance(var, sym.DeferredTypeSymbol)
        assert var.type.dtype == dtype


@pytest.mark.parametrize('frontend', available_frontends(
    skip={OMNI: "OMNI fails on missing module"}
))
def test_stmt_func_heuristic(frontend, tmp_path):
    """
    Our Fparser translation has a heuristic to detect statement function declarations,
    but that falsely misinterpreted some assignments as statement functions due to
    missing shape information (reported in #326)
    """
    fcode = """
SUBROUTINE SOME_ROUTINE(YDFIELDS,YDMODEL,YDCST)
USE FIELDS_MOD         , ONLY : FIELDS
USE TYPE_MODEL         , ONLY : MODEL
USE VAR_MOD            , ONLY : ARR, FNAME
IMPLICIT NONE
TYPE(FIELDS)        ,INTENT(INOUT) :: YDFIELDS
TYPE(MODEL)         ,INTENT(IN)    :: YDMODEL
TYPE(TOMCST)        ,INTENT(IN)    :: YDCST
CHARACTER(LEN=20)                  :: CLFILE
REAL                               :: ZALFA
REAL                               :: ZALFAG(3)
REAL                               :: FOEALFA
REAL                               :: PTARE
FOEALFA(PTARE) = MIN(1.0, PTARE)
#include "fcttre.func.h"

ASSOCIATE(YDSURF=>YDFIELDS%YRSURF,RTT=>YDCST%RTT)
ASSOCIATE(SD_VN=>YDSURF%SD_VN,YSD_VN=>YDSURF%YSD_VN, &
 & LEGBRAD=>YDMODEL%YRML_PHY_EC%YREPHY%LEGBRAD)
IF(LEGBRAD)SD_VN(:,YSD_VN%YACCPR5%MP,:)=SD_VN(:,YSD_VN%YACCPR%MP,:)
IF(LEGBRAD)ARR(:,YSD_VN%YACCPR5%MP,:)=SD_VN(:,YSD_VN%YACCPR%MP,:)
CLFILE(1:20)=FNAME
ZALFA=FOEDELTA(RTT)
ZALFAG(1)=FOEDELTA(RTT-1.)
ZALFAG(2)=FOEALFA(RTT)
ZALFAG(3)=FOEALFA(RTT-1.)
END ASSOCIATE
END ASSOCIATE
END SUBROUTINE SOME_ROUTINE
    """.strip()
    source = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    routine = source['some_routine']

    assignments = FindNodes(ir.Assignment).visit(routine.body)

    assert [
        ass.lhs.name.lower() for ass in assignments
    ] == [
        'sd_vn', 'arr', 'clfile', 'zalfa', 'zalfag', 'zalfag', 'zalfag'
    ]

    sd_vn = assignments[0].lhs
    assert isinstance(sd_vn, sym.Array)

    arr = assignments[1].lhs
    assert isinstance(arr, sym.Array)
    assert arr.type.imported

    # FOEDELTA cannot be identified as a statement function due to the declarations
    # hidden in the external header
    assert isinstance(assignments[3].rhs, sym.Array)
    assert isinstance(assignments[4].rhs, sym.Array)

    # FOEALFA should have been identified as a statement function
    stmt_funcs = FindNodes(ir.StatementFunction).visit(routine.ir)
    assert len(stmt_funcs) == 1
    assert stmt_funcs[0].name.lower() == 'foealfa'
    assert isinstance(assignments[5].rhs, sym.InlineCall)
    assert isinstance(assignments[6].rhs, sym.InlineCall)
loki-ecmwf-0.3.6/loki/expression/tests/test_symbolic.py0000664000175000017500000002740615167130205023503 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
A selection of tests for symbolic computations using expression tree nodes.
"""
import itertools
import operator as op
from math import floor

import pytest
import pymbolic.primitives as pmbl

from loki import Scope, is_dimension_constant, Subroutine
from loki.expression import symbols as sym, simplify, Simplification, symbolic_op, parse_expr
from loki.expression import iteration_number, iteration_index, get_pyrange
from loki.expression.evaluation import LokiEvaluationMapper
from loki.frontend import available_frontends


@pytest.mark.parametrize('a, b, lt, eq', [
    (sym.Literal(1),      sym.Literal(0),   False, False),
    (sym.Literal(0),      sym.Literal(1),   True,  False),
    (sym.Literal(1),      sym.Literal(1),   False, True),
    (sym.Literal(-3),     sym.Literal(-1),  True,  False),
    (sym.Literal(-3),     sym.Literal(3),   True,  False),
    (sym.Literal(-2),     sym.Literal(-4),  False, False),
    (sym.Literal(3.0),    sym.Literal(5.0), True,  False),
    (sym.Literal(7.0),    sym.Literal(2.0), False, False),
    (sym.Literal(4.0),    sym.Literal(4.0), False, True),
    (sym.Literal(3.9999), sym.Literal(4.0), True,  False),
    (sym.Literal(2),      sym.Literal(4.0), None,  False),
    (sym.Literal(5.0),    sym.Literal(8),   None,  False),
    (sym.Literal(3.0),    sym.Literal(3),   None,  False),
    (sym.Literal(3),      sym.Literal(3.0), None,  False),
    (sym.Literal(2),      5,                True,  False),
    (sym.Literal(5),      2,                False, False),
    (sym.Literal(1),      3.1,              None,  False),
    (sym.Literal(4),      2.2,              None,  False),
    (3,                   sym.Literal(4),   True,  False),
    (4,                   sym.Literal(3),   False, False),
    (3.1,                 sym.Literal(1),   None,  False),
    (2.2,                 sym.Literal(4),   None,  False),
    (sym.Literal(9.1),    13,               True,  False),
    (sym.Literal(7.4),    9.1,              True,  False),
    (sym.Literal(8.2),    4,                False, False),
    (sym.Literal(6.5),    3.7,              False, False),
    (13,                  sym.Literal(9.1), False, False),
    (9.1,                 sym.Literal(7.4), False, False),
    (4,                   sym.Literal(8.2), True,  False),
    (3.7,                 sym.Literal(6.5), True,  False),
    (sym.Literal(3.1),    3.1,              False, True),
    (3.1,                 sym.Literal(3.1), False, True),
])
def test_symbolic_literal_comparison(a, b, lt, eq):
    """
    Test correct evaluation of ab, a>=b, a==b for literals
    """
    if lt is None:
        with pytest.raises(TypeError):
            _ = a < b
        with pytest.raises(TypeError):
            _ = a <= b
        with pytest.raises(TypeError):
            _ = a > b
        with pytest.raises(TypeError):
            _ = a >= b
    else:
        assert (a < b) is lt
        assert (a <= b) is (lt or eq)
        assert (a > b) is not (lt or eq)
        assert (a >= b) is (not lt)
    assert (a == b) is eq


@pytest.mark.parametrize('a, _op, b, ref', [
    ('1', op.eq, '2', False),
    ('1', op.ne, '2', True),
    ('1', op.lt, '2', True),
    ('1', op.le, '2', True),
    ('1', op.gt, '2', False),
    ('1', op.ge, '2', False),
    ('a', op.eq, 'a', True),
    ('a', op.ne, 'a', False),
    ('a', op.lt, 'a', False),
    ('a', op.le, 'a', True),
    ('a', op.gt, 'a', False),
    ('a', op.ge, 'a', True),
    ('a', op.eq, 'a+1', False),
    ('a', op.ne, 'a+1', True),
    ('a', op.lt, 'a+1', True),
    ('a', op.le, 'a+1', True),
    ('a', op.gt, 'a+1', False),
    ('a', op.ge, 'a+1', False),
    ('a', op.sub, 'a+1', '-1'),
])
def test_symbolic_op(a, _op, b, ref):
    """
    Test correct evaluation of operators on expressions.
    """
    scope = Scope()
    expr_a = parse_expr(a, scope)
    expr_b = parse_expr(b, scope)
    ret = symbolic_op(expr_a, _op, expr_b)
    if isinstance(ret, pmbl.Expression):
        assert simplify(ret) == ref
    else:
        assert ret == ref


@pytest.mark.parametrize('source, ref', [
    ('1 + 1', '1 + 1'),
    ('1 + (2 + (3 + (4 + 5) + 6)) + 7', '1 + 2 + 3 + 4 + 5 + 6 + 7'),
    ('1 - (2 + (3 + (4 - 5) - 6)) - 7', '1 - 2 - 3 - 4 + 5 + 6 - 7'),
    ('1 - (-1 - (-1 - (-1 - (-1 - 1) - 1) - 1) - 1) - 1', '1 + 1 - 1 + 1 - 1 - 1 + 1 - 1 + 1 - 1'),
    ('a + (b - (c + d))', 'a + b - c - d'),
    ('5 * (4 + 3 * (2 + 1) )', '5*4 + 5*3*2 + 5*3'),
    ('5 + a * (3 - b * (2 + c) / 7) * 5 - 4', '5 + a*3*5 - a*b*2*5 / 7 - a*b*c*5 / 7 - 4'),
    ('(((0)))', '0'),
    ('0*0', '0'),
    ('1*1', '1'),
    ('(-1)*(-1)', '1'),
    ('1*(1*(1*1))', '1'),
    ('(6 + 4) / 3', '6 / 3 + 4 / 3'),
    ('6 * (5/3) * 2', '6*5*2 / 3'),
    ('(3 + 4) * (5/3) * 2', '3*5*2 / 3 + 4*5*2 / 3'),
    ('a * (b + c/d) * e', 'a*b*e + a*c*e / d'),
])
def test_simplify_flattened(source, ref):
    scope = Scope()
    expr = parse_expr(source, scope)
    expr = simplify(expr, enabled_simplifications=Simplification.Flatten)
    assert str(expr) == ref


@pytest.mark.parametrize('source, ref', [
    ('1 + 1', '2'),
    ('2 - 1', '1'),
    ('1 - 1', '0'),
    ('0 + 1 - 0 - 1 + 0', '0'),
    ('1 + 1 + 1 + 1', '4'),
    ('1 + 1 - 1 + 1 - 1 + 1', '2'),
    ('(1 + 1) - (1 + 1)', '0'),
    ('5*4', '20'),
    ('-3*7', '-21'),
    ('3*7*0*10', '0'),
    ('1/1', '1'),
    ('0/1', '0'),
    ( '4/2', '2'),
    ('-1/1', '-1'),
    ('7/(-1)', '-7'),
    ('10*a/5', '2*a'),
    ('2*(-2)/(-4)', '1'),
    ('(-8)/4', '-2'),
    ('(5 + 3) * a - 8 * a / 2 + a * ((7 - 1) / 3)', '8*a - 4*a + 2*a')
])
def test_simplify_integer_arithmetic(source, ref):
    scope = Scope()
    expr = parse_expr(source, scope)
    expr = simplify(expr, enabled_simplifications=Simplification.IntegerArithmetic)
    assert str(expr) == ref


@pytest.mark.parametrize('source, ref', [
    ('a + a + a', '3*a'),
    ('2*a + 1*a + a*3', '6*a'),
    ('(a + a)*(b + b)', '2*a*2*b'),
    ('(a + b) + a + b', 'a + b + a + b'),  # We lose the parenthesis but it does not reduce without flattening
    ('a - a', '0'),
    ('(a + a)*(b - b)', '2*a*0'),
    ('3*a + (-2)*a', 'a'),
    ('3*a - 2*a', 'a'),
    ('1*a + 0*a', 'a'),
    ('1*a*b + 0*a*b', '1*a*b + 0*a*b'),  # Note that this does not reduce without flattening
    ( '5*5 + 3*3', '34'),
    ('5 + (-1)', '4'),
    ('(5 + 3) * a - 8 * a / 2 + a * ((7 - 1) / 3)', '8*a - 8*a / 2 + 6 / 3*a')
])
def test_simplify_collect_coefficients(source, ref):
    scope = Scope()
    expr = parse_expr(source, scope)
    expr = simplify(expr, enabled_simplifications=Simplification.CollectCoefficients)
    assert str(expr) == ref


@pytest.mark.parametrize('source, ref', [
    ('1 == 1', 'True'),
    ('2 == 1', 'False'),
    ('1 > 1', 'False'),
    ('1 > 0', 'True'),
    ('1 >= 1', 'True'),
    ('5 >= 1', 'True'),
    ('-1 >= 1', 'False'),
    ('1 != 1', 'False'),
    ('3 != 5', 'True'),
    ('1 < 1', 'False'),
    ('1 < 5', 'True'),
    ('-1 < 10', 'True'),
    ('-1 <= -2', 'False'),
    ('-2 <= -2', 'True'),
    ('-3 <= -2', 'True'),
    ('1 <= 1', 'True'),
    ('4 <= 3', 'False'),
    ('1 + 1 == 2', '1 + 1 == 2'),  # Not true without integer arithmetic
    ('.true. .and. .true.', 'True'),
    ('.true. .and. .false.', 'False'),
    ('.true. .or. .false.', 'True'),
    ('.false. .or. .false.', 'False'),
    ('2 == 1 .and. 1 == 1', 'False'),
    ('2 == 1 .or. 1 == 1', 'True'),
    ('.true. .or. a', 'True'),
    ('.false. .or. a', 'a'),
    ('.false. .and. a', 'False'),
    ('.true. .and. a', 'a'),
])
def test_simplify_logic_evaluation(source, ref):
    scope = Scope()
    expr = parse_expr(source, scope)
    expr = simplify(expr, enabled_simplifications=Simplification.LogicEvaluation)
    assert str(expr) == ref


@pytest.mark.parametrize('source, ref', [
    ('5 * (4 + 3 * (2 + 1) )', '65'),
    ('1 - (-1 - (-1 - (-1 - (-1 - 1) - 1) - 1) - 1) - 1', '0'),
    ('5 + a * (3 - b * (2 + c)) * 5 - 4', '1 + 15*a - 10*a*b - 5*a*b*c'),
    ('(a + b) + a + b', '2*a + 2*b'),
    ('(a+b)*(a+b)', 'a*a + 2*a*b + b*b'),
    ('(a-b)*(a-b)', 'a*a - 2*a*b + b*b'),
    ('-(a+b)*(a-b)', '-a*a + b*b'),
    ('a*a + b*(a - b) - a*(b + a) + b*b', '0'),
    ('0*(a + b - a - b)', '0'),
    ('(a + b) * c - c*a - c*b + 1', '1'),
    ('1*a*b + 0*a*b', 'a*b'),
    ('n+(((-1)*1)*n)', '0'),
    ('5 + a * (3 - b * (2 + c) / 7) * 5 - 4', '1 + 15*a - 10*a*b / 7 - 5*a*b*c / 7'),
    ('(5 + 3) * a - 8 * a / 2 + a * ((7 - 1) / 3)', '6*a'),
    ('(5 + 3) == 8', 'True'),
    ('42 == 666', 'False'),
])
def test_simplify(source,ref):
    scope = Scope()
    expr = parse_expr(source, scope)
    expr = simplify(expr)
    assert str(expr) == ref


@pytest.mark.parametrize('frontend', available_frontends())
def test_is_dimension_constant(frontend):
    fcode = """
    subroutine kernel(a, n)
    implicit none
    integer, parameter :: m = 2
    integer :: k = 3
    integer, intent(in) :: n
    real, intent(inout) :: a(n,m,2:5,n+1:n+5,k)
    end subroutine kernel
    """

    routine = Subroutine.from_source(fcode, frontend)
    is_const = [is_dimension_constant(d) for d in routine.variable_map['a'].shape]

    assert is_const[1]
    assert is_const[2]
    assert is_const[4]
    assert not is_const[0]
    assert not is_const[3]


def test_normalized_loop_range():
    """
    Tests the num_iterations and normalized_loop_range functions.
    """
    for start in range(-10, 11):
        for stop in range(start + 1, 50 + 1, 4):
            for step in range(1, stop - start):
                loop_range = sym.LoopRange((start, stop, step))
                pyrange = range(start, stop + 1, step)

                normalized_range = loop_range.normalized
                assert normalized_range.step is None, "LoopRange.step should be None in a normalized range"

                normalized_start = LokiEvaluationMapper()(normalized_range.start)
                assert normalized_start == 1, "LoopRange.start should be equal to 1 in a normalized range"

                normalized_stop = floor(LokiEvaluationMapper()(normalized_range.stop))
                assert normalized_stop == len(pyrange), \
                    "LoopRange.stop should be equal to the total number of iterations of the original LoopRange"


def test_iteration_number():
    for start in range(-10, 11):
        for stop in range(start + 1, 50, 4):
            for step in itertools.chain([None], range(1, stop - start)):
                loop_range = sym.LoopRange((start, stop, step))
                pyrange = range(start, stop + 1, step) if step is not None else range(start, stop+1)
                normalized_range = get_pyrange(loop_range.normalized)
                assert len(normalized_range) == len(
                    pyrange), "Length of normalized loop range should equal length of python loop range"

                LEM = LokiEvaluationMapper()
                assert all(n == LEM(iteration_number(sym.IntLiteral(i), loop_range)) for i, n in
                           zip(pyrange, normalized_range))


def test_iteration_index():
    for start in range(-10, 11):
        for stop in range(start + 1, 50, 4):
            for step in range(1, stop - start):
                loop_range = sym.LoopRange((start, stop, step))
                pyrange = range(start, stop + 1, step)
                normalized_range = get_pyrange(loop_range.normalized)
                assert len(normalized_range) == len(
                    pyrange), "Length of normalized loop range should equal length of python loop range"

                LEM = LokiEvaluationMapper()
                assert all(i == LEM(iteration_index(sym.IntLiteral(n), loop_range)) for i, n in
                           zip(pyrange, normalized_range))
loki-ecmwf-0.3.6/loki/expression/evaluation.py0000664000175000017500000001632015167130205021621 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import math
import numpy as np
from pymbolic.mapper.evaluator import EvaluationMapper
from loki.expression import symbols as sym
from loki.tools.util import CaseInsensitiveDict, as_tuple

__all__ = ['LokiEvaluationMapper', 'eval_expr']

class LokiEvaluationMapper(EvaluationMapper):
    """
    A mapper for evaluating expressions, based on
    :any:`pymbolic.mapper.evaluator.EvaluationMapper`.

    Parameters
    ----------
    strict : bool
        Raise exception for unknown symbols/expressions (default: `False`).
    """

    @staticmethod
    def case_insensitive_getattr(obj, attr):
        """
        Case-insensitive version of `getattr`.
        """
        for elem in dir(obj):
            if elem.lower() == attr.lower():
                return getattr(obj, elem)
        return getattr(obj, attr)

    def __init__(self, strict=False, **kwargs):
        self.strict = strict
        super().__init__(**kwargs)

    def map_comparison(self, expr):
        import operator # pylint: disable=import-outside-toplevel
        left = self.rec(expr.left)
        right = self.rec(expr.right)
        rel_types = (sym._Literal, float, int)
        if isinstance(left, rel_types) and isinstance(right, rel_types):
            return getattr(operator, expr.operator_to_name[expr.operator])(
                self.rec(expr.left), self.rec(expr.right))
        return sym.Comparison(left=left, operator=expr.operator, right=right)

    def map_logical_and(self, expr):
        children = [self.rec(ch) for ch in expr.children]
        if not all(isinstance(ch, bool) for ch in children):
            new_children = [ch for ch in children if not isinstance(ch, bool) and ch]
            return sym.LogicalAnd(as_tuple(new_children))
        return all(children)

    def map_logic_literal(self, expr):
        return expr.value

    def map_float_literal(self, expr):
        return expr.value
    map_int_literal = map_float_literal

    def map_variable(self, expr):
        from loki.expression.parser import FORTRAN_INTRINSIC_PROCEDURES # pylint: disable=import-outside-toplevel,cyclic-import
        _, obj = self._recurse_parent(expr)
        if obj is not None:
            try:
                return self.case_insensitive_getattr(obj, expr.name.split('%')[-1])
            except: # pylint: disable=bare-except
                return expr
        if expr.name.upper() in FORTRAN_INTRINSIC_PROCEDURES:
            return self.map_call(expr)
        if self.strict:
            return super().map_variable(expr)
        if expr.name in self.context:
            return super().map_variable(expr)
        return expr

    def _recurse_parent(self, expr):
        current_expr = expr
        while hasattr(current_expr, 'parent') and current_expr.parent is not None:
            current_expr = current_expr.parent
            obj = self.rec(current_expr)
            return current_expr, obj
        return expr, None

    @staticmethod
    def _evaluate_array(arr, dims):
        """
        Evaluate arrays by converting to numpy array and
        adapting the dimensions corresponding to the different
        starting index.
        """
        return np.array(arr, order='F').item(*[dim-1 for dim in dims])

    def map_array(self, expr):
        new_dims = as_tuple(self.rec(dim) for dim in expr.dimensions)
        return self.map_call(expr.clone(dimensions=new_dims), name=expr.name.lower(), parameters=new_dims)

    def map_call(self, expr, name=None, parameters=None):
        _, obj = self._recurse_parent(expr)
        if obj is not None:
            try:
                _call = self.case_insensitive_getattr(obj, expr.name.split('%')[-1])
                if callable(_call):
                    return _call(*[self.rec(par) for par in expr.dimensions])
                return self._evaluate_array(_call,
                        [self.rec(par) for par in expr.dimensions])
            except Exception as e:
                if self.strict:
                    raise e
                return expr
        call_name = name or expr.function.name.lower()
        expr_parameters = parameters or expr.parameters
        if call_name == 'min':
            return min(self.rec(par) for par in expr_parameters)
        if call_name == 'max':
            return max(self.rec(par) for par in expr_parameters)
        if call_name == 'modulo':
            args = [self.rec(par) for par in expr_parameters]
            return args[0]%args[1]
        if call_name == 'abs':
            return abs(float([self.rec(par) for par in expr_parameters][0]))
        if call_name == 'int':
            return int(float([self.rec(par) for par in expr_parameters][0]))
        if call_name == 'real':
            return float([self.rec(par) for par in expr_parameters][0])
        if call_name == 'sqrt':
            return math.sqrt(float([self.rec(par) for par in expr_parameters][0]))
        if call_name == 'exp':
            return math.exp(float([self.rec(par) for par in expr_parameters][0]))
        if call_name in self.context:
            if not callable(self.context[call_name]):
                return self._evaluate_array(self.context[call_name],
                        [self.rec(par) for par in expr_parameters])
            kwargs = CaseInsensitiveDict(expr.kw_parameters) if hasattr(expr, 'kw_parameters') else {}
            return self.rec(self.context[call_name](*[self.rec(par) for par in expr_parameters], **kwargs))
        try:
            return super().map_call(expr)
        except: # pylint: disable=bare-except
            return expr

    def map_inline_call(self, expr):
        _, obj = self._recurse_parent(expr.function)
        if obj is not None:
            try:
                kwargs = {
                    k: self.rec(v)
                    for k, v in expr.kw_parameters.items()}
                return self.case_insensitive_getattr(obj,
                        expr.name.split('%')[-1])(*[self.rec(par) for par in expr.parameters], **kwargs)
            except: # pylint: disable=bare-except
                return expr
        return self.map_call(expr, name=expr.name.lower(),
                parameters=as_tuple(self.rec(param) for param in expr.parameters))


def eval_expr(expr, context=None, strict=False):
    """
    Call Loki Evaluation Mapper to evaluate expression(s).

    Parameters
    ----------
    expr : :any:`Expression`
        The expression as a string
    strict : bool, optional
        Whether to raise exception for unknown variables/symbols when
        evaluating an expression (default: `False`)
    context : dict, optional
        Symbol context, defining variables/symbols/procedures to help/support
        evaluating an expression

    Returns
    -------
    :any:`Expression`
        The evaluated expression tree corresponding to the expression
    """
    context = context or {}
    context = CaseInsensitiveDict(context)
    mapper = LokiEvaluationMapper(context=context, strict=strict)
    return mapper(expr)
loki-ecmwf-0.3.6/loki/expression/parser.py0000664000175000017500000004551115167130205020752 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from sys import intern
import re
import pytools.lex
from pymbolic.parser import Parser as ParserBase
from pymbolic.mapper import Mapper
import pymbolic.primitives as pmbl
from pymbolic.parser import (
    _openpar, _closepar, _minus, FinalizedTuple, _PREC_UNARY,
    _PREC_TIMES, _PREC_PLUS, _PREC_CALL, _times, _plus
)
try:
    from fparser.two.Fortran2003 import Intrinsic_Name

    FORTRAN_INTRINSIC_PROCEDURES = Intrinsic_Name.function_names
    """list of intrinsic fortran procedure(s) names"""
except ImportError:
    FORTRAN_INTRINSIC_PROCEDURES = ()

from loki.expression import symbols as sym, operations as sym_ops
from loki.tools.util import CaseInsensitiveDict
from loki.types.scope import Scope

__all__ = ['ExpressionParser', 'parse_expr', 'FORTRAN_INTRINSIC_PROCEDURES']


class PymbolicMapper(Mapper):
    """
    Pymbolic expression to Loki expression mapper.

    Convert pymbolic expressions to Loki expressions.
    """
    # pylint: disable=abstract-method,unused-argument

    def map_product(self, expr, *args, **kwargs):
        children = tuple(self.rec(child, *args, **kwargs) for child in expr.children)
        if isinstance(expr, sym_ops.ParenthesisedMul):
            return sym_ops.ParenthesisedMul(children)
        return sym.Product(children)

    def map_sum(self, expr, *args, **kwargs):
        children = tuple(self.rec(child, *args, **kwargs) for child in expr.children)
        if isinstance(expr, sym_ops.ParenthesisedAdd):
            return sym_ops.ParenthesisedAdd(children)
        return sym.Sum(children)

    def map_power(self, expr, *args, **kwargs):
        base=self.rec(expr.base, *args, **kwargs)
        exponent=self.rec(expr.exponent, *args, **kwargs)
        if isinstance(expr, sym_ops.ParenthesisedPow):
            return sym_ops.ParenthesisedPow(base=base, exponent=exponent)
        return sym.Power(base=base, exponent=exponent)

    def map_quotient(self, expr, *args, **kwargs):
        numerator=self.rec(expr.numerator, *args, **kwargs)
        denominator=self.rec(expr.denominator, *args, **kwargs)
        if isinstance(expr, sym_ops.ParenthesisedDiv):
            return sym_ops.ParenthesisedDiv(numerator=numerator, denominator=denominator)
        return sym.Quotient(numerator=numerator, denominator=denominator)

    def map_comparison(self, expr, *args, **kwargs):
        return sym.Comparison(left=self.rec(expr.left, *args, **kwargs),
                operator=expr.operator,
                right=self.rec(expr.right, *args, **kwargs))

    def map_logical_and(self, expr, *args, **kwargs):
        return sym.LogicalAnd(tuple(self.rec(child, *args, **kwargs) for child in expr.children))

    def map_logical_or(self, expr, *args, **kwargs):
        return sym.LogicalOr(tuple(self.rec(child, *args, **kwargs) for child in expr.children))

    def map_logical_not(self, expr, *args, **kwargs):
        return sym.LogicalNot(self.rec(expr.child, *args, **kwargs))

    def map_constant(self, expr, *args, **kwargs):
        if expr == -1:
            return expr
        if isinstance(expr, (sym.FloatLiteral, sym.IntLiteral, sym.StringLiteral, sym.LogicLiteral)):
            if isinstance(expr, sym.IntLiteral) and expr.value < 0:
                return sym.Product((-1, sym.IntLiteral(abs(expr.value))))
            return expr
        if isinstance(expr, bool):
            return sym.LogicLiteral('true' if expr else 'false')
        return sym.Literal(expr)

    map_logic_literal = map_constant
    map_string_literal = map_constant
    map_intrinsic_literal = map_constant

    map_int_literal = map_constant

    map_float_literal = map_int_literal
    map_variable_symbol = map_constant
    map_deferred_type_symbol = map_constant

    def map_meta_symbol(self, expr, *args, **kwargs):
        return sym.Variable(name=str(expr.name))
    map_scalar = map_meta_symbol
    map_array = map_meta_symbol

    def map_slice(self, expr, *args, **kwargs):
        children = tuple(self.rec(child, *args, **kwargs) if child is not None else child for child in expr.children)
        if len(children) == 1 and children[0] is None:
            # this corresponds to ':' (sym.RangeIndex((None, None)))
            children = (None, None)
        return sym.RangeIndex(children)

    map_range = map_slice
    map_range_index = map_slice
    map_loop_range = map_slice

    def map_variable(self, expr, *args, **kwargs):
        parent = kwargs.pop('parent', None)
        return sym.Variable(name=expr.name, parent=parent)

    def map_algebraic_leaf(self, expr, *args, **kwargs):
        if str(expr).isnumeric():
            return self.map_constant(expr)
        if isinstance(expr, pmbl.Call):
            if expr.function.name.lower() in ('real', 'int'):
                return sym.Cast(expr.function.name, [self.rec(param, *args, **kwargs) for param in expr.parameters][0])
            if expr.function.name.upper() in FORTRAN_INTRINSIC_PROCEDURES:
                return sym.InlineCall(function=sym.Variable(name=expr.function.name),
                        parameters=tuple(self.rec(param, *args, **kwargs) for param in expr.parameters))
            parent = kwargs.pop('parent', None)
            dimensions = tuple(self.rec(param, *args, **kwargs) for param in expr.parameters)
            if not dimensions:
                return sym.InlineCall(function=sym.Variable(name=expr.function.name, parent=parent),
                        parameters=dimensions)
            return sym.Variable(name=expr.function.name, parent=parent,
                    dimensions=tuple(self.rec(param, *args, **kwargs) for param in expr.parameters))
        try:
            return self.map_variable(expr, *args, **kwargs)
        except Exception as e:
            print(f"Exception: {e}")
            return expr

    def map_call_with_kwargs(self, expr, *args, **kwargs):
        parent = kwargs.pop('parent', None)
        name = sym.Variable(name=expr.function.name, parent=parent)
        parameters = tuple(self.rec(param, *args, **kwargs) for param in expr.parameters)
        kw_parameters = {key: self.rec(value, *args, **kwargs) for key, value\
                in CaseInsensitiveDict(expr.kw_parameters).items()}
        if expr.function.name.lower() in ('real', 'int'):
            return sym.Cast(name, parameters, kind=kw_parameters['kind'])
        return sym.InlineCall(function=name, parameters=parameters, kw_parameters=kw_parameters)

    def map_tuple(self, expr, *args, **kwargs):
        return tuple(self.rec(elem, *args, **kwargs) for elem in expr)

    def map_list(self, expr, *args, **kwargs):
        return sym.LiteralList([self.rec(elem, *args, **kwargs) for elem in expr])

    def map_remainder(self, expr, *args, **kwargs):
        # this should never happen as '%' is overwritten to represent derived types
        raise NotImplementedError

    def map_lookup(self, expr, *args, **kwargs):
        # construct derived type(s) variables
        parent = kwargs.pop('parent', None)
        parent = self.rec(expr.aggregate, parent=parent)
        return self.rec(expr.name, parent=parent)


class ExpressionParser(ParserBase):
    """
    String Parser based on :any:`pymbolic.parser.Parser` for
    parsing expressions from strings.

    The Loki String Parser utilises and extends pymbolic's parser to incorporate
    Fortran specific syntax and to map pymbolic expressions to Loki expressions, utilising 
    the mapper :any:`PymbolicMapper`.

    **Further**, in order to ensure correct ordering of Fortran Statements as documented
    in `'WD 1539-1 J3/23-007r1 (Draft Fortran 2023)' `_,
    pymbolic's parsing logic needed to be slightly adapted.

    Pymbolic references:

    * `GitHub: pymbolic `_
    * `pymbolic/parser.py `_
    * `pymbolic's parser documentation `_

    .. note::
       **Example:**
        Using the expression parser and possibly evaluate them

        .. code-block::

            >>> from loki import parse_expr
            >>> # parse numerical expressions
            >>> ir = parse_expr('3 + 2**4')
            >>> ir
            Sum((IntLiteral(3, None), Power(IntLiteral(2, None), IntLiteral(4, None))))
            >>> # or expressions with variables
            >>> ir = parse_expr('a*b')
            >>> ir
            Product((DeferredTypeSymbol('a', None, None, ),\
                    DeferredTypeSymbol('b', None, None, )))
            >>> # and provide a scope e.g, with some routine defining a and b as 'real's
            >>> ir = parse_expr('a*b', scope=routine)
            >>> ir
            Product((Scalar('a', None, None, None), Scalar('b', None, None, None)))
            >>> # further, it is possible to evaluate expressions
            >>> ir = parse_expr('a*b + 1', evaluate=True, context={'a': 2, 'b': 3})
            >>> ir
            >>> IntLiteral(7, None)
            >>> # even with functions implemented in Python
            >>> def add(a, b):
            >>>     return a + b
            >>> ir = parse_expr('a + add(a, b)', evaluate=True, context={'a': 2, 'b': 3, 'add': add})
            >>> ir
            >>> IntLiteral(7, None)

    .. automethod:: __call__
    """

    _f_true = intern("f_true")
    _f_false = intern("f_false")
    _f_lessequal = intern('_f_lessequal')
    _f_less = intern('_f_less')
    _f_greaterequal = intern('_f_greaterequal')
    _f_greater = intern('_f_greater')
    _f_equal = intern('_f_equal')
    _f_notequal = intern('_f_notequal')
    _f_and = intern("and")
    _f_or = intern("or")
    _f_not = intern("not")
    _f_float = intern("f_float")
    _f_int = intern("f_int")
    _f_string = intern("f_string")
    _f_openbracket = intern("openbracket")
    _f_closebracket = intern("closebracket")
    _f_derived_type = intern("dot")

    lex_table = [
            (_f_true, pytools.lex.RE(r"\.true\.", re.IGNORECASE)),
            (_f_false, pytools.lex.RE(r"\.false\.", re.IGNORECASE)),
            (_f_lessequal, pytools.lex.RE(r"\.le\.", re.IGNORECASE)),
            (_f_less, pytools.lex.RE(r"\.lt\.", re.IGNORECASE)),
            (_f_greaterequal, pytools.lex.RE(r"\.ge\.", re.IGNORECASE)),
            (_f_greater, pytools.lex.RE(r"\.gt\.", re.IGNORECASE)),
            (_f_equal, pytools.lex.RE(r"\.eq\.", re.IGNORECASE)),
            (_f_notequal, ("|", pytools.lex.RE(r"\.ne\.", re.IGNORECASE), pytools.lex.RE(r"/=", re.IGNORECASE))),
            (_f_and, pytools.lex.RE(r"\.and\.", re.IGNORECASE)),
            (_f_or, pytools.lex.RE(r"\.or\.", re.IGNORECASE)),
            (_f_not, pytools.lex.RE(r"\.not\.", re.IGNORECASE)),
            (_f_float, ("|", pytools.lex.RE(r"[0-9]+\.[0-9]*([eEdD][+-]?[0-9]+)?(_([\w$]+|[0-9]+))+$", re.IGNORECASE))),
            (_f_int, pytools.lex.RE(r"[0-9]+?(_[a-zA-Z]*)", re.IGNORECASE)),
            (_f_string, ("|", pytools.lex.RE(r'\".*\"', re.IGNORECASE),
                pytools.lex.RE(r"\'.*\'", re.IGNORECASE))),
            (_f_openbracket, pytools.lex.RE(r"\(/")),
            (_f_closebracket, pytools.lex.RE(r"/\)")),
            (_f_derived_type, pytools.lex.RE(r"\%")),
            ] + ParserBase.lex_table
    """
    Extend :any:`pymbolic.parser.Parser.lex_table` to accomodate for Fortran specifix syntax/expressions.
    """

    ParserBase._COMP_TABLE.update({
         _f_lessequal: "<=",
         _f_less: "<",
         _f_greaterequal: ">=",
         _f_greater: ">",
         _f_equal: "==",
         _f_notequal: "!="
         })

    @staticmethod
    def _parenthesise(expr):
        """
        Utility method to parenthesise specific expressions.

        E.g., from :any:`pymbolic.primitives.Sum` to 
        :any:`ParenthesisedAdd`.
        """
        if isinstance(expr, pmbl.Sum):
            return sym_ops.ParenthesisedAdd(expr.children)
        if isinstance(expr, pmbl.Product):
            return sym_ops.ParenthesisedMul(expr.children)
        if isinstance(expr, pmbl.Quotient):
            return sym_ops.ParenthesisedDiv(numerator=expr.numerator,
                    denominator=expr.denominator)
        if isinstance(expr, pmbl.Power):
            return sym_ops.ParenthesisedPow(base=expr.base, exponent=expr.exponent)
        return expr

    def parse_prefix(self, pstate):
        pstate.expect_not_end()

        if pstate.is_next(_minus):
            pstate.advance()
            left_exp = pmbl.Product((-1, self.parse_expression(pstate, _PREC_UNARY)))
            return left_exp
        if pstate.is_next(_openpar):
            pstate.advance()

            if pstate.is_next(_closepar):
                left_exp = ()
            else:
                # This is parsing expressions separated by commas, so it
                # will return a tuple. Kind of the lazy way out.
                left_exp = self.parse_expression(pstate)
                # NECESSARY to ensure correct ordering!
                left_exp = self._parenthesise(left_exp)
            pstate.expect(_closepar)
            pstate.advance()
            if isinstance(left_exp, tuple):
                # These could just be plain parentheses.

                # Finalization prevents things from being appended
                # to containers after their closing delimiter.
                left_exp = FinalizedTuple(left_exp)
            return left_exp
        return super().parse_prefix(pstate)

    def parse_postfix(self, pstate, min_precedence, left_exp):

        did_something = False
        if pstate.is_next(self._f_derived_type) and _PREC_CALL > min_precedence:
            pstate.advance()
            right_exp = self.parse_expression(pstate, _PREC_PLUS)
            left_exp = pmbl.Lookup(left_exp, right_exp)
            did_something = True
        elif pstate.is_next(_times) and _PREC_TIMES > min_precedence:
            pstate.advance()
            right_exp = self.parse_expression(pstate, _PREC_PLUS)
            # NECESSARY to ensure correct ordering!
            # pylint: disable=unidiomatic-typecheck
            if type(right_exp) is pmbl.Quotient:
                left_exp = pmbl.Quotient(numerator=pmbl.Product((left_exp, right_exp.numerator)),
                        denominator=right_exp.denominator)
            # pylint: disable=unidiomatic-typecheck
            elif type(right_exp) is pmbl.Product:
                left_exp = pmbl.Product((sym.Product((left_exp, right_exp.children[0])), right_exp.children[1]))
            else:
                left_exp = pmbl.Product((left_exp, right_exp))
            did_something = True
        elif pstate.is_next(_plus) and _PREC_PLUS > min_precedence:
            pstate.advance()
            right_exp = self.parse_expression(pstate, _PREC_PLUS)
            left_exp = pmbl.Sum((left_exp, right_exp))
            did_something = True
        elif pstate.is_next(_minus) and _PREC_PLUS > min_precedence:
            pstate.advance()
            right_exp = self.parse_expression(pstate, _PREC_PLUS)
            right_exp = pmbl.Product((-1, right_exp))
            left_exp = pmbl.Sum((left_exp, right_exp))
            did_something = True
        else:
            return super().parse_postfix(pstate, min_precedence, left_exp)
        return left_exp, did_something

    def parse_terminal(self, pstate):
        if pstate.is_next(self._f_float):
            return self.parse_f_float(pstate.next_str_and_advance())
        if pstate.is_next(self._f_int):
            return self.parse_f_int(pstate.next_str_and_advance())
        if pstate.is_next(self._f_string):
            return self.parse_f_string(pstate.next_str_and_advance())
        if pstate.is_next(self._f_true):
            assert pstate.next_str_and_advance().lower() == ".true."
            return sym.LogicLiteral('.TRUE.')
        if pstate.is_next(self._f_false):
            assert pstate.next_str_and_advance().lower() == ".false."
            return sym.LogicLiteral('.FALSE.')
        return super().parse_terminal(pstate)

    def __call__(self, expr_str, scope=None, evaluate=False, strict=False, context=None):
        """
        Call Loki String Parser to convert expression(s) represented in a string to Loki expression(s)/IR.

        Parameters
        ----------
        expr_str : str
            The expression as a string
        scope : :any:`Scope`
            The scope to which symbol names inside the expression belong
        evaluate : bool, optional
            Whether to evaluate the expression or not (default: `False`)
        strict : bool, optional
            Whether to raise exception for unknown variables/symbols when
            evaluating an expression (default: `False`)
        context : dict, optional
            Symbol context, defining variables/symbols/procedures to help/support
            evaluating an expression

        Returns
        -------
        :any:`Expression`
            The expression tree corresponding to the expression
        """
        from loki.ir import AttachScopes  # pylint: disable=import-outside-toplevel,cyclic-import
        from loki.expression.evaluation import eval_expr # pylint: disable=import-outside-toplevel,cyclic-import
        result = super().__call__(expr_str)
        ir = PymbolicMapper()(result)
        if evaluate:
            ir = eval_expr(ir, context=context, strict=strict)
        return AttachScopes().visit(ir, scope=scope or Scope())

    def parse_float(self, s):
        """
        Parse float literals.

        Do not cast to float via 'float()' in order to keep the original
        notation, e.g., do not convert 1E-3 to 0.003.
        """
        return sym.FloatLiteral(value=s.replace("d", "e").replace("D", "e"))

    def parse_f_float(self, s):
        """
        Parse "Fortran-style" float literals.

        E.g., ``3.1415_my_real_kind``.
        """
        stripped = s.split('_', 1)
        if len(stripped) == 2:
            return sym.FloatLiteral(value=self.parse_float(stripped[0]), kind=sym.Variable(name=stripped[1].lower()))
        return self.parse_float(stripped[0])

    def parse_f_int(self, s):
        """
        Parse "Fortran-style" int literals.

        E.g., ``1_my_int_kind``.
        """
        stripped = s.split('_', 1)
        value = int(stripped[0].replace("d", "e").replace("D", "e"))
        return sym.IntLiteral(value=value, kind=sym.Variable(name=stripped[1].lower()))

    def parse_f_string(self, s):
        """
        Parse string literals.
        """
        return sym.StringLiteral(s)


parse_expr = ExpressionParser()
"""
An instance of :any:`ExpressionParser` that allows parsing expression strings into a Loki expression tree.
See :any:`ExpressionParser.__call__` for a description of the available arguments.
"""
loki-ecmwf-0.3.6/loki/expression/symbolic.py0000664000175000017500000005777115167130205021312 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from collections import defaultdict
import enum
from functools import reduce
from math import gcd, floor
import operator as _op
import numpy as np
import pymbolic.primitives as pmbl

from loki.expression.mappers import LokiIdentityMapper
from loki.expression.evaluation import LokiEvaluationMapper
import loki.expression.symbols as sym
from loki.tools import as_tuple

__all__ = [
    'is_constant', 'symbolic_op', 'simplify', 'accumulate_polynomial_terms',
    'Simplification', 'SimplifyMapper', 'is_dimension_constant', 'ceil_division',
    'get_pyrange', 'iteration_number', 'iteration_index'
]


def is_minus_prefix(expr):
    """
    Return `True` if the given expression prefixes a nested expression with a minus sign,
    else return `False`.

    It essentially means that `expr == Product((-1, ...))`.
    """
    if isinstance(expr, sym.Product) and expr.children:
        return pmbl.is_zero(expr.children[0]+1)
    return False


def strip_minus_prefix(expr):
    """
    Return the expression without the minus prefix.

    Raises a `ValueError` if the expression is not prefixed by a minus.
    """
    if not is_minus_prefix(expr):
        raise ValueError('Given expression does not have a minus prefix.')
    children = expr.children[1:]
    if len(children) == 1:
        return children[0]
    return sym.Product(as_tuple(children))


def is_constant(expr):
    """
    Return `True` if the given expression reduces to a constant value, else return `False`.
    """
    if is_minus_prefix(expr):
        return is_constant(strip_minus_prefix(expr))
    return pmbl.is_constant(expr)


def is_dimension_constant(d):
    """Establish if a given dimension symbol is a compile-time constant"""
    if isinstance(d, sym.IntLiteral):
        return True

    if isinstance(d, sym.RangeIndex):
        if d.lower:
            return is_dimension_constant(d.lower) and is_dimension_constant(d.upper)
        return is_dimension_constant(d.upper)

    if isinstance(d, sym.Scalar) and isinstance(d.initial , sym.IntLiteral):
        return True

    return False


def symbolic_op(expr1, op, expr2):
    """
    Evaluate `expr1  expr2` (or equivalently, `op(expr1, expr2)`) and
    return the result.

    `op` can be any binary operation such as the rich comparison operators
    from the `operator` library.

    While calling this function largely equivalent to applying the operator
    directly, it is to be understood as a convenience layer that applies,
    depending on the operator, a number of symbolically neutral manipulations.
    Currently, this only applies to comparison operators (such as `eq`, `ne`,
    `lt, `le`, `gt`, `ge`). Since expression nodes do not imply an order,
    such comparisons would fail even if a symbolic meaning can be derived.

    For that reason, these operations are reformulated as the difference
    between the two expressions and compared against `0`. For example:
    ```
    # n < n + 1
    Scalar('n') < Sum((Scalar('n'), IntLiteral(1)))
    ```
    raises a `TypeError` but
    ```
    # n < n + 1
    symbolic_op(Scalar('n'), operator.lt, Sum((Scalar('n'), IntLiteral(1))))
    ```
    returns `True`.

    This is done by transforming this expression into
    ```
    # n - (n + 1) < 0
    Sum((Scalar('n'), Product(-1, Sum((Scalar('n'), IntLiteral(1)))))) < 0
    ```
    and then calling `simplify` on the left hand side to obtain
    ```
    # -1 < 0
    Product(-1, IntLiteral(1)) < 0
    ```
    In combination with stripping the minus prefix this yields the result.
    """
    if op in (_op.eq, _op.ne, _op.lt, _op.le, _op.gt, _op.ge):
        expr1, expr2 = simplify(expr1 - expr2), 0
        if is_minus_prefix(expr1):
            # Strip minus prefix to possibly yield constant expression
            if op in (_op.eq, _op.ne):
                return symbolic_op(strip_minus_prefix(expr1), op, expr2)
            return not symbolic_op(strip_minus_prefix(expr1), op, expr2)
    return op(expr1, expr2)


def distribute_product(expr):
    """
    Flatten (nested) products into a sum of products.

    This converts for example `a * (b + c) * (d + e)` to
    `a * b * d + a * c * d + a * b * e + a * c * e`.
    """
    def _retval(numerator, denominator):
        if not denominator:
            return numerator
        if len(denominator) == 1:
            return sym.Quotient(numerator, denominator[0])
        return sym.Quotient(numerator, sym.Product(as_tuple(denominator)))

    if not isinstance(expr, sym.Product):
        return expr

    queue = list(expr.children)
    denominator = []
    done = [[]]

    while queue:
        item = queue.pop(0)

        if isinstance(item, sym.IntLiteral) and item.value == 1:
            continue

        if isinstance(item, sym.Product):
            # Prepend children to maintain order of operands
            queue = list(item.children) + queue
        elif isinstance(item, sym.Quotient):
            # Enqueue the numerator and save the denominator for later
            queue = [item.numerator] + queue
            denominator += [item.denominator]
        elif isinstance(item, sym.Sum):
            # This is the distribution part
            old_done, done = done, []
            for child in item.children:
                done += [l + [child] for l in old_done]
        else:
            # Some other factor that we simply carry over
            done = [l + [item] for l in done]

    if not done:
        return _retval(sym.IntLiteral(1), denominator)

    # Form the new products, eliminating multiple `-1` in the process
    children = []
    for components in done:
        is_neg = False
        if -1 in components:
            is_neg = sum(1 for v in components if v == -1) % 2 == 1
            components = [v for v in components if v != -1]

        if not components:
            components = sym.IntLiteral(1)
        elif len(components) == 1:
            components = components[0]
        else:
            components = sym.Product(as_tuple(components))

        if is_neg:
            components = sym.Product((-1, components))
        children.append(components)

    if len(children) == 1:
        return _retval(children[0], denominator)
    return _retval(sym.Sum(as_tuple(children)), denominator)


def distribute_quotient(expr):
    """
    Flatten (nested) quotients into a sum of quotients.

    This converts for example `(a/b + c) / d` to `a / (b*d) + c / d`.
    """
    if not isinstance(expr, sym.Quotient):
        return expr

    if is_minus_prefix(expr.numerator):
        q = sym.Quotient(strip_minus_prefix(expr.numerator), expr.denominator)
        return sym.Product((-1, distribute_quotient(q)))

    if is_minus_prefix(expr.denominator):
        q = sym.Quotient(expr.numerator, strip_minus_prefix(expr.denominator))
        return sym.Product((-1, distribute_quotient(q)))

    queue = [expr.numerator]
    done = []

    while queue:
        item = queue.pop(0)

        if isinstance(item, sym.IntLiteral) and item.value == 0:
            continue

        if isinstance(item, sym.Sum):
            # Prepend children to maintain order of operands
            queue = list(item.children) + queue
        elif isinstance(item, sym.Quotient):
            done += [distribute_quotient(sym.Quotient(item.numerator, item.denominator * expr.denominator))]
        else:
            # Convert to a quotient
            done += [sym.Quotient(item, expr.denominator)]

    if not done:
        return sym.IntLiteral(1)
    if len(done) == 1:
        return done[0]
    return sym.Sum(as_tuple(done))


def flatten_expr(expr):
    """
    Flatten an expression by flattening any sub-sums and distributing products and quotients.

    This converts for example `a + (b - (c + d) * e)` to `a + b - c * e - d * e`.

    This is an (enhanced) re-implementation of the original `flattened_sum` routine from
    Pymbolic to account for the Loki-specific expression nodes and expand the flattening to
    distributing products.
    """
    queue = list(as_tuple(expr))
    done = []

    while queue:
        item = queue.pop(0)

        if pmbl.is_zero(item):
            continue

        if isinstance(item, sym.Product):
            item = distribute_product(item)

        if isinstance(item, sym.Quotient):
            item = distribute_quotient(item)

        if isinstance(item, sym.Sum):
            # Prepend children to maintain order of operands
            queue = list(item.children) + queue
        else:
            done.append(item)

    if not done:
        return sym.IntLiteral(0)
    if len(done) == 1:
        return done[0]
    return sym.Sum(as_tuple(done))


def sum_int_literals(expr):
    """
    Sum up the values of all `IntLiteral` in the sum and return the reduced sum.
    """
    def _process(child):
        if isinstance(child, sym.IntLiteral):
            return child.value, None
        if is_minus_prefix(child):
            value, stripped_child = _process(strip_minus_prefix(child))
            if value != 0:
                return -value, stripped_child
        return 0, child

    if not isinstance(expr, sym.Sum):
        return expr

    transformed_components = list(zip(*[_process(child) for child in expr.children]))
    value = sum(transformed_components[0])
    remaining_components = [ch for ch in transformed_components[1] if ch is not None]
    if value != 0:
        remaining_components = [sym.IntLiteral(value)] + remaining_components

    if not remaining_components:
        return sym.IntLiteral(0)
    if len(remaining_components) == 1:
        return remaining_components[0]
    return sym.Sum(as_tuple(remaining_components))


def separate_coefficients(expr):
    """
    Helper routine that separates components of a product into constant coefficients
    and remaining factors.

    :param sym.Product expr: the product comprising constant and non-constant sub-expressions.
    :returns: the constant coefficient and remaining non-constant sub-expressions.
    :rtype: (int, list)
    """
    def _process(child):
        if isinstance(child, (int, np.integer)):
            return child, None
        if isinstance(child, sym.IntLiteral):
            return child.value, None
        if is_minus_prefix(child):
            # We recurse here as products that are only there to change the sign
            # should not introduce a layer in the expression tree.
            value, component = _process(child.children[1])
            return -value, component
        return 1, child

    if isinstance(expr, sym.IntLiteral):
        return expr.value, []
    if not isinstance(expr, sym.Product):
        return 1, [expr]

    if is_minus_prefix(expr):
        value, remaining_components = separate_coefficients(strip_minus_prefix(expr))
        return -value, remaining_components

    transformed_components = list(zip(*[_process(child) for child in expr.children]))
    value = reduce(_op.mul, transformed_components[0], 1)
    remaining_components = [ch for ch in transformed_components[1] if ch is not None]
    return value, remaining_components


def mul_int_literals(expr):
    """
    Multiply all `IntLiteral` in the given `Product` and return the reduced expression.
    """
    if not isinstance(expr, sym.Product):
        return expr

    value, remaining_components = separate_coefficients(expr)
    if value == 0:
        return sym.IntLiteral(0)
    if abs(value) != 1:
        remaining_components = [sym.IntLiteral(abs(value))] + remaining_components

    if not remaining_components:
        ret = sym.IntLiteral(1)
    elif len(remaining_components) == 1:
        ret = remaining_components[0]
    else:
        ret = sym.Product(as_tuple(remaining_components))

    if value < 0:
        return sym.Product((-1, ret))
    return ret


def div_int_literals(expr):
    """
    Reduce fractions where the denominator is a `IntLiteral`.
    """
    if not isinstance(expr, sym.Quotient):
        return expr

    if is_minus_prefix(expr.numerator):
        q = sym.Quotient(strip_minus_prefix(expr.numerator), expr.denominator)
        return sym.Product((-1, div_int_literals(q)))

    if is_minus_prefix(expr.denominator):
        q = sym.Quotient(expr.numerator, strip_minus_prefix(expr.denominator))
        return sym.Product((-1, div_int_literals(q)))

    if not isinstance(expr.denominator, sym.IntLiteral):
        return expr

    if isinstance(expr.numerator, sym.IntLiteral):
        div = gcd(expr.numerator.value, expr.denominator.value)
        numerator = sym.IntLiteral(expr.numerator.value / div)
        denominator = sym.IntLiteral(expr.denominator.value / div)

    elif isinstance(expr.numerator, sym.Product):
        value, remaining_components = separate_coefficients(expr.numerator)
        div = gcd(value, expr.denominator.value)
        numerator = mul_int_literals(sym.Product((sym.IntLiteral(value / div), *remaining_components)))
        denominator = sym.IntLiteral(expr.denominator.value / div)

    else:
        numerator, denominator = expr.numerator, expr.denominator

    if denominator == 1:
        return numerator
    return sym.Quotient(numerator, denominator)


def accumulate_polynomial_terms(expr):
    """
    Collect all occurences of each base and determine the constant coefficient
    in a list of expressions.

    Note that this works for any non-constant sub-expression as "base" for summands and thus
    this can be applied not only to polynomials.

    :param list components: list of expressions, e.g., components of a :py:class:`sym.Sum`.
    :returns: mapping of base and corresponding coefficient
    :rtype: dict
    """
    if isinstance(expr, sym.Sum):
        components = expr.children
    else:
        components = as_tuple(expr)

    summands = defaultdict(int)  # map (base, coefficient) pairs
    for item in components:
        if isinstance(item, sym.Product):
            value, remaining_components = separate_coefficients(item)
            if value == 0:
                continue
            if not remaining_components:
                summands[1] += value
            else:
                # We sort the components using their string representation
                summands[as_tuple(sorted(remaining_components, key=str))] += value
        elif isinstance(item, (int, np.integer)):
            summands[1] += item
        elif isinstance(item, sym.IntLiteral):
            summands[1] += item.value
        else:
            summands[as_tuple(item)] += 1

    return dict(summands)


def collect_coefficients(expr):
    """
    Simplify a polynomial-type expression by combining all occurences of a non-constant
    subexpression into a single summand.

    :param list components: list of expressions, e.g., components of a :py:class:`sym.Sum`.
    :returns: reduced list of expressions.
    :rtype: list
    """
    def _get_coefficient(value):
        if value == 1:
            return []
        if value == -1:
            return [-1]
        if value < 0:
            return [-1, sym.IntLiteral(abs(value))]
        return [sym.IntLiteral(abs(value))]

    summands = accumulate_polynomial_terms(expr)
    components = []

    # Treat the constant part separately to make sure this is flat
    constant = summands.pop(1, 0)
    if constant < 0:
        components += [sym.Product((-1, sym.IntLiteral(abs(constant))))]
    elif constant > 0:
        components += [sym.IntLiteral(constant)]

    # Insert the remaining summands
    for base, factor in summands.items():
        if factor == 0:
            continue
        if factor == 1 and len(base) == 1:
            components.append(base[0])
        else:
            components.append(sym.Product(as_tuple(_get_coefficient(factor) + list(base))))

    if not components:
        return sym.IntLiteral(0)
    if len(components) == 1:
        return components[0]
    return sym.Sum(as_tuple(components))


class Simplification(enum.Flag):
    """
    The selection of available simplification techniques that can be used to simplify expressions.
    Multiple techniques can be combined using bitwise logical operations, for example:
    ```
    Flatten | IntegerArithmetic
    ALL & ~Flatten
    ```

    Attributes:
        Flatten             Flatten sub-sums and distribute products.
        IntegerArithmetic   Perform arithmetic on integer literals (addition and multiplication).
        CollectCoefficients Combine summands as far as possible.
        LogicEvaluation     Resolve logically fully determinate expressions, like ``1 == 1`` or ``1 == 6``
        ALL                 All of the above.
    """
    Flatten = enum.auto()
    IntegerArithmetic = enum.auto()
    CollectCoefficients = enum.auto()
    LogicEvaluation = enum.auto()

    # pylint: disable-next=unsupported-binary-operation
    ALL = Flatten | IntegerArithmetic | CollectCoefficients | LogicEvaluation


class SimplifyMapper(LokiIdentityMapper):
    """
    A mapper that attempts to symbolically simplify an expression.

    It applies all enabled simplifications from `Simplification` to a expression.
    """
    # pylint: disable=abstract-method

    def __init__(self, enabled_simplifications=Simplification.ALL):
        super().__init__()

        self.enabled_simplifications = enabled_simplifications

    def map_sum(self, expr, *args, **kwargs):
        new_expr = sym.Sum(as_tuple([self.rec(child, *args, **kwargs) for child in expr.children]))

        if self.enabled_simplifications & Simplification.Flatten:
            new_expr = flatten_expr(new_expr)

        if self.enabled_simplifications & Simplification.IntegerArithmetic:
            new_expr = sum_int_literals(new_expr)

        if self.enabled_simplifications & Simplification.CollectCoefficients:
            new_expr = collect_coefficients(new_expr)

        if new_expr != expr:
            return self.rec(new_expr, *args, **kwargs)
        return expr

    def map_product(self, expr, *args, **kwargs):
        new_expr = sym.Product(as_tuple([self.rec(child, *args, **kwargs) for child in expr.children]))

        if self.enabled_simplifications & Simplification.Flatten:
            new_expr = flatten_expr(new_expr)

        if self.enabled_simplifications & Simplification.IntegerArithmetic:
            new_expr = mul_int_literals(new_expr)

        if new_expr != expr:
            return self.rec(new_expr, *args, **kwargs)
        return expr

    def map_quotient(self, expr, *args, **kwargs):
        numerator = self.rec(expr.numerator, *args, **kwargs)
        denominator = self.rec(expr.denominator, *args, **kwargs)
        new_expr = sym.Quotient(numerator, denominator)

        if self.enabled_simplifications & Simplification.Flatten:
            new_expr = flatten_expr(new_expr)

        if self.enabled_simplifications & Simplification.IntegerArithmetic:
            new_expr = div_int_literals(new_expr)

        if new_expr != expr:
            return self.rec(new_expr, *args, **kwargs)
        return expr

    map_parenthesised_add = map_sum
    map_parenthesised_mul = map_product
    map_parenthesised_div = map_quotient

    def map_comparison(self, expr, *args, **kwargs):
        def get_constant_value(expr):
            if is_minus_prefix(expr):
                return -1 * strip_minus_prefix(expr).value
            return expr.value

        left = self.rec(expr.left, *args, **kwargs)
        right = self.rec(expr.right, *args, **kwargs)

        op_map = {'==': _op.eq, '>': _op.gt, '>=': _op.ge, '!=': _op.ne,
                '<': _op.lt, '<=': _op.le}
        if self.enabled_simplifications & Simplification.LogicEvaluation:
            if is_constant(left) and is_constant(right):
                left = get_constant_value(left)
                right = get_constant_value(right)
                if op_map[expr.operator](left, right):
                    return sym.LogicLiteral('True')
                return sym.LogicLiteral('False')

        return sym.Comparison(operator=expr.operator, left=left, right=right)

    def map_logical_and(self, expr, *args, **kwargs):
        children = tuple(self.rec(child, *args, **kwargs) for child in expr.children)
        if self.enabled_simplifications & Simplification.LogicEvaluation:
            if any(c == 'False' for c in children):
                return sym.LogicLiteral('False')
            if any(c == 'True' for c in children):
                # Trim all literals and return .true. if all were .true.
                children = tuple(c for c in children if not c == 'True')

        return sym.LogicalAnd(children) if len(children) > 0 else sym.LogicLiteral('True')

    def map_logical_or(self, expr, *args, **kwargs):
        children = tuple(self.rec(child, *args, **kwargs) for child in expr.children)
        if self.enabled_simplifications & Simplification.LogicEvaluation:
            if any(c == 'True' for c in children):
                return sym.LogicLiteral('True')
            if any(c == 'False' for c in children):
                # Trim all literals and return .false. if all were .false.
                children = tuple(c for c in children if not c == 'False')

        return sym.LogicalOr(children) if len(children) > 0 else sym.LogicLiteral('False')


def simplify(expr, enabled_simplifications=Simplification.ALL):
    """
    Simplify the given expression by applying selected simplifications.
    """
    return SimplifyMapper(enabled_simplifications=enabled_simplifications)(expr)


def ceil_division(iexpr1: pmbl.Expression, iexpr2: pmbl.Expression) -> pmbl.Expression:
    """
    Returns ceiled division expression of two integer expressions iexpr1/iexpr2.
    """
    expr = sym.Sum(children=(sym.Quotient(numerator=sym.Sum(children=(iexpr1, sym.IntLiteral(-1))),
                                          denominator=iexpr2),
                             sym.IntLiteral(1)))
    return simplify(expr, enabled_simplifications=Simplification.IntegerArithmetic)


def get_pyrange(loop_range: sym.LoopRange):
    """
    Returns a python range corresponding to a LoopRange of IntLiterals.
    """
    LEM = LokiEvaluationMapper()
    if loop_range.step is None:
        return range(LEM(loop_range.start), floor(LEM(loop_range.stop))+1)
    return range(LEM(loop_range.start), floor(LEM(loop_range.stop))+1, LEM(loop_range.step))



def iteration_number(iter_idx, loop_range: sym.LoopRange) -> pmbl.Expression:
    """
    Returns the normalized iteration number of the iteration variable

    Given the loop iteration index for an iteration in a loop defined by the
    :any:´LoopRange´ this method returns the normalized iteration index given by
    iter_num = (iter_idx - start + step)/step = (iter_idx-start)/step + 1

    Parameters
    ----------
    iter_idx : :any:`Variable`, :any:`Expression`, or :any:`IntLiteral`
        corresponding to a valid iteration index for the parameter `loop_range`
    loop_range: :any:`LoopRange`
    """
    if loop_range.step is None:
        expr = sym.Sum((sym.Sum((iter_idx, -loop_range.start)), sym.IntLiteral(1)))

    else:
        expr = sym.Sum(
            (sym.Quotient(sym.Sum((iter_idx, -loop_range.start)), loop_range.step),
             sym.IntLiteral(1)))
    return simplify(expr, enabled_simplifications=Simplification.IntegerArithmetic)


def iteration_index(iter_num, loop_range: sym.LoopRange) -> pmbl.Expression:
    """
    Returns the iteration index of the loop based on the iteration number

    Given the normalized iteration number for an iteration in a loop defined by the
    :any:´LoopRange´ this method returns the iteration index given by
    iter_idx = (iter_num-1)*step+start

    Parameters
    ----------
    iter_num : :any:`Variable`, :any:`Expression`, or :any:`sym.IntLiteral`
        corresponding to a valid iteration number for the parameter `loop_range`
    loop_range: :any:`LoopRange`
    """
    if loop_range.step is None:
        expr = sym.Sum((iter_num, sym.IntLiteral(-1), loop_range.start))

    else:
        expr = sym.Sum((sym.Product((sym.Sum((iter_num, sym.IntLiteral(-1))), loop_range.step)),
                    loop_range.start))
    return simplify(expr, enabled_simplifications=Simplification.IntegerArithmetic)
loki-ecmwf-0.3.6/loki/expression/symbols.py0000664000175000017500000011753215167130205021151 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

# pylint: disable=too-many-lines

"""
Expression tree node classes for
:ref:`internal_representation:Expression tree`.
"""

from itertools import chain
import weakref
from sys import intern
import pymbolic.primitives as pmbl

from loki.tools import as_tuple, CaseInsensitiveDict
from loki.types import BasicType, DerivedType, ProcedureType, SymbolAttributes, Scope
from loki.config import config

from loki.expression.literals import (
    _Literal, FloatLiteral, IntLiteral, LogicLiteral, StringLiteral,
    IntrinsicLiteral, Literal, LiteralList,
)
from loki.expression.mixins import StrCompareMixin
from loki.expression.operations import (
    Sum, Product, Quotient, Power, Comparison, LogicalAnd,
    LogicalOr, LogicalNot, StringConcat, Cast, Reference, Dereference
)


__all__ = [
    # Typed leaf nodes
    'TypedSymbol', 'DeferredTypeSymbol', 'VariableSymbol', 'ProcedureSymbol', 'DerivedTypeSymbol',
    'MetaSymbol', 'Scalar', 'Array', 'Variable',
    # Internal nodes
    'InlineCall', 'InlineDo', 'Range', 'LoopRange', 'RangeIndex',
    'ArraySubscript', 'StringSubscript',
    # Literals (imported but exposed here for convenience)
    '_Literal', 'FloatLiteral', 'IntLiteral', 'LogicLiteral',
    'StringLiteral', 'IntrinsicLiteral', 'Literal', 'LiteralList',
    # Operations (imported but exposed here for convenience)
    'Sum', 'Product', 'Quotient', 'Power', 'Comparison', 'LogicalAnd', 'LogicalOr',
    'LogicalNot', 'StringConcat', 'Cast', 'Reference', 'Dereference',
]


class TypedSymbol:
    """
    Base class for symbols that carry type information.

    :class:`TypedSymbol` can be associated with a specific :any:`Scope` in
    which it is declared. In that case, all type information is cached in that
    scope's :any:`SymbolTable`. Creating :class:`TypedSymbol` without attaching
    it to a scope stores the type information locally.

    .. note::
        Providing :attr:`scope` and :attr:`type` overwrites the corresponding
        entry in the scope's symbol table. To not modify the type information
        omit :attr:`type` or use ``type=None``.

    Objects should always be created via the factory class :any:`Variable`.


    Parameters
    ----------
    name : str
        The identifier of that symbol (e.g., variable name).
    scope : :any:`Scope`
        The scope in which that symbol is declared.
    type : optional
        The type of that symbol. Defaults to :any:`BasicType.DEFERRED`.
    parent : :any:`Scalar` or :any:`Array`, optional
        The derived type variable this variable belongs to.
    case_sensitive : bool, optional
        Mark the name of this symbol as case-sensitive (default: `False`)
    *args : optional
        Any other positional arguments for other parent classes
    **kwargs : optional
        Any other keyword arguments for other parent classes
    """

    init_arg_names = ('name', 'scope', 'parent', 'type', 'case_sensitive', )

    def __init__(self, *args, **kwargs):
        self.name = kwargs['name']
        self.parent = kwargs.pop('parent', None)
        self.scope = kwargs.pop('scope', None)
        self.case_sensitive = kwargs.pop('case_sensitive', config['case-sensitive'])

        # Use provided type or try to determine from scope
        self._type = None
        self.type = kwargs.pop('type', None) or self.type

        super().__init__(*args, **kwargs)

    @property
    def name(self):
        if self.parent:
            return f'{self.parent.name}%{self._name}'
        return self._name

    @name.setter
    def name(self, name):
        self._name = name.split('%')[-1]

    def __getinitargs__(self):
        """
        Fixed tuple of initialisation arguments, corresponding to
        ``init_arg_names`` above.

        Note that this defines the pickling behaviour of pymbolic
        symbol objects. We do not recurse here, since we own the
        "name" attribute, which pymbolic will otherwise replicate.
        """
        return (self.name, None, self._parent, self._type, self.case_sensitive, )

    @property
    def scope(self):
        """
        The object corresponding to the symbol's scope.
        """
        if self._scope is None:
            return None
        return self._scope()

    @scope.setter
    def scope(self, scope):
        assert scope is None or isinstance(scope, Scope)
        self._scope = None if scope is None else weakref.ref(scope)

    def _lookup_type(self, scope):
        """
        Helper method to look-up type information in any :data:`scope`

        Note that this is useful when trying to discover type information
        without putting the variable in :data:`scope` first. Combined with
        the recursive lookup of type information via the parent, this allows
        e.g. to distinguish between procedure calls and array subscripts for
        ambiguous derived type components.
        """
        _type = scope.symbol_attrs.lookup(self.name)
        if _type and _type.dtype is not BasicType.DEFERRED:
            # We have a clean entry in the symbol table which is not deferred
            return _type

        # Try a look-up via parent
        if self.parent:
            tdef_var = self.parent.variable_map.get(self.basename)
            if not tdef_var and self.parent.scope is not scope:
                # If the parent isn't delivering straight away (may happen e.g. for nested derived types)
                # we'll try discovering its parent's type via the provided scope
                parent = self._lookup_parent(scope)
                if parent:
                    tdef_var = parent.variable_map.get(self.basename)
            if tdef_var:
                return tdef_var.type

        return _type

    def _lookup_parent(self, scope):
        """
        Helper method to look-up parent variable using provided :data:`scope`
        """
        # Start at the root, i.e. the declared derived type object
        parent_name = self.name_parts[0]
        parent_type = scope.symbol_attrs.lookup(parent_name)
        parent_var = Variable(name=parent_name, scope=scope, type=parent_type)
        # Walk through nested derived types (if any)...
        for name in self.name_parts[1:-1]:
            if not parent_var:
                # If the look-up fails somewhere we have to bail out
                return None
            parent_var = parent_var.variable_map.get(name)  # pylint: disable=no-member
        # ...until we are at the actual parent
        return parent_var

    @property
    def type(self):
        """
        Internal representation of the declared data type.
        """
        if self.scope is None:
            return self._type
        return self._lookup_type(self.scope)

    @type.setter
    def type(self, _type):
        """
        Update the stored type information
        """
        if self._scope is None:
            # Store locally if not attached to a scope
            self._type = _type
        elif _type is None:
            # Store deferred type if unknown
            self.scope.symbol_attrs[self.name] = SymbolAttributes(BasicType.DEFERRED)
        elif _type is not self.scope.symbol_attrs.lookup(self.name):
            # Update type if it differs from stored type
            self.scope.symbol_attrs[self.name] = _type

    @property
    def parent(self):
        """
        Parent variable for derived type members

        Returns
        -------
        :any:`TypedSymbol` or :any:`MetaSymbol` or `NoneType`
            The parent variable or None
        """
        return self._parent

    @parent.setter
    def parent(self, parent):
        assert parent is None or isinstance(parent, (TypedSymbol, MetaSymbol,
            Reference, Dereference))
        self._parent = parent

    @property
    def parents(self):
        """
        Variables nodes for all parents

        Returns
        -------
        tuple
            The list of parent variables, e.g., for a variable ``a%b%c%d`` this
            yields the nodes corresponding to ``(a, a%b, a%b%c)``
        """
        parent = self.parent
        if parent:
            return parent.parents + (parent,)
        return ()

    @property
    def variables(self):
        """
        List of member variables in a derived type

        Returns
        -------
        tuple of :any:`TypedSymbol` or :any:`MetaSymbol` if derived type variable, else `None`
            List of member variables in a derived type
        """
        _type = self.type
        if _type and isinstance(_type.dtype, DerivedType):
            if _type.dtype.typedef is BasicType.DEFERRED:
                return ()
            return tuple(
                v.clone(name=f'{self.name}%{v.name}', scope=self.scope, type=v.type, parent=self)
                for v in _type.dtype.typedef.variables
            )
        return None

    @property
    def variable_map(self):
        """
        Member variables in a derived type variable as a map

        Returns
        -------
        dict of (str, :any:`TypedSymbol` or :any:`MetaSymbol`)
            Map of member variable basenames to variable objects
        """
        return CaseInsensitiveDict((v.basename, v) for v in self.variables or ())

    @property
    def basename(self):
        """
        The symbol name without the qualifier from the parent.
        """
        return self._name

    @property
    def name_parts(self):
        """
        All name parts with parent qualifiers separated
        """
        if self.parent:
            return self.parent.name_parts + [self.basename]
        return [self.basename]

    def clone(self, **kwargs):
        """
        Replicate the object with the provided overrides.
        """
        # Add existing meta-info to the clone arguments, only if we have them.
        if 'name' not in kwargs and self.name:
            kwargs['name'] = self.name
        if 'scope' not in kwargs and self.scope:
            kwargs['scope'] = self.scope
        if 'type' not in kwargs:
            # If no type is given, check new scope
            if 'scope' in kwargs and kwargs['scope'] and kwargs['name'] in kwargs['scope'].symbol_attrs:
                kwargs['type'] = kwargs['scope'].symbol_attrs[kwargs['name']]
            else:
                kwargs['type'] = self.type
        if 'parent' not in kwargs and self.parent:
            kwargs['parent'] = self.parent
        if 'case_sensitive' not in kwargs and self.case_sensitive:
            kwargs['case_sensitive'] = self.case_sensitive

        return Variable(**kwargs)

    def rescope(self, scope):
        """
        Replicate the object with a new scope

        This is a bespoke variant of :meth:`clone` for rescoping
        symbols. The difference lies in the handling of the
        type information, making sure not to overwrite any existing
        symbol table entry in the provided scope.
        """
        if self.type:
            existing_type = self._lookup_type(scope)
            if existing_type:
                return self.clone(scope=scope, type=existing_type)
        return self.clone(scope=scope)

    def get_derived_type_member(self, name_str):
        """
        Resolve type-bound variables of arbitrary nested depth.
        """
        name_parts = name_str.split('%', maxsplit=1)
        if self.type.dtype is not BasicType.DEFERRED and self.type.dtype.typedef is not BasicType.DEFERRED:
            assert self.type.dtype.typedef.variable_map[name_parts[0]]
        declared_var = Variable(name=f'{self.name}%{name_parts[0]}', scope=self.scope, parent=self)
        if len(name_parts) > 1:
            return declared_var.get_derived_type_member(name_parts[1])  # pylint:disable=no-member
        return declared_var


class DeferredTypeSymbol(StrCompareMixin, TypedSymbol, pmbl.Variable):  # pylint: disable=too-many-ancestors
    """
    Internal representation of symbols with deferred type

    This is used, for example, in the symbol list of :any:`Import` if a symbol's
    definition is not available.

    Note that symbols with deferred type are assumed to be variables, which
    implies they are included in the result from visitors such as
    :any:`FindVariables`.

    Parameters
    ----------
    name : str
        The name of the symbol
    scope : :any:`Scope`
        The scope in which the symbol is declared
    """

    def __init__(self, name, scope=None, **kwargs):
        if kwargs.get('type') is None:
            kwargs['type'] = SymbolAttributes(BasicType.DEFERRED)
        assert kwargs['type'].dtype is BasicType.DEFERRED
        super().__init__(name=name, scope=scope, **kwargs)

    mapper_method = intern('map_deferred_type_symbol')


class VariableSymbol(StrCompareMixin, TypedSymbol, pmbl.Variable):  # pylint: disable=too-many-ancestors
    """
    Expression node to represent a variable symbol

    Note that this node should not be used directly to represent variables
    but instead meta nodes :any:`Scalar` or :any:`Array` (via their factory
    :any:`Variable`) should be used.

    The purpose of this is to align Loki's "convenience layer" for expressions
    with Pymbolic's expression tree structure. Loki makes variable use
    (especially for arrays) with or without properties (such as subscript
    dimensions) directly accessible from a single object, whereas Pymbolic
    represents array subscripts as an operation applied to a variable.

    Furthermore, it adds type information via :any:`TypedSymbol`.

    Parameters
    ----------
    name : str
        The name of the variable.
    scope : :any:`Scope`, optional
        The scope in which the variable is declared.
    type : :any:`SymbolAttributes`, optional
        The type of that symbol. Defaults to :any:`SymbolAttributes` with
        :any:`BasicType.DEFERRED`.
    """

    @property
    def initial(self):
        """
        Initial value of the variable in declaration.
        """
        return self.type.initial

    @initial.setter
    def initial(self, value):
        self.type.initial = value

    mapper_method = intern('map_variable_symbol')


class _FunctionSymbol(pmbl.FunctionSymbol):
    """
    Adapter class for :any:`pymbolic.primitives.FunctionSymbol` that intercepts
    constructor arguments

    This is needed since the original symbol does not like having a :data:`name`
    parameter handed down in the constructor.
    """

    def __init__(self, *args, **kwargs):  # pylint:disable=unused-argument
        super().__init__()


class ProcedureSymbol(StrCompareMixin, TypedSymbol, _FunctionSymbol):  # pylint: disable=too-many-ancestors
    """
    Internal representation of a symbol that represents a callable
    subroutine or function

    Parameters
    ----------
    name : str
        The name of the symbol.
    scope : :any:`Scope`
        The scope in which the symbol is declared.
    type : optional
        The type of that symbol. Defaults to :any:`BasicType.DEFERRED`.
    """

    def __init__(self, name, scope=None, type=None, **kwargs):
        # pylint: disable=redefined-builtin
        assert type is None or isinstance(type.dtype, ProcedureType) or \
                (isinstance(type.dtype, DerivedType) and name.lower() == type.dtype.name.lower())
        super().__init__(name=name, scope=scope, type=type, **kwargs)

    mapper_method = intern('map_procedure_symbol')


class DerivedTypeSymbol(StrCompareMixin, TypedSymbol, _FunctionSymbol):
    """
    Internal representation of a symbol that represents a named
    derived type.

    This is used to represent the derived type symbolically in
    :any:`Import` statements and when defining derived types.

    Parameters
    ----------
    name : str
        The name of the symbol.
    scope : :any:`Scope`
        The scope in which the symbol is declared.
    type : optional
        The type of that symbol. Defaults to :any:`BasicType.DEFERRED`.
    """

    def __init__(self, name, scope=None, type=None, **kwargs):
        # pylint: disable=redefined-builtin
        assert type is None or isinstance(type.dtype, DerivedType)
        if type is not None:
            assert name.lower() == type.dtype.name.lower()
        super().__init__(name=name, scope=scope, type=type, **kwargs)

    mapper_method = intern('map_derived_type_symbol')


class MetaSymbol(StrCompareMixin, pmbl.AlgebraicLeaf):
    """
    Base class for meta symbols to encapsulate a symbol node with optional
    enclosing operations in a unifying interface

    The motivation for this class is that Loki strives to make variables
    and their use accessible via uniform interfaces :any:`Scalar` or
    :any:`Array`. Pymbolic's representation of array subscripts or access
    to members of a derived type are represented as operations on a symbol,
    thus resulting in a inside-out view that has the symbol innermost.

    To make it more convenient to find symbols and apply transformations on
    them, Loki wraps these compositions of expression tree nodes into meta
    nodes that store these compositions and provide direct access to properties
    of the contained nodes from a single object.

    In the simplest case, an instance of a :any:`TypedSymbol` subclass is
    stored as :attr:`symbol` and accessible via this property. Typical
    properties of this symbol (such as :attr:`name`, :attr:`type`, etc.) are
    directly accessible as properties that are redirected to the actual symbol.

    For arrays, not just the :any:`TypedSymbol` subclass but also an enclosing
    :any:`ArraySubscript` may be stored inside the meta symbol, providing
    additionally access to the subscript dimensions. The properties are then
    dynamically redirected to the relevant expression tree node.
    """

    def __init__(self, symbol, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._symbol = symbol

    def __getstate__(self):
        return self._symbol

    def __setstate__(self, state):
        self._symbol = state

    @property
    def symbol(self):
        """
        The underlying :any:`TypedSymbol` node encapsulated by this meta node
        """
        return self._symbol

    @property
    def name(self):
        """
        The fully qualifying symbol name

        For derived type members this yields parent and basename
        """
        return self.symbol.name

    @property
    def basename(self):
        """
        For derived type members this yields the declared member name without
        the parent's name
        """
        return self.symbol.basename

    @property
    def name_parts(self):
        return self.symbol.name_parts

    @property
    def parent(self):
        """
        For derived type members this yields the declared parent symbol to
        which it belongs
        """
        return self.symbol.parent

    @property
    def parents(self):
        """
        Yield all parent symbols for derived type members
        """
        return self.symbol.parents

    @property
    def scope(self):
        """
        The scope in which the symbol was declared

        Note: for imported symbols this refers to the scope into which it is
        imported, _not_ where it was declared.
        """
        return self.symbol.scope

    @property
    def type(self):
        """
        The :any:`SymbolAttributes` declared for this symbol

        This includes data type as well as additional properties, such as
        ``INTENT``, ``KIND`` etc.
        """
        return self.symbol.type

    @type.setter
    def type(self, _type):
        """
        Update the :any:`SymbolAttributes` declared for this symbol
        """
        self.symbol.type = _type

    @property
    def variables(self):
        """
        List of member variables in a derived type

        Returns
        -------
        tuple of :any:`TypedSymbol` or :any:`MetaSymbol` if derived type variable, else `None`
            List of member variables in a derived type
        """
        return self.symbol.variables

    @property
    def variable_map(self):
        """
        Member variables in a derived type variable as a map

        Returns
        -------
        dict of (str, :any:`TypedSymbol` or :any:`MetaSymbol`)
            Map of member variable basenames to variable objects
        """
        return self.symbol.variable_map

    @property
    def initial(self):
        """
        Initial value of the variable in a declaration, if given
        """
        return self.type.initial

    mapper_method = intern('map_meta_symbol')

    def __getinitargs__(self):
        return self.symbol.__getinitargs__()

    @property
    def init_arg_names(self):
        return self.symbol.init_arg_names

    def _lookup_type(self, scope):
        """
        Helper method to look-up type information in any :data:`scope`
        """
        return self.symbol._lookup_type(scope)

    def clone(self, **kwargs):
        """
        Replicate the object with the provided overrides.
        """
        return self.symbol.clone(**kwargs)

    def rescope(self, scope):
        """
        Replicate the object with a new scope

        This is a bespoke variant of :meth:`clone` for rescoping
        symbols. The difference lies in the handling of the
        type information, making sure not to overwrite any existing
        symbol table entry in the provided scope.
        """
        return self.symbol.rescope(scope)

    @property
    def case_sensitive(self):
        """
        Property to indicate that the name of this symbol is case-sensitive.
        """
        return self.symbol.case_sensitive

    def get_derived_type_member(self, name_str):
        """
        Resolve type-bound variables of arbitrary nested depth.
        """

        return self.symbol.get_derived_type_member(name_str)


class Scalar(MetaSymbol):  # pylint: disable=too-many-ancestors
    """
    Expression node for scalar variables.

    See :any:`MetaSymbol` for a description of meta symbols.

    Parameters
    ----------
    name : str
        The name of the variable.
    scope : :any:`Scope`
        The scope in which the variable is declared.
    type : optional
        The type of that symbol. Defaults to :any:`BasicType.DEFERRED`.
    """

    def __init__(self, name, scope=None, type=None, **kwargs):
        # Stop complaints about `type` in this function
        # pylint: disable=redefined-builtin
        symbol = VariableSymbol(name=name, scope=scope, type=type, **kwargs)
        super().__init__(symbol=symbol)

    mapper_method = intern('map_scalar')


class Array(MetaSymbol):
    """
    Expression node for array variables.

    Similar to :any:`Scalar` with the notable difference that it has
    a shape (stored in :data:`type`) and can have associated
    :data:`dimensions` (i.e., the array subscript for indexing/slicing
    when accessing entries).

    See :any:`MetaSymbol` for a description of meta symbols.

    Parameters
    ----------
    name : str
        The name of the variable.
    scope : :any:`Scope`
        The scope in which the variable is declared.
    type : optional
        The type of that symbol. Defaults to :any:`BasicType.DEFERRED`.
    dimensions : :any:`ArraySubscript`, optional
        The array subscript expression.
    """

    def __init__(self, name, scope=None, type=None, dimensions=None, **kwargs):
        # Stop complaints about `type` in this function
        # pylint: disable=redefined-builtin
        symbol = VariableSymbol(name=name, scope=scope, type=type, **kwargs)
        if dimensions:
            symbol = ArraySubscript(symbol, dimensions)
        super().__init__(symbol=symbol)

    @property
    def name_parts(self):
        return self.symbol.name_parts

    @property
    def symbol(self):
        if isinstance(self._symbol, ArraySubscript):
            return self._symbol.aggregate
        return self._symbol

    @property
    def dimensions(self):
        """
        Symbolic representation of the dimensions or indices.
        """
        if isinstance(self._symbol, ArraySubscript):
            return self._symbol.index_tuple
        return ()

    @property
    def shape(self):
        """
        Original allocated shape of the variable as a tuple of dimensions.
        """
        return self.type.shape

    def __getinitargs__(self):
        return super().__getinitargs__() + (self.dimensions, )

    @property
    def init_arg_names(self):
        return super().init_arg_names + ('dimensions', )

    mapper_method = intern('map_array')

    def clone(self, **kwargs):
        """
        Replicate the :class:`Array` variable with the provided overrides.

        Note, if :data:`dimensions` is set to ``None`` and :data:`type` updated
        to have no shape, this will create a :any:`Scalar` variable.
        """
        # Add existing meta-info to the clone arguments, only if we have them.
        if self.dimensions and 'dimensions' not in kwargs:
            kwargs['dimensions'] = self.dimensions
        return super().clone(**kwargs)

    def rescope(self, scope):
        """
        Replicate the object with a new scope

        This is a bespoke variant of :meth:`clone` for rescoping
        symbols. The difference lies in the handling of the
        type information, making sure not to overwrite any existing
        symbol table entry in the provided scope.
        """
        if self.type:
            existing_type = scope.symbol_attrs.lookup(self.name)
            if existing_type:
                return self.clone(scope=scope, type=existing_type, dimensions=self.dimensions)
        return self.clone(scope=scope, dimensions=self.dimensions)


class Variable:
    """
    Factory class for :any:`TypedSymbol` or :any:`MetaSymbol` classes

    This is a convenience constructor to provide a uniform interface for
    instantiating different symbol types. It checks the symbol's type
    (either the provided :data:`type` or via a lookup in :data:`scope`)
    and :data:`dimensions` and dispatches the relevant class constructor.

    The tier algorithm is as follows:

    1. `type.dtype` is :any:`ProcedureType`: Instantiate a
       :any:`ProcedureSymbol`;
    2. :data:`dimensions` is not `None` or `type.shape` is not `None`:
       Instantiate an :any:`Array`;
    3. `type.dtype` is not :any:`BasicType.DEFERRED`:
       Instantiate a :any:`Scalar`;
    4. None of the above: Instantiate a :any:`DeferredTypeSymbol`

    All objects created by this factory implement :class:`TypedSymbol`. A
    :class:`TypedSymbol` object can be associated with a specific :any:`Scope` in
    which it is declared. In that case, all type information is cached in that
    scope's :any:`SymbolTable`. Creating :class:`TypedSymbol` without attaching
    it to a scope stores the type information locally.

    .. note::

        Providing :attr:`scope` and :attr:`type` overwrites the corresponding
        entry in the scope's symbol table. To not modify the type information
        omit :attr:`type` or use ``type=None``.

    Note that all :class:`TypedSymbol` and :class:`MetaSymbol` classes are
    intentionally quasi-immutable:
    Changing any of their attributes, including attaching them to a scope and
    modifying their type, should always be done via the :meth:`clone` method:

    .. code-block::

        var = Variable(name='foo')
        var = var.clone(scope=scope, type=SymbolAttributes(BasicType.INTEGER))
        var = var.clone(type=var.type.clone(dtype=BasicType.REAL))

    Attaching a symbol to a new scope without updating any stored type information
    (but still inserting type information if it doesn't exist, yet), can be done
    via the dedicated :meth:`rescope` method. This is essentially a :meth:`clone`
    invocation but without the type update:

    .. code-block::

        var = Variable(name='foo', type=SymbolAttributes(BasicType.INTEGER), scope=scope)
        unscoped_var = Variable(name='foo', type=SymbolAttributes(BasicType.REAL))
        scoped_var = unscoped_var.rescope(scope)  # scoped_var will have INTEGER type

    Parameters
    ----------
    name : str
        The name of the variable.
    scope : :any:`Scope`
        The scope in which the variable is declared.
    type : optional
        The type of that symbol. Defaults to :any:`BasicType.DEFERRED`.
    parent : :any:`Scalar` or :any:`Array`, optional
        The derived type variable this variable belongs to.
    dimensions : :any:`ArraySubscript`, optional
        The array subscript expression.
    """

    def __new__(cls, **kwargs):
        name = kwargs['name']
        scope = kwargs.get('scope')
        _type = kwargs.get('type')

        if scope is not None and _type is None:
            # Determine type information from scope if not provided explicitly
            _type = cls._get_type_from_scope(name, scope, kwargs.get('parent'))
        kwargs['type'] = _type

        if _type and isinstance(_type.dtype, ProcedureType):
            # This is the name in a function/subroutine call
            return ProcedureSymbol(**kwargs)

        if _type and isinstance(_type.dtype, DerivedType) and name.lower() == _type.dtype.name.lower():
            # This the name of a derived type, as found in USE import statements
            return DerivedTypeSymbol(**kwargs)

        if 'dimensions' in kwargs and kwargs['dimensions'] is None:
            # Convenience: This way we can construct Scalar variables with `dimensions=None`
            kwargs.pop('dimensions')

        if kwargs.get('dimensions') is not None or (_type and _type.shape):
            return Array(**kwargs)
        if _type and _type.dtype is not BasicType.DEFERRED:
            return Scalar(**kwargs)
        return DeferredTypeSymbol(**kwargs)

    @classmethod
    def _get_type_from_scope(cls, name, scope, parent=None):
        """
        Helper method to determine the type of a symbol

        If no entry is found in the scope's symbol table, a lookup via
        the parent is attempted to construct the type for derived type
        members.

        Parameters
        ----------
        name : str
            The symbol's name
        scope : :any:`Scope`
            The scope in which to search for the symbol's type
        parent : :any:`MetaSymbol` or :any:`TypedSymbol`, optional
            The symbol's parent (for derived type members)

        Returns
        -------
        :any:`SymbolAttributes` or `None`
        """
        # 1. Try to find symbol in scope
        stored_type = scope.symbol_attrs.lookup(name)

        # 2. For derived type members, we can try to find it via the parent instead
        if '%' in name and (not stored_type or stored_type.dtype is BasicType.DEFERRED):
            name_parts = name.split('%')
            if not parent:
                # Build the parent if not given
                parent_type = scope.symbol_attrs.lookup(name_parts[0])
                parent = Variable(name=name_parts[0], scope=scope, type=parent_type)
                for pname in name_parts[1:-1]:
                    if not parent:
                        return None
                    parent = parent.variable_map.get(pname)  # pylint: disable=no-member
            if parent:
                # Lookup type in parent's typedef
                tdef_var = parent.variable_map.get(name_parts[-1])
                if tdef_var:
                    return tdef_var.type

        return stored_type


class InlineDo(StrCompareMixin, pmbl.AlgebraicLeaf):
    """
    An inlined do, e.g., implied-do as used in array constructors
    """

    def __init__(self, values, variable, bounds, **kwargs):
        self.values = values
        self.variable = variable
        self.bounds = bounds
        super().__init__(**kwargs)

    mapper_method = intern('map_inline_do')

    def __getinitargs__(self):
        return (self.values, self.variable, self.bounds)


class InlineCall(StrCompareMixin, pmbl.CallWithKwargs):
    """
    Internal representation of an in-line function call.
    """

    init_arg_names = ('function', 'parameters', 'kw_parameters')

    def __getinitargs__(self):
        return (self.function, self.parameters, as_tuple(self.kw_parameters))


    def __init__(self, function, parameters=None, kw_parameters=None, **kwargs):
        # Unfortunately, have to accept MetaSymbol here for the time being as
        # rescoping before injecting statement functions may create InlineCalls
        # with Scalar/Variable function names.
        assert isinstance(function, (
            ProcedureSymbol, DerivedTypeSymbol, DeferredTypeSymbol, MetaSymbol
        ))
        parameters = parameters or ()
        kw_parameters = kw_parameters or {}

        super().__init__(function=function, parameters=parameters,
                         kw_parameters=kw_parameters, **kwargs)

    mapper_method = intern('map_inline_call')

    def __hash__(self):
        # A custom `__hash__` function to protect us from unhashasble
        # dicts that `pmbl.CallWithKwargs` uses internally
        return hash(self.__getinitargs__())

    @property
    def name(self):
        return self.function.name

    @property
    def procedure_type(self):
        """
        Returns the underpinning procedure type if the type is know,
        ``BasicType.DEFFERED`` otherwise.
        """
        return self.function.type.dtype

    @property
    def arguments(self):
        """
        Alias for :attr:`parameters`
        """
        return self.parameters

    @property
    def kwarguments(self):
        """
        Alias for :attr:`kw_parameters`
        """
        return as_tuple(self.kw_parameters.items())

    @property
    def routine(self):
        """
        The :any:`Subroutine` object of the called routine

        Shorthand for ``call.function.type.dtype.procedure``

        Returns
        -------
        :any:`Subroutine` or :any:`BasicType.DEFERRED`
            If the :any:`ProcedureType` object of the :any:`ProcedureSymbol`
            in :attr:`function` is linked up to the target routine, this returns
            the corresponding :any:`Subroutine` object, otherwise `None`.
        """
        procedure_type = self.procedure_type
        if procedure_type is BasicType.DEFERRED:
            return BasicType.DEFERRED
        return procedure_type.procedure

    def arg_iter(self):
        """
        Iterator that maps argument definitions in the target :any:`Subroutine`
        to arguments and keyword arguments in the call.

        Returns
        -------
        iterator
            An iterator that traverses the mapping ``(arg name, call arg)`` for
            all positional and then keyword arguments.
        """
        routine = self.routine
        assert routine is not BasicType.DEFERRED
        r_args = CaseInsensitiveDict((arg.name, arg) for arg in routine.arguments)
        args = zip(routine.arguments, self.arguments)
        kwargs = ((r_args[kw], arg) for kw, arg in as_tuple(self.kwarguments))
        return chain(args, kwargs)

    @property
    def arg_map(self):
        """
        A full map of all qualified argument matches from arguments
        and keyword arguments.

        Returns
        -------
        dict
            An dictionary that mapping ``arg name: call arg`` for
            all positional and then keyword arguments.
        """
        return dict(self.arg_iter())

    def clone(self, **kwargs):
        """
        Replicate the object with the provided overrides.
        """
        function = kwargs.get('function', self.function)
        parameters = kwargs.get('parameters', self.parameters)
        kw_parameters = kwargs.get('kw_parameters', self.kw_parameters)
        return InlineCall(function, parameters, kw_parameters)

    def _sort_kwarguments(self):
        """
        Helper routine to sort the kwarguments/kw_parameters according to the order of the
        arguments (``self.routine.arguments``)`.
        """
        routine = self.routine
        assert routine is not BasicType.DEFERRED
        kwargs = CaseInsensitiveDict(self.kwarguments)
        r_arg_names = [arg.name for arg in routine.arguments if arg.name in kwargs]
        new_kwarguments = tuple((arg_name, kwargs[arg_name]) for arg_name in r_arg_names)
        return new_kwarguments

    def is_kwargs_order_correct(self):
        """
        Check whether kwarguments/kw_parameters are correctly ordered
        in respect to the arguments (``self.routine.arguments``).
        """
        return self.kwarguments == self._sort_kwarguments()

    def clone_with_sorted_kwargs(self):
        """
        Sort and update the kwarguments/kw_parameters according to the order of the
        arguments (``self.routine.arguments``) and return the
        conveted clone/copy of the inline call.
        """
        new_kwarguments = self._sort_kwarguments()
        return self.clone(kw_parameters=new_kwarguments)

    def clone_with_kwargs_as_args(self):
        """
        Convert all kwarguments/kw_parameters to arguments and
        return the converted clone/copy of the inline call.
        """
        new_kwarguments = self._sort_kwarguments()
        new_args = tuple(arg[1] for arg in new_kwarguments)
        return self.clone(parameters=self.arguments + new_args, kw_parameters=())


class Range(StrCompareMixin, pmbl.Slice):
    """
    Internal representation of a loop or index range.
    """

    def __init__(self, children, **kwargs):
        assert len(children) in (2, 3)
        if len(children) == 2:
            children += (None,)
        super().__init__(children, **kwargs)

    mapper_method = intern('map_range')

    def __hash__(self):
        """ Need custom hashing function if we sepcialise :meth:`__eq__` """
        return hash(super().__str__().lower().replace(' ', ''))

    def __eq__(self, other):
        """ Specialization to capture ``a(1:n) == a(n)`` """
        if self.children[0] == 1 and self.children[2] is None:
            return self.children[1] == other or super().__eq__(other)
        return super().__eq__(other)

    @property
    def lower(self):
        return self.start

    @property
    def upper(self):
        return self.stop


class RangeIndex(Range):
    """
    Internal representation of a subscript range.
    """

    def __hash__(self):
        """ Need custom hashing function if we specialise :meth:`__eq__` """
        return hash(super().__str__().lower().replace(' ', ''))

    def __eq__(self, other):
        """ Specialization to capture `a(1:n) == a(n)` """
        if self.children[0] == 1 and self.children[2] is None:
            return self.children[1] == other or super().__eq__(other)
        return super().__eq__(other)

    mapper_method = intern('map_range_index')


class LoopRange(Range):
    """
    Internal representation of a loop range.
    """

    mapper_method = intern('map_loop_range')

    @property
    def num_iterations(self) -> pmbl.Expression:
        """
        Returns total number of iterations of a loop.

        Given a loop, this returns an expression that computes the total number of
        iterations of the loop, i.e.
        (start,stop,step) -> ceil(stop-start/step)
        """
        start = self.start
        stop = self.stop
        step = self.step
        if step is None:
            return stop if isinstance(start, IntLiteral) and start.value == 1 else Sum(
                (stop, Product((-1, start)), IntLiteral(1)))
        return Sum((Quotient(Sum((stop, Product((-1, start)))), step), IntLiteral(1)))

    @property
    def normalized(self):
        """
        Returns the normalized :any:`LoopRange` of a given :any:`LoopRange`.

        Returns the normalized :any:`LoopRange` which corresponds to a loop with
        the same number of iterations but starts at 1 and has stride 1, i.e.
        (start,stop,step) -> (1,num_iterations,1)
        """
        return LoopRange((1, self.num_iterations))


class ArraySubscript(StrCompareMixin, pmbl.Subscript):
    """
    Internal representation of an array subscript.
    """
    mapper_method = intern('map_array_subscript')


class StringSubscript(StrCompareMixin, pmbl.Subscript):
    """
    Internal representation of a substring subscript operator.
    """
    mapper_method = intern('map_string_subscript')

    @property
    def symbol(self):
        return self.aggregate
loki-ecmwf-0.3.6/loki/expression/operations.py0000664000175000017500000001726415167130205021645 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Sub-classes of Pymbolic's native operations that allow us to inject
Fortran-specific features, such as case-insensitive string comparison
and bracket-aware and sub-expression grouping. Here we also add
additional technical operations, such as cast and references.
"""

from sys import intern
import pymbolic.primitives as pmbl

from loki.tools import as_tuple

from loki.expression.literals import StringLiteral
from loki.expression.mixins import loki_make_stringifier, StrCompareMixin


__all__ = [
    'Sum', 'Product', 'Quotient', 'Power',
    'Comparison', 'LogicalAnd', 'LogicalOr', 'LogicalNot',
    'StringConcat', 'Cast', 'Reference', 'Dereference'
]


class Sum(StrCompareMixin, pmbl.Sum):
    """Representation of a sum."""


class Product(StrCompareMixin, pmbl.Product):
    """Representation of a product."""


class Quotient(StrCompareMixin, pmbl.Quotient):
    """Representation of a quotient."""


class Power(StrCompareMixin, pmbl.Power):
    """Representation of a power."""


class Comparison(StrCompareMixin, pmbl.Comparison):
    """Representation of a comparison operation."""


class LogicalAnd(StrCompareMixin, pmbl.LogicalAnd):
    """Representation of an 'and' in a logical expression."""


class LogicalOr(StrCompareMixin, pmbl.LogicalOr):
    """Representation of an 'or' in a logical expression."""


class LogicalNot(StrCompareMixin, pmbl.LogicalNot):
    """Representation of a negation in a logical expression."""


class ParenthesisedAdd(Sum):
    """
    Specialised version of :class:`Sum` that always pretty-prints and
    code-generates with explicit parentheses.
    """

    mapper_method = intern("map_parenthesised_add")
    make_stringifier = loki_make_stringifier


class ParenthesisedMul(Product):
    """
    Specialised version of :class:`Product` that always pretty-prints and
    code-generates with explicit parentheses.
    """

    mapper_method = intern("map_parenthesised_mul")
    make_stringifier = loki_make_stringifier


class ParenthesisedDiv(Quotient):
    """
    Specialised version of :class:`Quotient` that always pretty-prints and
    code-generates with explicit parentheses.
    """

    mapper_method = intern("map_parenthesised_div")
    make_stringifier = loki_make_stringifier


class ParenthesisedPow(Power):
    """
    Specialised version of :class:`Power` that always pretty-prints and
    code-generates with explicit parentheses.
    """

    mapper_method = intern("map_parenthesised_pow")
    make_stringifier = loki_make_stringifier


class StringConcat(pmbl._MultiChildExpression):
    """
    Implements string concatenation in a way similar to :class:`Sum`.
    """

    def __add__(self, other):
        if isinstance(other, (StringConcat, StringLiteral, pmbl.Variable)):
            return StringConcat((self, other))
        if not other:
            return self
        return NotImplemented

    def __radd__(self, other):
        if isinstance(other, (StringConcat, StringLiteral, pmbl.Variable)):
            return StringConcat((other, self))
        if not other:
            return self
        return NotImplemented

    def __bool__(self):
        if len(self.children) == 1:
            return bool(self.children[0])
        return True

    __nonzero__ = __bool__

    mapper_method = intern("map_string_concat")


class Cast(StrCompareMixin, pmbl.Call):
    """
    Internal representation of a data type cast.
    """

    init_arg_names = ('name', 'expression', 'kind')

    def __init__(self, name, expression, kind=None, **kwargs):
        assert kind is None or isinstance(kind, pmbl.Expression)
        self.kind = kind
        super().__init__(pmbl.make_variable(name), as_tuple(expression), **kwargs)

    def __getinitargs__(self):
        return (self.name, self.expression, self.kind)

    mapper_method = intern('map_cast')

    @property
    def name(self):
        return self.function.name

    @property
    def expression(self):
        return self.parameters


class Reference(StrCompareMixin, pmbl.Expression):
    """
    Internal representation of a Reference.

    .. warning:: Experimental! Allowing compound
        ``Reference(Variable(...))`` to appear
        with behaviour akin to a symbol itself
        for easier processing in mappers.

    **C/C++ only**, no corresponding concept in Fortran.
    Referencing refers to taking the address of an
    existing variable (to set a pointer variable).
    """
    init_arg_names = ('expression',)

    def __getinitargs__(self):
        return (self.expression, )

    def __init__(self, expression):
        assert isinstance(expression, pmbl.Expression)
        self.expression = expression

    @property
    def name(self):
        """
        Allowing the compound ``Reference(Variable(name))`` to appear
        with behaviour akin to a symbol itself for easier processing in mappers.
        """
        return self.expression.name

    @property
    def type(self):
        """
        Allowing the compound ``Reference(Variable(type))`` to appear
        with behaviour akin to a symbol itself for easier processing in mappers.
        """
        return self.expression.type

    @property
    def scope(self):
        """
        Allowing the compound ``Reference(Variable(scope))`` to appear
        with behaviour akin to a symbol itself for easier processing in mappers.
        """
        return self.expression.scope

    @property
    def initial(self):
        """
        Allowing the compound ``Reference(Variable(initial))`` to appear
        with behaviour akin to a symbol itself for easier processing in mappers.
        """
        return self.expression.initial

    mapper_method = intern('map_c_reference')


class Dereference(StrCompareMixin, pmbl.Expression):
    """
    Internal representation of a Dereference.

    .. warning:: Experimental! Allowing compound
        ``Dereference(Variable(...))`` to appear
        with behaviour akin to a symbol itself
        for easier processing in mappers.

    **C/C++ only**, no corresponding concept in Fortran.
    Dereferencing (a pointer) refers to retrieving the value
    from a memory address (that is pointed by the pointer).
    """
    init_arg_names = ('expression', )

    def __getinitargs__(self):
        return (self.expression, )

    def __init__(self, expression):
        assert isinstance(expression, pmbl.Expression)
        self.expression = expression

    @property
    def name(self):
        """
        Allowing the compound ``Dereference(Variable(name))`` to appear
        with behaviour akin to a symbol itself for easier processing in mappers.
        """
        return self.expression.name

    @property
    def type(self):
        """
        Allowing the compound ``Dereference(Variable(type))`` to appear
        with behaviour akin to a symbol itself for easier processing in mappers.
        """
        return self.expression.type

    @property
    def scope(self):
        """
        Allowing the compound ``Dereference(Variable(scope))`` to appear
        with behaviour akin to a symbol itself for easier processing in mappers.
        """
        return self.expression.scope

    @property
    def initial(self):
        """
        Allowing the compound ``Dereference(Variable(initial))`` to appear
        with behaviour akin to a symbol itself for easier processing in mappers.
        """
        return self.expression.initial

    mapper_method = intern('map_c_dereference')
loki-ecmwf-0.3.6/loki/expression/literals.py0000664000175000017500000002314315167130205021272 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Literal symbols representing numbers, strings and logicals.
"""

from sys import intern
import pymbolic.primitives as pmbl
from pymbolic.mapper.evaluator import UnknownVariableError

from loki.types import BasicType
from loki.expression.mixins import StrCompareMixin


__all__ = [
    '_Literal', 'FloatLiteral', 'IntLiteral', 'LogicLiteral',
    'StringLiteral', 'IntrinsicLiteral', 'Literal', 'LiteralList'
]


class _Literal(pmbl.Leaf):
    """
    Base class for literals.

    This exists to overcome the problem of a disfunctional
    :meth:`__getinitargs__` in any:`pymbolic.primitives.Leaf`.
    """

    def __getinitargs__(self):
        return ()


class FloatLiteral(StrCompareMixin, _Literal):
    """
    A floating point constant in an expression.

    Note that its :data:`value` is stored as a string to avoid any
    representation issues that could stem from converting it to a
    Python floating point number.

    It can have a specific type associated, which backends can use to cast
    or annotate the constant to make sure the specified type is used.

    Parameters
    ----------
    value : str
        The value of that literal.
    kind : optional
        The kind information for that literal.
    """

    def __init__(self, value, **kwargs):
        # We store float literals as strings to make sure no information gets
        # lost in the conversion
        self.value = str(value)
        self.kind = kwargs.pop('kind', None)
        super().__init__(**kwargs)

    def __hash__(self):
        return hash((self.value, self.kind))

    def __eq__(self, other):
        if isinstance(other, FloatLiteral):
            return self.value == other.value and self.kind == other.kind

        try:
            return float(self.value) == float(other)
        except (TypeError, ValueError, UnknownVariableError):
            return False

    def __lt__(self, other):
        if isinstance(other, FloatLiteral):
            return float(self.value) < float(other.value)
        try:
            return float(self.value) < float(other)
        except ValueError:
            return super().__lt__(other)

    def __le__(self, other):
        if isinstance(other, FloatLiteral):
            return float(self.value) <= float(other.value)
        try:
            return float(self.value) <= float(other)
        except ValueError:
            return super().__le__(other)

    def __gt__(self, other):
        if isinstance(other, FloatLiteral):
            return float(self.value) > float(other.value)
        try:
            return float(self.value) > float(other)
        except ValueError:
            return super().__gt__(other)

    def __ge__(self, other):
        if isinstance(other, FloatLiteral):
            return float(self.value) >= float(other.value)
        try:
            return float(self.value) >= float(other)
        except ValueError:
            return super().__ge__(other)

    init_arg_names = ('value', 'kind')

    def __getinitargs__(self):
        return (self.value, self.kind)

    mapper_method = intern('map_float_literal')


class IntLiteral(StrCompareMixin, _Literal):
    """
    An integer constant in an expression.

    It can have a specific type associated, which backends can use to cast
    or annotate the constant to make sure the specified type is used.

    Parameters
    ----------
    value : int
        The value of that literal.
    kind : optional
        The kind information for that literal.
    """

    def __init__(self, value, **kwargs):
        self.value = int(value)
        self.kind = kwargs.pop('kind', None)
        super().__init__(**kwargs)

    def __hash__(self):
        return hash((self.value, self.kind))

    def __eq__(self, other):
        if isinstance(other, IntLiteral):
            return self.value == other.value and self.kind == other.kind
        if isinstance(other, (int, float, complex)):
            return self.value == other

        try:
            return self.value == int(other)
        except (TypeError, ValueError):
            return False

    def __lt__(self, other):
        if isinstance(other, IntLiteral):
            return self.value < other.value
        if isinstance(other, int):
            return self.value < other
        return super().__lt__(other)

    def __le__(self, other):
        if isinstance(other, IntLiteral):
            return self.value <= other.value
        if isinstance(other, int):
            return self.value <= other
        return super().__le__(other)

    def __gt__(self, other):
        if isinstance(other, IntLiteral):
            return self.value > other.value
        if isinstance(other, int):
            return self.value > other
        return super().__gt__(other)

    def __ge__(self, other):
        if isinstance(other, IntLiteral):
            return self.value >= other.value
        if isinstance(other, int):
            return self.value >= other
        return super().__ge__(other)

    init_arg_names = ('value', 'kind')

    def __getinitargs__(self):
        return (self.value, self.kind)

    def __int__(self):
        return self.value

    def __bool__(self):
        return bool(self.value)

    mapper_method = intern('map_int_literal')


# Register IntLiteral as a constant class in Pymbolic
pmbl.register_constant_class(IntLiteral)


class LogicLiteral(StrCompareMixin, _Literal):
    """
    A boolean constant in an expression.

    Parameters
    ----------
    value : bool
        The value of that literal.
    """

    def __init__(self, value, **kwargs):
        self.value = str(value).lower() in ('true', '.true.')
        super().__init__(**kwargs)

    init_arg_names = ('value', )

    def __getinitargs__(self):
        return (self.value, )

    def __bool__(self):
        return self.value

    mapper_method = intern('map_logic_literal')


class StringLiteral(StrCompareMixin, _Literal):
    """
    A string constant in an expression.

    Parameters
    ----------
    value : str
        The value of that literal. Enclosing quotes are removed.
    """

    def __init__(self, value, **kwargs):
        # Remove quotation marks
        if value[0] == value[-1] and value[0] in '"\'':
            value = value[1:-1]

        self.value = value

        super().__init__(**kwargs)

    def __hash__(self):
        return hash(self.value)

    def __eq__(self, other):
        if isinstance(other, StringLiteral):
            return self.value == other.value
        if isinstance(other, str):
            return self.value == other
        return False

    init_arg_names = ('value', )

    def __getinitargs__(self):
        return (self.value, )

    mapper_method = intern('map_string_literal')


class IntrinsicLiteral(StrCompareMixin, _Literal):
    """
    Any literal not represented by a dedicated class.

    Its value is stored as string and returned unaltered.
    This is currently used for complex and BOZ constants and to retain
    array constructor expressions with type spec or implied-do.

    Parameters
    ----------
    value : str
        The value of that literal.
    """

    def __init__(self, value, **kwargs):
        self.value = value
        super().__init__(**kwargs)

    init_arg_names = ('value', )

    def __getinitargs__(self):
        return (self.value, )

    mapper_method = intern('map_intrinsic_literal')


class Literal:
    """
    Factory class to instantiate the best-matching literal node.

    This always returns a :class:`IntLiteral`, :class:`FloatLiteral`,
    :class:`StringLiteral`, :class:`LogicLiteral` or, as a fallback,
    :class:`IntrinsicLiteral`, selected by using any provided :data:`type`
    information or inspecting the Python data type of :data: value.

    Parameters
    ----------
    value :
        The value of that literal.
    kind : optional
        The kind information for that literal.
    """

    @staticmethod
    def _from_literal(value, **kwargs):

        cls_map = {BasicType.INTEGER: IntLiteral, BasicType.REAL: FloatLiteral,
                   BasicType.LOGICAL: LogicLiteral, BasicType.CHARACTER: StringLiteral}

        _type = kwargs.pop('type', None)
        if _type is None:
            if isinstance(value, int):
                _type = BasicType.INTEGER
            elif isinstance(value, float):
                _type = BasicType.REAL
            elif isinstance(value, str):
                if str(value).lower() in ('.true.', 'true', '.false.', 'false'):
                    _type = BasicType.LOGICAL
                else:
                    _type = BasicType.CHARACTER

        return cls_map[_type](value, **kwargs)

    def __new__(cls, value, **kwargs):
        try:
            obj = cls._from_literal(value, **kwargs)
        except KeyError:
            obj = IntrinsicLiteral(value, **kwargs)

        # And attach our own meta-data
        if hasattr(obj, 'kind'):
            obj.kind = kwargs.get('kind', None)
        return obj


class LiteralList(StrCompareMixin, pmbl.AlgebraicLeaf):
    """
    A list of constant literals, e.g., as used in Array Initialization Lists.
    """

    init_arg_names = ('values', 'dtype')

    def __init__(self, values, dtype=None, **kwargs):
        self.elements = values
        self.dtype = dtype
        super().__init__(**kwargs)

    mapper_method = intern('map_literal_list')

    def __getinitargs__(self):
        return (self.elements, self.dtype)
loki-ecmwf-0.3.6/loki/expression/mappers.py0000664000175000017500000010446615167130205021132 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Mappers for traversing and transforming the
:ref:`internal_representation:Expression tree`.
"""

import re
from itertools import zip_longest
import pymbolic.primitives as pmbl
from pymbolic.mapper import Mapper, WalkMapper, CombineMapper, IdentityMapper
from pymbolic.mapper.stringifier import (
    StringifyMapper, PREC_NONE, PREC_SUM, PREC_CALL, PREC_PRODUCT
)
try:
    from fparser.two.Fortran2003 import Intrinsic_Name
    _intrinsic_fortran_names = Intrinsic_Name.function_names
except ImportError:
    _intrinsic_fortran_names = ()

from loki.logging import debug
from loki.tools import as_tuple, flatten
from loki.types import SymbolAttributes, BasicType


__all__ = ['LokiStringifyMapper', 'ExpressionRetriever', 'ExpressionDimensionsMapper',
           'ExpressionCallbackMapper', 'SubstituteExpressionsMapper',
           'LokiIdentityMapper', 'AttachScopesMapper', 'DetachScopesMapper']



class LokiStringifyMapper(StringifyMapper):
    """
    A class derived from the default :class:`StringifyMapper` that adds mappings for nodes of the
    expression tree that we added ourselves.

    This is the default pretty printer for nodes in the expression tree.
    """
    # pylint: disable=unused-argument,abstract-method

    _regex_string_literal = re.compile(r"((? wrapped_in_call(my_var)``.

       When there is a need to recursively apply the mapping, the mapping needs to
       be applied to itself first. A potential use-case is renaming of variables,
       which may appear as the name of an array subscript as well as in the ``dimensions``
       attribute of the same expression: ``SOME_ARR(SOME_ARR > SOME_VAL)``.
       The mapping can be applied to itself using the utility function
       :any:`recursive_expression_map_update`.

    Parameters
    ----------
    expr_map : dict
        Expression mapping to apply to the expression tree.
    """
    # pylint: disable=abstract-method

    def __init__(self, expr_map):
        super().__init__()

        self.expr_map = expr_map
        for expr in self.expr_map.keys():
            setattr(self, expr.mapper_method, self.map_from_expr_map)

    def map_from_expr_map(self, expr, *args, **kwargs):
        """
        Replace an expr with its substitution, if found in the :attr:`expr_map`,
        otherwise continue tree traversal
        """
        if expr in self.expr_map:
            return self._rebuild(self.expr_map[expr])
        map_fn = getattr(super(), expr.mapper_method)
        return map_fn(expr, *args, **kwargs)


class AttachScopesMapper(LokiIdentityMapper):
    """
    A Pymbolic expression mapper (i.e., a visitor for the expression tree)
    that determines the scope of :any:`TypedSymbol` nodes and updates its
    :attr:`scope` pointer accordingly.

    Parameters
    ----------
    fail : bool, optional
        If `True`, the mapper raises :any:`RuntimeError` if the scope for a
        symbol can not be found.
    """

    def __init__(self, fail=False):
        super().__init__()
        self.fail = fail

    def _update_symbol_scope(self, expr, scope):
        """
        Find the scope of :data:`expr` and, if it is different,
        attach the new scope and return the symbol
        """
        symbol_scope = scope.get_symbol_scope(expr.name)
        if symbol_scope is None and '%' in expr.name:
            symbol_scope = scope.get_symbol_scope(expr.name_parts[0])
        if symbol_scope is not None:
            if symbol_scope is not expr.scope:
                expr = expr.rescope(symbol_scope)
        elif self.fail:
            raise RuntimeError(f'AttachScopesMapper: {expr!s} was not found in any scope')
        elif expr not in _intrinsic_fortran_names:
            debug('AttachScopesMapper: %s was not found in any scopes', str(expr))
        return expr

    def map_variable_symbol(self, expr, *args, **kwargs):
        """
        Handler for :class:`VariableSymbol`

        This updates the symbol's scope via :meth:`_update_symbol_scope`
        and then calls the parent class handler routine

        Note: this may be a different handler as attaching the scope and therefore
        type may change a symbol's type, e.g. from :class:`DeferredTypeSymbol` to :class:`Scalar`
        """
        new_expr = self._update_symbol_scope(expr, kwargs['scope'])
        if new_expr.scope and new_expr.scope is not kwargs['scope']:
            # We call the parent handler to take care of properties like initial value, kind etc.,
            # all of which should be declared at or above the scope of the expression
            kwargs['scope'] = new_expr.scope
        map_fn = getattr(super(), new_expr.mapper_method)

        # If we cannot resolve scope or type of an expression, we mark it as deferred
        if not new_expr.scope and not new_expr.type:
            new_expr.type = SymbolAttributes(dtype=BasicType.DEFERRED)

        return map_fn(new_expr, *args, **kwargs)

    map_deferred_type_symbol = map_variable_symbol

    def map_procedure_symbol(self, expr, *args, **kwargs):
        if expr.type and expr.type.is_intrinsic:
            # Always rescope intrinsics to the closest scope
            return expr.clone(scope=kwargs['scope'])
        return self.map_variable_symbol(expr, *args, **kwargs)


class DetachScopesMapper(LokiIdentityMapper):
    """
    A Pymbolic expression mapper (i.e., a visitor for the expression tree)
    that rebuilds an expression unchanged but with the scope for every
    :any:`TypedSymbol` detached.

    This will ensure that type information is stored locally on the object
    itself, which is useful when storing information for inter-procedural
    analysis passes.
    """

    def map_variable_symbol(self, expr, *args, **kwargs):
        new_expr = super().map_variable_symbol(expr, *args, **kwargs)
        new_expr = new_expr.clone(scope=None)
        return new_expr

    map_deferred_type_symbol = map_variable_symbol
    map_procedure_symbol = map_variable_symbol
loki-ecmwf-0.3.6/loki/analyse/0000775000175000017500000000000015167130205016333 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/analyse/__init__.py0000664000175000017500000000077515167130205020455 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
"""
Advanced analysis utilities, such as dataflow analysis functionalities.
"""

from loki.analyse.analyse_dataflow import *  # noqa
loki-ecmwf-0.3.6/loki/analyse/tests/0000775000175000017500000000000015167130205017475 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/analyse/tests/test_util_polyhedron.py0000664000175000017500000002437615167130205024342 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from pathlib import Path
import pytest
import numpy as np

from loki.analyse.util_polyhedron import Polyhedron
from loki.expression import symbols as sym, parse_expr
from loki.ir import Loop, FindNodes
from loki.sourcefile import Sourcefile
from loki.types import Scope

@pytest.fixture(scope="module", name="here")
def fixture_here():
    return Path(__file__).parent


@pytest.fixture(scope='module', name='testdir')
def fixture_testdir(here):
    return here.parent.parent/'tests'



@pytest.mark.parametrize(
    "variables, lbounds, ubounds, A, b, variable_names",
    [
        # do i=0,5: do j=i,7: ...
        (
            ["i", "j"],
            ["0", "i"],
            ["5", "7"],
            [[-1, 0], [1, 0], [1, -1], [0, 1]],
            [0, 5, 0, 7],
            ["i", "j"],
        ),
        # do i=1,n: do j=0,2*i+1: do k=a,b: ...
        (
            ["i", "j", "k"],
            ["1", "0", "a"],
            ["n", "2*i+1", "b"],
            [
                [-1, 0, 0, 0, 0, 0],
                [1, 0, 0, 0, 0, -1],
                [0, -1, 0, 0, 0, 0],
                [-2, 1, 0, 0, 0, 0],
                [0, 0, -1, 1, 0, 0],
                [0, 0, 1, 0, -1, 0],
            ],
            [-1, 0, 0, 1, 0, 0],
            ["i", "j", "k", "a", "b", "n"],
        ),
        # do jk=1,klev: ...
        (["jk"], ["1"], ["klev"], [[-1, 0], [1, -1]], [-1, 0], ["jk", "klev"]),
        # do JK=1,klev-1: ...
        (["JK"], ["1"], ["klev - 1"], [[-1, 0], [1, -1]], [-1, -1], ["jk", "klev"]),
        # do jk=ncldtop,klev: ...
        (
            ["jk"],
            ["ncldtop"],
            ["klev"],
            [[-1, 0, 1], [1, -1, 0]],
            [0, 0],
            ["jk", "klev", "ncldtop"],
        ),
        # do jk=1,KLEV+1: ...
        (["jk"], ["1"], ["KLEV+1"], [[-1, 0], [1, -1]], [-1, 1], ["jk", "klev"]),
    ],
)
def test_polyhedron_from_loop_ranges(variables, lbounds, ubounds, A, b, variable_names):
    """
    Test converting loop ranges to polyedron representation of iteration space.
    """
    scope = Scope()
    loop_variables = [parse_expr(expr, scope) for expr in variables]
    loop_lbounds = [parse_expr(expr, scope) for expr in lbounds]
    loop_ubounds = [parse_expr(expr, scope) for expr in ubounds]
    loop_ranges = [sym.LoopRange((l, u)) for l, u in zip(loop_lbounds, loop_ubounds)]
    p = Polyhedron.from_loop_ranges(loop_variables, loop_ranges)
    assert np.all(p.A == np.array(A, dtype=np.dtype(int)))
    assert np.all(p.b == np.array(b, dtype=np.dtype(int)))
    assert p.variables == variable_names


def test_polyhedron_from_loop_ranges_failures():
    """
    Test known limitation of the conversion from loop ranges to polyhedron.
    """
    # m*n is non-affine and thus can't be represented
    scope = Scope()
    loop_variable = parse_expr("i", scope)
    lower_bound = parse_expr("1", scope)
    upper_bound = parse_expr("m * n", scope)
    loop_range = sym.LoopRange((lower_bound, upper_bound))
    with pytest.raises(ValueError):
        _ = Polyhedron.from_loop_ranges([loop_variable], [loop_range])

    # no functionality to flatten exponentials, yet
    upper_bound = parse_expr("5**2", scope)
    loop_range = sym.LoopRange((lower_bound, upper_bound))
    with pytest.raises(ValueError):
        _ = Polyhedron.from_loop_ranges([loop_variable], [loop_range])


@pytest.mark.parametrize(
    "A, b, variable_names, lower_bounds, upper_bounds",
    [
        # do i=1,n: ...
        ([[-1, 0], [1, -1]], [-1, 0], ["i", "n"], [["1"], ["i"]], [["n"], []]),
        # do i=1,10: ...
        ([[-1], [1]], [-1, 10], ["i"], [["1"]], [["10"]]),
        # do i=0,5: do j=i,7: ...
        (
            [[-1, 0], [1, 0], [1, -1], [0, 1]],
            [0, 5, 0, 7],
            ["i", "j"],
            [["0"], ["i"]],
            [["5", "j"], ["7"]],
        ),
        # do i=1,n: do j=0,2*i+1: do k=a,b: ...
        (
            [
                [-1, 0, 0, 0, 0, 0],
                [1, 0, 0, 0, 0, -1],
                [0, -1, 0, 0, 0, 0],
                [-2, 1, 0, 0, 0, 0],
                [0, 0, -1, 1, 0, 0],
                [0, 0, 1, 0, -1, 0],
            ],
            [-1, 0, 0, 1, 0, 0],
            ["i", "j", "k", "a", "b", "n"],  # variable names
            [["1", "-1 / 2 + j / 2"], ["0"], ["a"], [], ["k"], ["i"]],  # lower bounds
            [["n"], ["1 + 2*i"], ["b"], ["k"], [], []],
        ),  # upper bounds
    ],
)
def test_polyhedron_bounds(A, b, variable_names, lower_bounds, upper_bounds):
    """
    Test the production of lower and upper bounds.
    """
    scope = Scope()
    variables = [parse_expr(v, scope) for v in variable_names]
    p = Polyhedron(A, b, variables)
    for var, ref_bounds in zip(variables, lower_bounds):
        lbounds = p.lower_bounds(var)
        assert len(lbounds) == len(ref_bounds)
        assert all(str(b1) == b2 for b1, b2 in zip(lbounds, ref_bounds))
    for var, ref_bounds in zip(variables, upper_bounds):
        ubounds = p.upper_bounds(var)
        assert len(ubounds) == len(ref_bounds)
        assert all(str(b1) == b2 for b1, b2 in zip(ubounds, ref_bounds))


@pytest.mark.parametrize(
    "polyhedron,is_empty,will_fail",
    [
        # totaly empty polyhedron
        (Polyhedron.from_nested_loops([]), True, False),
        # full matrix --> non trivial problem
        (Polyhedron([[1]], [1]), None, True),
        # empty matrix, full and fullfiled b --> non empty polyhedron
        (Polyhedron([[]], [1]), False, False),
        # empty matrix, full b but not fullfiled b --> empty polyhedron
        (Polyhedron([[]], [-1]), True, False),
    ],
)
def test_check_empty_polyhedron(polyhedron, is_empty, will_fail):
    if will_fail:
        with pytest.raises(RuntimeError):
            _ = polyhedron.is_empty()
    else:
        assert polyhedron.is_empty() == is_empty


def simple_loop_extractor(start_node):
    """Find all loops in the AST and structure them depending on their nesting level"""
    start_loops = FindNodes(Loop, greedy=True).visit(start_node)
    return [FindNodes(Loop).visit(node) for node in start_loops]


def assert_equal_polyhedron(poly_A, poly_B):
    assert poly_A.variables == poly_B.variables
    assert (poly_A.A == poly_B.A).all()
    assert (poly_A.b == poly_B.b).all()


@pytest.mark.parametrize(
    "filename, loop_extractor, polyhedrons_per_subroutine",
    [
        (
            "sources/data_dependency_detection/loop_carried_dependencies.f90",
            simple_loop_extractor,
            {
                "SimpleDependency": [
                    Polyhedron(
                        [[-1, 0], [1, -1]], [-1, 0], [sym.Scalar("i"), sym.Scalar("n")]
                    ),
                ],
                "NestedDependency": [
                    Polyhedron(
                        [[-1, 0, 0], [1, 0, -1], [0, -1, 0], [-1, 1, 0]],
                        [-2, 0, -1, -1],
                        [sym.Scalar("i"), sym.Scalar("j"), sym.Scalar("n")],
                    ),
                ],
                "ConditionalDependency": [
                    Polyhedron(
                        [[-1, 0], [1, -1]],
                        [-2, 0],
                        [sym.Scalar("i"), sym.Scalar("n")],
                    ),
                ],
                "NoDependency": [
                    Polyhedron(
                        [[-1], [1]],
                        [-1, 10],
                        [sym.Scalar("i")],
                    ),
                    Polyhedron(
                        [[-1], [1]],
                        [-1, 5],
                        [sym.Scalar("i")],
                    ),
                ],
            },
        ),
        (
            "sources/data_dependency_detection/various_loops.f90",
            simple_loop_extractor,
            {
                "single_loop": [
                    Polyhedron(
                        [[-1, 0], [1, -1]],
                        [-1, 0],
                        [sym.Scalar("i"), sym.Scalar("n")],
                    ),
                ],
                "single_loop_split_access": [
                    Polyhedron(
                        [[-1, 0], [1, -1]],
                        [-1, 0],
                        [sym.Scalar("i"), sym.Scalar("nhalf")],
                    ),
                ],
                "single_loop_arithmetic_operations_for_access": [
                    Polyhedron(
                        [[-1, 0], [1, -1]],
                        [-1, 0],
                        [sym.Scalar("i"), sym.Scalar("n")],
                    ),
                ],
                "nested_loop_single_dimensions_access": [
                    Polyhedron(
                        [[-1, 0, 0], [1, 0, -1], [0, -1, 0], [0, 1, -1]],
                        [-1, 0, -1, 0],
                        [sym.Scalar("i"), sym.Scalar("j"), sym.Scalar("nhalf")],
                    ),
                ],
                "nested_loop_partially_used": [
                    Polyhedron(
                        [[-1, 0, 0], [1, 0, -1], [0, -1, 0], [0, 1, -1]],
                        [-1, 0, -1, 0],
                        [sym.Scalar("i"), sym.Scalar("j"), sym.Scalar("nfourth")],
                    ),
                ],
                "partially_used_array": [
                    Polyhedron(
                        [[-1, 0], [1, -1]],
                        [-2, 0],
                        [sym.Scalar("i"), sym.Scalar("nhalf")],
                    ),
                ],
            },
        ),
    ],
)
def test_polyhedron_construction_from_nested_loops(
    testdir, filename, loop_extractor, polyhedrons_per_subroutine
):
    source = Sourcefile.from_file(testdir / filename)

    for subroutine in source.all_subroutines:
        expected_polyhedrons = polyhedrons_per_subroutine[subroutine.name]

        list_of_loops = loop_extractor(subroutine.body)

        polyhedrons = [Polyhedron.from_nested_loops(loops) for loops in list_of_loops]

        for polyhedron, expected_polyhedron in zip(polyhedrons, expected_polyhedrons):
            assert_equal_polyhedron(polyhedron, expected_polyhedron)
loki-ecmwf-0.3.6/loki/analyse/tests/__init__.py0000664000175000017500000000057015167130205021610 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
loki-ecmwf-0.3.6/loki/analyse/tests/test_analyse_dataflow.py0000664000175000017500000005132115167130205024425 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki import Module, Sourcefile, Subroutine
from loki.analyse import (
    dataflow_analysis_attached, read_after_write_vars, loop_carried_dependencies
)
from loki.analyse.analyse_dataflow import DataflowAnalysisAttacher, DataflowAnalysisDetacher
from loki.backend import fgen
from loki.expression import symbols as sym
from loki.frontend import available_frontends, OMNI
from loki.ir import nodes as ir, FindNodes


@pytest.mark.parametrize('frontend', available_frontends())
def test_analyse_live_symbols(frontend):
    fcode = """
subroutine analyse_live_symbols(v1, v2, v3)
  integer, intent(in) :: v1
  integer, intent(inout) :: v2
  integer, intent(out) :: v3
  integer :: i, j, n=10, tmp, a

  do i=1,n
    do j=1,n
      tmp = j + 1
    end do
    a = v2 + tmp
  end do

  v3 = v1 + v2
  v2 = a
end subroutine analyse_live_symbols
    """.strip()
    routine = Subroutine.from_source(fcode, frontend=frontend)
    ref_fgen = fgen(routine)

    assignments = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assignments) == 4

    with pytest.raises(RuntimeError):
        for assignment in assignments:
            _ = assignment.live_symbols

    ref_live_symbols = {
        'tmp': {'i', 'j', 'n', 'v1', 'v2'},
        'a': {'i', 'tmp', 'n', 'v1', 'v2'},
        'v3': {'tmp', 'a', 'n', 'v1', 'v2'},
        'v2': {'tmp', 'a', 'n', 'v1', 'v2', 'v3'}
    }

    with dataflow_analysis_attached(routine):
        assert routine.body

        for assignment in assignments:
            live_symbols = {str(s).lower() for s in assignment.live_symbols}
            assert live_symbols == ref_live_symbols[str(assignment.lhs).lower()]

    assert routine.body
    assert fgen(routine) == ref_fgen

    with pytest.raises(RuntimeError):
        for assignment in assignments:
            _ = assignment.live_symbols


@pytest.mark.parametrize('frontend', available_frontends())
def test_analyse_defines_uses_symbols(frontend):
    fcode = """
subroutine analyse_defines_uses_symbols(a, j, m, n)
  integer, intent(out) :: a, j
  integer, intent(in) :: m, n
  integer :: i
  j = n
  a = 1
  do i=m-1,n
    if (i > a) then
      a = a + 1
      if (i < n) exit
    end if
    j = j - 1
  end do
end subroutine analyse_defines_uses_symbols
    """.strip()

    routine = Subroutine.from_source(fcode, frontend=frontend)
    ref_fgen = fgen(routine)

    conditionals = FindNodes(ir.Conditional).visit(routine.body)
    assert len(conditionals) == 2
    loops = FindNodes(ir.Loop).visit(routine.body)
    assert len(loops) == 1

    with pytest.raises(RuntimeError):
        for cond in conditionals:
            _ = cond.defines_symbols
        for cond in conditionals:
            _ = cond.uses_symbols

    with dataflow_analysis_attached(routine):
        assert fgen(routine) == ref_fgen
        assert len(FindNodes(ir.Conditional).visit(routine.body)) == 2
        assert len(FindNodes(ir.Loop).visit(routine.body)) == 1

        assert {str(s) for s in routine.body.uses_symbols} == {'m', 'n'}
        assert {str(s) for s in loops[0].uses_symbols} == {'m', 'n', 'a', 'j'}
        assert {str(s) for s in conditionals[0].uses_symbols} == {'i', 'a', 'n'}
        assert {str(s) for s in conditionals[1].uses_symbols} == {'i', 'n'}
        assert not conditionals[1].body[0].uses_symbols

        assert {str(s) for s in routine.body.defines_symbols} == {'j', 'a'}
        assert {str(s) for s in loops[0].defines_symbols} == {'j', 'a'}
        assert {str(s) for s in conditionals[0].defines_symbols} == {'a'}
        assert not conditionals[1].defines_symbols
        assert not conditionals[1].body[0].defines_symbols

    assert fgen(routine) == ref_fgen

    with pytest.raises(RuntimeError):
        for cond in conditionals:
            _ = cond.defines_symbols
        for cond in conditionals:
            _ = cond.uses_symbols


@pytest.mark.parametrize('frontend', available_frontends())
def test_read_after_write_vars(frontend):
    fcode = """
subroutine analyse_read_after_write_vars
  integer :: a, b, c, d, e, f

  a = 1
!$loki A
  b = 2
!$loki B
  c = a + 1
!$loki C
  d = b + 1
!$loki D
  e = c + d
!$loki E
  e = 3
  f = e
end subroutine analyse_read_after_write_vars
    """.strip()

    routine = Subroutine.from_source(fcode, frontend=frontend)
    variable_map = routine.variable_map

    vars_at_inspection_node = {
        'A': {variable_map['a']},
        'B': {variable_map['a'], variable_map['b']},
        'C': {variable_map['b'], variable_map['c']},
        'D': {variable_map['c'], variable_map['d']},
        'E': set(),
    }

    pragmas = FindNodes(ir.Pragma).visit(routine.body)
    assert len(pragmas) == 5

    with dataflow_analysis_attached(routine):
        for pragma in pragmas:
            assert read_after_write_vars(routine.body, pragma) == vars_at_inspection_node[pragma.content]


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('include_literal_kinds', [True, False])
def test_read_after_write_vars_conditionals(frontend, include_literal_kinds):
    fcode = """
subroutine analyse_read_after_write_vars_conditionals(a, b, c, d, e, f)
  use iso_fortran_env, only : int32
  integer, intent(in) :: a
  integer, intent(out) :: b, c, d, e, f

  b = 1
  d = 0
!$loki A
  if (a < 3_int32) then
    d = b
!$loki B
  endif
!$loki C
  c = 2 + d
!$loki D
  if (a < 5) then
    e = a
  else
    e = c
  endif
!$loki E
  f = e
end subroutine analyse_read_after_write_vars_conditionals
    """.strip()

    routine = Subroutine.from_source(fcode, frontend=frontend)
    variable_map = routine.variable_map

    vars_at_inspection_node = {
        'A': {variable_map['b'], variable_map['d']},
        'B': {variable_map['d']},
        'C': {variable_map['d']},
        'D': {variable_map['c']},
        'E': {variable_map['e']},
    }

    pragmas = FindNodes(ir.Pragma).visit(routine.body)
    assert len(pragmas) == len(vars_at_inspection_node)

    # We skip the context manager here to test the "include_literal_kinds" option
    DataflowAnalysisAttacher(include_literal_kinds=include_literal_kinds).visit(routine.body)

    if include_literal_kinds:
        assert 'int32' in routine.body.uses_symbols
    else:
        assert not 'int32' in routine.body.uses_symbols
    for pragma in pragmas:
        assert read_after_write_vars(routine.body, pragma) == vars_at_inspection_node[pragma.content]

    DataflowAnalysisDetacher().visit(routine.body)


@pytest.mark.parametrize('frontend', available_frontends())
def test_loop_carried_dependencies(frontend):
    fcode = """
subroutine analyse_loop_carried_dependencies(a, b, c)
  integer, intent(inout) :: a, b, c
  integer :: i, tmp

  do i = 1,a
    b = b + i
    tmp = c
    c = 5 + tmp
  end do
end subroutine analyse_loop_carried_dependencies
    """.strip()


    routine = Subroutine.from_source(fcode, frontend=frontend)
    variable_map = routine.variable_map

    loops = FindNodes(ir.Loop).visit(routine.body)
    assert len(loops) == 1

    with dataflow_analysis_attached(routine):
        assert loop_carried_dependencies(loops[0]) == {variable_map['b'], variable_map['c']}

@pytest.mark.parametrize('frontend', available_frontends())
def test_analyse_interface(frontend):
    fcode = """
subroutine random_call(v_out,v_in,v_inout)
implicit none

  real,intent(in)  :: v_in
  real,intent(out)  :: v_out
  real,intent(inout)  :: v_inout


end subroutine random_call

subroutine test(v_out,v_in,v_inout)
implicit none
interface
  subroutine random_call(v_out,v_in,v_inout)
     real,intent(in)  :: v_in
     real,intent(out)  :: v_out
     real,intent(inout)  :: v_inout
  end subroutine random_call
end interface

real,intent(in   )  :: v_in
real,intent(out  )  :: v_out
real,intent(inout)  :: v_inout

end subroutine test
    """.strip()

    source = Sourcefile.from_source(fcode, frontend=frontend)
    routine = source['test']

    with dataflow_analysis_attached(routine):
        assert len(routine.body.defines_symbols) == 0
        assert len(routine.body.uses_symbols) == 0
        assert len(routine.spec.uses_symbols) == 0
        assert len(routine.spec.defines_symbols) == 1
        assert isinstance(list(routine.spec.defines_symbols)[0], sym.ProcedureSymbol)
        assert 'random_call' in routine.spec.defines_symbols


@pytest.mark.parametrize('frontend', available_frontends())
def test_analyse_imports(frontend, tmp_path):
    fcode_module = """
module some_mod
implicit none
real :: my_global
contains
subroutine random_call(v_out,v_in,v_inout)

  real,intent(in)  :: v_in
  real,intent(out)  :: v_out
  real,intent(inout)  :: v_inout


end subroutine random_call
end module some_mod
""".strip()

    fcode = """
subroutine test()
use some_mod, only: my_global, random_call
implicit none

end subroutine test
""".strip()

    module = Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path])
    routine = Subroutine.from_source(fcode, frontend=frontend, definitions=module, xmods=[tmp_path])

    with dataflow_analysis_attached(routine):
        assert len(routine.spec.defines_symbols) == 1
        assert 'random_call' in routine.spec.defines_symbols


@pytest.mark.parametrize('frontend', available_frontends())
def test_analyse_enriched_call(frontend):
    fcode = """
subroutine random_call(v_out,v_in,v_inout)
implicit none

  real,intent(in)  :: v_in
  real,intent(out)  :: v_out
  real,intent(inout)  :: v_inout


end subroutine random_call

subroutine test(v_out,v_in,v_inout)
implicit none

  real,intent(in   )  :: v_in
  real,intent(out  )  :: v_out
  real,intent(inout)  :: v_inout

  call random_call(v_out,v_in,v_inout)

end subroutine test
    """.strip()

    source = Sourcefile.from_source(fcode, frontend=frontend)
    routine = source['test']
    routine.enrich(source.all_subroutines)
    call = FindNodes(ir.CallStatement).visit(routine.body)[0]

    with dataflow_analysis_attached(routine):
        assert all(i in call.defines_symbols for i in ('v_out', 'v_inout'))
        assert all(i in call.uses_symbols for i in ('v_in', 'v_inout'))


@pytest.mark.parametrize('frontend', available_frontends())
def test_analyse_unenriched_call(frontend):
    fcode = """
subroutine test(v_out,v_in,v_inout)
implicit none

  real,intent(in   )  :: v_in
  real,intent(out  )  :: v_out
  real,intent(inout)  :: v_inout

  call random_call(v_out,v_in,var=v_inout)

end subroutine test
    """.strip()

    source = Sourcefile.from_source(fcode, frontend=frontend)
    routine = source['test']
    call = FindNodes(ir.CallStatement).visit(routine.body)[0]

    with dataflow_analysis_attached(routine):
        assert all(i in call.defines_symbols for i in ('v_out', 'v_inout', 'v_in'))
        assert all(i in call.uses_symbols for i in ('v_in', 'v_inout', 'v_in'))


@pytest.mark.parametrize('frontend', available_frontends())
def test_analyse_allocate_statement(frontend):
    fcode = """
subroutine test(n,m)
implicit none

  integer,intent(in   ) :: n
  integer,intent(inout) :: m
  real,allocatable :: a(:,:)

  allocate(a(n,m))


  deallocate(a)

end subroutine test
    """.strip()

    routine = Subroutine.from_source(fcode, frontend=frontend)
    with dataflow_analysis_attached(routine):
        assert all(i not in routine.body.defines_symbols for i in ['m', 'n'])
        assert all(i in routine.body.uses_symbols for i in ['m', 'n'])
        assert 'a' in routine.body.defines_symbols


@pytest.mark.parametrize('frontend', available_frontends())
def test_analyse_import_kind(frontend):
    fcode = """
subroutine test(n,m)
use iso_fortran_env, only: real64
implicit none

  integer,intent(in   ) :: n
  integer,intent(inout) :: m
  real(kind=real64),allocatable :: a(:,:)

  a = 0._real64

end subroutine test
    """.strip()

    routine = Subroutine.from_source(fcode, frontend=frontend)
    with dataflow_analysis_attached(routine):
        assert 'real64' in routine.body.uses_symbols
        assert 'real64' in routine.spec.uses_symbols
        assert 'real64' not in routine.body.defines_symbols
        assert 'a' in routine.body.defines_symbols
        assert 'a' not in routine.body.uses_symbols


@pytest.mark.parametrize('frontend', available_frontends())
def test_analyse_query_memory_attributes(frontend):
    """
    Test that checks whether variables used only in function calls that
    query memory attributes appear in uses_symbols.
    """

    fcode = """
subroutine test(a)
implicit none

  real,intent(out) :: a(:,:)
  real             :: b(10)
  integer          :: bsize, i

  if(size(a) > 0) a(:,:) = 0.
  bsize = size(b)

  do i=1,size(b)
    print *, i
  enddo

end subroutine test
    """.strip()

    routine = Subroutine.from_source(fcode, frontend=frontend)
    with dataflow_analysis_attached(routine):
        assert not 'a' in routine.body.uses_symbols
        assert 'a' in routine.body.defines_symbols
        assert not 'b' in routine.body.uses_symbols


@pytest.mark.parametrize('frontend', available_frontends())
def test_analyse_call_args_array_slicing(frontend):
    fcode = """
subroutine random_call(v)
implicit none

  integer,intent(out) :: v

  v = 1

end subroutine random_call

subroutine test(v,n,b)
implicit none

  integer,intent(out) :: v(:)
  integer,intent( in) :: n
  integer,intent( in) :: b(n)

  call random_call(v(n))
  call random_call(v(b(1)))

end subroutine test
    """.strip()

    source = Sourcefile.from_source(fcode, frontend=frontend)
    routine = source['test']

    calls = FindNodes(ir.CallStatement).visit(routine.body)
    routine.enrich(source.all_subroutines)

    with dataflow_analysis_attached(routine):
        assert 'n' in calls[0].uses_symbols
        assert not 'n' in calls[0].defines_symbols
        assert 'b' in calls[1].uses_symbols
        assert not 'b' in calls[0].defines_symbols


@pytest.mark.parametrize('frontend', available_frontends())
def test_analyse_multiconditional(frontend):
    fcode = """
subroutine test(ia,ib,ic)
integer, intent(in) :: ia,ib,ic
integer             :: a,b

multicond: select case (ic)
case (10) multicond
  a = 0
case (ia) multicond
  b = 0
case default multicond
  b = ib
end select multicond
end subroutine test
    """.strip()

    routine = Subroutine.from_source(fcode, frontend=frontend)
    mcond = FindNodes(ir.MultiConditional).visit(routine.body)[0]
    with dataflow_analysis_attached(routine):
        assert len(mcond.bodies) == 2
        assert len(mcond.else_body) == 1
        for b in mcond.bodies:
            assert len(b) == 1

        assert len(mcond.uses_symbols) == 3
        assert len(mcond.defines_symbols) == 2
        assert all(i in mcond.uses_symbols for i in ['ic', 'ia', 'ib'])
        assert all(i in mcond.defines_symbols for i in ['a', 'b'])

        assigns = FindNodes(ir.Assignment).visit(routine.body)
        for assign in assigns:
            assert assign.live_symbols == {'ia', 'ib', 'ic'}


@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OMNI, 'OMNI fails to read without full module')]))
def test_analyse_typeconditional(frontend):
    fcode = """
subroutine test(arg)
use type_mod, only: base_type, some_type, other_type
class(base_type), intent(in) :: arg
integer             :: a, b, c

typecond: select type(arg)
  class is(some_type)
    associate (aa => arg%s)
      a = aa
    end associate
  type is(other_type)
    associate (bb => arg%t)
      b = bb
    end associate
  class default
    c = 0
end select typecond
end subroutine test
    """.strip()

    routine = Subroutine.from_source(fcode, frontend=frontend)
    tcond = FindNodes(ir.TypeConditional).visit(routine.body)[0]
    with dataflow_analysis_attached(routine):
        assert len(tcond.bodies) == 2
        assert len(tcond.else_body) == 1
        for b in tcond.bodies:
            assert len(b) == 1

        assert tcond.uses_symbols == {'arg%t', 'arg%s', 'arg'}
        assert tcond.defines_symbols == {'a', 'b', 'c'}
        assert tcond.live_symbols == {'arg'}

        assigns = FindNodes(ir.Assignment).visit(routine.body)
        for assign in assigns:
            assert assign.live_symbols == {'arg'}


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('include_literal_kinds', [True, False])
def test_analyse_maskedstatement(frontend, include_literal_kinds):
    fcode = """
subroutine masked_statements(n, mask, vec1, vec2)
  use iso_fortran_env, only : int32
  integer, intent(in) :: n
  integer, intent(in), dimension(n) :: mask
  real, intent(out), dimension(n) :: vec1,vec2

  where (mask(:) < -5_int32)
    vec1(:) = -5.0
    vec1(:) = vec1(:) -5.0
  elsewhere (mask(:) > 5_int32)
    vec1(:) =  5.0
  elsewhere
    vec1(:) = 0.0
  endwhere

end subroutine masked_statements
    """.strip()

    routine = Subroutine.from_source(fcode, frontend=frontend)
    mask = FindNodes(ir.MaskedStatement).visit(routine.body)[0]
    num_bodies = len(mask.bodies)

    # We skip the context manager here to test the "include_literal_kinds" option
    DataflowAnalysisAttacher(include_literal_kinds=include_literal_kinds).visit(routine.body)

    if include_literal_kinds:
        assert len(mask.uses_symbols) == 2
        assert 'int32' in mask.uses_symbols
    else:
        assert len(mask.uses_symbols) == 1
        assert not 'int32' in mask.uses_symbols
    assert len(mask.defines_symbols) == 1
    assert 'mask' in mask.uses_symbols
    assert 'vec1' in mask.defines_symbols

    DataflowAnalysisDetacher().visit(routine.body)

    assert len(mask.bodies) == num_bodies


@pytest.mark.parametrize('frontend', available_frontends())
def test_analyse_whileloop(frontend):
    fcode = """
subroutine while_loop(flag)
   implicit none

   logical, intent(in) :: flag
   integer :: ij
   real :: a(10)

   if(flag)then
      ij = 0
      do while(ij .lt. 10)
          ij = ij + 1
          a(ij) = 0.
      enddo
   endif

end subroutine while_loop
"""

    routine = Subroutine.from_source(fcode, frontend=frontend)
    loop = FindNodes(ir.WhileLoop).visit(routine.body)[0]
    cond = FindNodes(ir.Conditional).visit(routine.body)[0]
    with dataflow_analysis_attached(routine):
        assert len(cond.uses_symbols) == 1
        assert 'flag' in cond.uses_symbols
        assert len(loop.uses_symbols) == 1
        assert len(loop.defines_symbols) == 2
        assert 'ij' in loop.uses_symbols
        assert all(v in loop.defines_symbols for v in ('ij', 'a'))

    with dataflow_analysis_attached(cond):
        assert len(loop.uses_symbols) == 1
        assert len(loop.defines_symbols) == 2
        assert 'ij' in loop.uses_symbols
        assert all(v in loop.defines_symbols for v in ('ij', 'a'))


@pytest.mark.parametrize('frontend', available_frontends())
def test_analyse_associate(frontend):

    fcode = """
subroutine associate_test(a, b, c, in_var)
   implicit none

   real, intent(in) :: in_var
   real, intent(inout) :: a, b, c

   associate(d=>a, e=>b, f=>c)
     e = in_var
     f = in_var
     associate(d0=>d)
       d0 = in_var
     end associate
   end associate

end subroutine associate_test
"""

    routine = Subroutine.from_source(fcode, frontend=frontend)
    associates = FindNodes(ir.Associate).visit(routine.body)
    assigns = FindNodes(ir.Assignment).visit(routine.body)
    with dataflow_analysis_attached(routine):
        # check that associates use variables names in outer scope
        assert associates[0].uses_symbols == {'in_var'}
        assert associates[0].defines_symbols == {'a', 'b', 'c'}

        assert associates[1].uses_symbols == {'in_var'}
        assert associates[1].defines_symbols == {'d'}

        # check that assignments use associated symbols
        assert assigns[0].uses_symbols == {'in_var'}
        assert assigns[1].uses_symbols == {'in_var'}
        assert assigns[2].uses_symbols == {'in_var'}

        assert assigns[0].defines_symbols == {'e'}
        assert assigns[1].defines_symbols == {'f'}
        assert assigns[2].defines_symbols == {'d0'}


@pytest.mark.parametrize('frontend', available_frontends())
def test_analyse_derived_types(frontend, tmp_path):
    """
    Test dataflow analysis on nested derived-types.
    """

    fcode = r"""
module my_mod
   implicit none

   type :: my_sub_type
      real, allocatable :: c(:)
   end type

   type :: my_type
      type(my_sub_type), allocatable :: b(:)
   end type

contains

subroutine kernel(a, d)
   type(my_type), intent(inout) :: a
   type(my_type), intent(in) :: d
   integer :: i

   do i=1,10
     A%B(i)%C(:) = D%B(i)%C(:)
   enddo

end subroutine

end module
"""

    source = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    routine = source['kernel']

    with dataflow_analysis_attached(routine):
        assert routine.body.defines_symbols == {'a%b%c'}
        assert routine.body.uses_symbols == {'d%b%c'}
loki-ecmwf-0.3.6/loki/analyse/tests/test_util_linear_algebra.py0000664000175000017500000001435415167130205025101 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
from math import gcd as math_gcd
import pytest
import numpy as np

try:
    _ = math_gcd(4,3,2)
    gcd = math_gcd
except TypeError: #Python 3.8 can only handle two arguments
    from functools import reduce
    def gcd(*args):
        return reduce(math_gcd, args)

from loki.analyse.util_linear_algebra import (
    back_substitution,
    generate_row_echelon_form,
    is_independent_system,
    yield_one_d_systems,
)


@pytest.mark.parametrize(
    "upper_triangular_square_matrix, right_hand_side, expected, divison_operation",
    [
        (
            [[2, 1, -1], [0, 0.5, 0.5], [0, 0, -1]],
            [[8], [1], [1]],
            [[2], [3], [-1]],
            lambda x, y: x / y,
        ),
        (
            [[2, 0], [0, 1]],
            [[10], [11]],
            [[5], [11]],
            lambda x, y: x // y,
        ),
    ],
)
def test_backsubstitution(
    upper_triangular_square_matrix, right_hand_side, expected, divison_operation
):
    assert np.allclose(
        back_substitution(
            np.array(upper_triangular_square_matrix),
            np.array(right_hand_side),
            divison_operation,
        ),
        np.array(expected),
    )


@pytest.mark.parametrize(
    "matrix, result",
    [
        ([[2, 0, 1], [0, 2, 0]], [[1, 0, 0.5], [0, 1, 0]]),
        ([[1, -2, 1, 0], [3, 2, 1, 5]], [[1, -2, 1, 0], [0, 1, -0.25, 0.625]]),
        ([[1, -1, -10]], [[1, -1, -10]]),
        ([[0, 0, 0], [0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0], [0, 0, 0]]),
        ([[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[1, 2, 3], [0, 1, 2], [0, 0, 0]]),
        ([[0, 1, 0], [0, 0, 1], [0, 0, 0]], [[0, 1, 0], [0, 0, 1], [0, 0, 0]]),
        (
            [[2, 4, 6, 8], [1, 2, 3, 4], [3, 6, 9, 12]],
            [[1, 2, 3, 4], [0, 0, 0, 0], [0, 0, 0, 0]],
        ),
        ([[0, 0, 0], [1, 0, 2]], [[1, 0, 2], [0, 0, 0]]),
    ],
)
def test_generate_row_echelon_form(matrix, result):
    matrix = np.array(matrix, dtype=float)
    result = np.array(result, dtype=float)

    assert np.allclose(generate_row_echelon_form(matrix), result)


@pytest.mark.parametrize(
    "matrix, result",
    [
        ([[]], [[]]),
        ([[2, 0, 1], [0, 2, 0]], [[1, 0, 0], [0, 1, 0]]),
        ([[1, -2, 1, 0], [3, 2, 1, 5]], [[1, -2, 1, 0], [0, 1, -1, 0]]),
        ([[1, -1, -10]], [[1, -1, -10]]),
        ([[0, 0, 0], [0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0], [0, 0, 0]]),
        ([[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[1, 2, 3], [0, 1, 2], [0, 0, 0]]),
        ([[0, 1, 0], [0, 0, 1], [0, 0, 0]], [[0, 1, 0], [0, 0, 1], [0, 0, 0]]),
        (
            [[2, 4, 6, 8], [1, 2, 3, 4], [3, 6, 9, 12]],
            [[1, 2, 3, 4], [0, 0, 0, 0], [0, 0, 0, 0]],
        ),
    ],
)
def test_enforce_integer_arithmetics_for_row_echelon_form(matrix, result):
    matrix = np.array(matrix, dtype=float)
    result = np.array(result, dtype=float)

    assert np.allclose(
        generate_row_echelon_form(matrix, division_operator=lambda x, y: x // y), result
    )


def _raise_assertion_error(A):
    raise ValueError()


def _require_gcd_condition(A):
    """Check that gcd condition of linear Diophantine equation is satisfied"""
    if A[0, -1] % gcd(*A[0, :-1]) != 0:
        raise ValueError()


@pytest.mark.parametrize(
    "matrix, condition, result",
    [
        ([[1, 2, 3], [4, 5, 6]], _raise_assertion_error, None),
        (
            [[2, 0, 0, -2, -20], [0, 2, -2, 0, -22]],
            _require_gcd_condition,
            [[1, 0, 0, -1, -10], [0, 1, -1, 0, -11]],
        ),
        ([[2, 0, 0, -2, -20], [0, 2, -2, 0, -21]], _require_gcd_condition, None),
    ],
)
def test_require_conditions(matrix, condition, result):
    matrix = np.array(matrix)

    if result is None:
        with pytest.raises(ValueError):
            _ = generate_row_echelon_form(matrix, conditional_check=condition)
    else:
        result = np.array(result)
        assert np.allclose(
            generate_row_echelon_form(matrix, conditional_check=condition), result
        )


@pytest.mark.parametrize(
    "matrix, expected_result",
    [
        (np.array([[1, 0], [0, 1], [0, 0]]), True),
        (np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), True),
        (np.array([[1, 0, 1], [0, 1, 0], [0, 0, 0]]), False),
        (np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0]]), True),
        (np.array([[1, 0, 0], [0, 0, 1], [0, 1, 0]]), True),
    ],
)
def test_is_independent_system(matrix, expected_result):
    assert is_independent_system(matrix) == expected_result


@pytest.mark.parametrize(
    "matrix, rhs, list_of_lhs_column, list_of_rhs_column",
    [
        (
            np.array([[1, 0], [0, 1], [0, 0]]),
            np.array([[1], [2], [0]]),
            [np.array([[0]]), np.array([[0]]), np.array([[1]]), np.array([[1]])],
            [np.array([[0]]), np.array([[0]]), np.array([[1]]), np.array([[2]])],
        ),
        (
            np.array([[1, 0], [0, 1], [0, 0]]),
            np.array([[1], [2], [1]]),
            [np.array([[0]]), np.array([[0]]), np.array([[1]]), np.array([[1]])],
            [np.array([[1]]), np.array([[1]]), np.array([[1]]), np.array([[2]])],
        ),
        (
            np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0]]),
            np.array([[1], [2], [3]]),
            [np.array([[0], [0], [0]])] * 3,
            [np.array([[1], [2], [3]])] * 3,
        ),
        (  # will even split non independent systems, call is_independent_system before
            np.array([[2, 1], [1, 3]]),
            np.array([[3], [4]]),
            [np.array([[2], [1]]), np.array([[1], [3]])],
            [np.array([[3], [4]]), np.array([[3], [4]])],
        ),
    ],
)
def test_yield_one_d_systems(matrix, rhs, list_of_lhs_column, list_of_rhs_column):
    results = list(yield_one_d_systems(matrix, rhs))
    assert len(results) == len(list_of_lhs_column) == len(list_of_rhs_column)
    for index, (A, b) in enumerate(results):
        assert np.array_equal(A, list_of_lhs_column[index])
        assert np.array_equal(b, list_of_rhs_column[index])
loki-ecmwf-0.3.6/loki/analyse/util_linear_algebra.py0000664000175000017500000001571615167130205022703 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import numpy as np

__all__ = [
    "back_substitution",
    "generate_row_echelon_form",
    "is_independent_system",
    "yield_one_d_systems",
]


def is_independent_system(matrix):
    """
    Check if a linear system of equations can be split into independent one-dimensional problems.

    Parameters
    ----------
    matrix : numpy.ndarray
        A rectangular matrix representing coefficients.

    Returns
    -------
    bool
        True if the system can be split into independent one-dimensional problems, False otherwise.

    Notes
    -----
    This function checks whether a linear system of equations in the form of `matrix [operator] right_hand_side`
    can be split into independent one-dimensional problems. The number of problems is determined by the
    number of variables (the row number of the matrix).

    Each problem consists of a coefficient vector and a right-hand side. The system can be considered independent
    if each row of the matrix has exactly one non-zero coefficient or no non-zero coefficients.
    """

    return np.all(np.isin(np.sum(matrix != 0, axis=1), [0, 1]))


def yield_one_d_systems(matrix, right_hand_side):
    """
    Split a linear system of equations (<=, >=, or ==) into independent one-dimensional problems.

    Parameters
    ----------
    matrix : numpy.ndarray
        A rectangular matrix representing coefficients.
    right_hand_side : numpy.ndarray
        The right-hand side vector.

    Yields
    ------
    tuple[numpy.ndarray, numpy.ndarray]
        A tuple containing a coefficient vector and the corresponding right-hand side.

    Notes
    -----
    The independence of the problems is NOT explicitly checked; call `is_independent_system` before using this
    function if unsure.

    This function takes a linear system of equations in the form of `matrix [operator] right_hand_side`,
    where "matrix" is a rectangular matrix, "x" is a vector of variables, and "right_hand_side" is
    the right-hand side vector. It splits the system into assumed independent one-dimensional problems.

    Each problem consists of a coefficient vector and a right-hand side. The number of problems is equal to the
    number of variables (the row number of the matrix).

    Example
    -------

    .. code-block:: python

        for A, b in yield_one_d_systems(matrix, right_hand_side):
            # Solve the one-dimensional problem A * x = b
            solution = solve_one_d_system(A, b)
    """
    # yield systems with empty left hand side (A) and non empty right hand side
    mask = np.all(matrix == 0, axis=1)
    if right_hand_side[mask].size != 0:
        for A in matrix[mask].T:
            yield A.reshape((-1,1)), right_hand_side[mask]

    matrix = matrix[~mask]
    right_hand_side = right_hand_side[~mask]

    if right_hand_side.size != 0:
        for A in matrix.T:
            mask = A != 0
            yield A[mask].reshape((-1,1)), right_hand_side[mask]


def back_substitution(
    upper_triangular_square_matrix,
    right_hand_side,
    divison_operation=lambda x, y: x / y,
):
    """
    Solve a linear system of equations using back substitution for an upper triangular square matrix.

    Parameters
    ----------
    upper_triangular_square_matrix : numpy.ndarray
        An upper triangular square matrix (R).

    right_hand_side : numpy.ndarray
        A vector (y) on the right-hand side of the equation Rx = y.

    division_operation : function, optional
        A custom division operation function. Default is standard division (/).

    Returns
    -------
    numpy.ndarray
        The solution vector (x) to the system of equations Rx = y.

    Notes
    -----
    The function performs back substitution to find the solution vector x for the equation Rx = y,
    where R is an upper triangular square matrix and y is a vector. The division_operation
    function is used for division (e.g., for custom division operations).

    The function assumes that the upper right element of the upper_triangular_square_matrix (R)
    is nonzero for proper back substitution.
    """
    R = upper_triangular_square_matrix
    y = right_hand_side

    x = np.zeros_like(y)

    assert R[-1, -1] != 0

    x[-1] = divison_operation(y[-1], R[-1, -1])

    for i in range(len(y) - 2, -1, -1):
        x[i] = divison_operation((y[i] - np.dot(R[i, i + 1 :], x[i + 1 :])), R[i, i])

    return x


def generate_row_echelon_form(
    A, conditional_check=lambda A: None, division_operator=lambda x, y: x / y
):
    """
    Calculate the Row Echelon Form (REF) of a matrix.

    Parameters
    ----------
    A : numpy.ndarray
        The input matrix for which the REF is to be calculated.
    conditional_check : function, optional
        A custom function to check conditions during the computation.
    division_operation : function, optional
        A custom division operation function. Default is standard division (/).

    Returns
    -------
    numpy.ndarray
        The REF of the input matrix A.

    Notes
    -----
    - If the input matrix has no rows or columns, it is already in REF, and the function returns itself.
    - The function utilizes the specified division operation (default is standard division) for division.

    Reference
    ---------
    https://math.stackexchange.com/a/3073117
    for question:
    https://math.stackexchange.com/questions/3073083/how-to-reduce-matrix-into-row-echelon-form-in-numpy
    """
    # if matrix A has no columns or rows,
    # it is already in REF, so we return itself
    r, c = A.shape
    if r == 0 or c == 0:
        return A

    # we search for non-zero element in the first column
    for i in range(len(A)):
        if A[i, 0] != 0:
            break
    else:
        # if all elements in the first column is zero,
        # we perform REF on matrix from second column
        B = generate_row_echelon_form(A[:, 1:], conditional_check, division_operator)
        # and then add the first zero-column back
        return np.hstack([A[:, :1], B])

    # if non-zero element happens not in the first row,
    # we switch rows
    if i > 0:
        A[[i, 0]] = A[[0, i]]

    # check condition
    conditional_check(A)

    # we divide first row by first element in it
    A[0] = division_operator(A[0], A[0, 0])
    # we subtract all subsequent rows with first row (it has 1 now as first element)
    # multiplied by the corresponding element in the first column
    A[1:] -= A[0] * A[1:, 0:1]

    # we perform REF on matrix from second row, from second column
    B = generate_row_echelon_form(A[1:, 1:], conditional_check, division_operator)

    # we add first row and first (zero) column, and return
    return np.vstack([A[:1], np.hstack([A[1:, :1], B])])
loki-ecmwf-0.3.6/loki/analyse/analyse_dataflow.py0000664000175000017500000006777715167130205022252 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Collection of dataflow analysis schema routines.
"""

from contextlib import contextmanager
from loki.expression import Array, ProcedureSymbol
from loki.ir.expr_visitors import FindLiterals
from loki.tools import as_tuple, flatten, OrderedSet
from loki.types import BasicType
from loki.ir import (
    Visitor, Transformer, FindVariables, FindInlineCalls, FindTypedSymbols
)
from loki.subroutine import Subroutine
from loki.tools.util import CaseInsensitiveDict

__all__ = [
    'dataflow_analysis_attached', 'read_after_write_vars',
    'loop_carried_dependencies'
]


def strip_nested_dimensions(expr):
    """
    Strip dimensions from array expressions of arbitrary derived-type
    nesting depth.
    """

    parent = expr.parent
    if parent:
        parent = strip_nested_dimensions(parent)
    return expr.clone(dimensions=None, parent=parent)


class DataflowAnalysisAttacher(Transformer):
    """
    Analyse and attach in-place the definition, use and live status of
    symbols.

    Parameters
    ----------
    include_literal_kinds : bool (default : True)
       Include kind specifiers for literals in dataflow analysis.
    """

    # group of functions that only query memory properties and don't read/write variable value
    _mem_property_queries = ('size', 'lbound', 'ubound', 'present')

    def __init__(self, include_literal_kinds=True, **kwargs):
        super().__init__(inplace=True, invalidate_source=False, **kwargs)
        self.include_literal_kinds = include_literal_kinds

    # Utility routines

    def _visit_body(self, body, live=None, defines=None, uses=None, **kwargs):
        """
        Iterate through the tuple that is a body and update defines and
        uses along the way.
        """
        if live is None:
            live = OrderedSet()
        if defines is None:
            defines = OrderedSet()
        if uses is None:
            uses = OrderedSet()
        visited = []
        for i in flatten(body):
            visited += [self.visit(i, live_symbols=live|defines, **kwargs)]
            uses |= visited[-1].uses_symbols.copy() - defines
            defines |= visited[-1].defines_symbols.copy()
        return as_tuple(visited), defines, uses

    @staticmethod
    def _symbols_from_expr(expr, condition=None):
        """
        Return set of symbols found in an expression.
        """
        variables = OrderedSet(strip_nested_dimensions(v) for v in FindVariables().visit(expr))
        parents = OrderedSet(p for var in variables for p in var.parents)
        variables -= parents
        if condition is not None:
            return OrderedSet(v for v in variables if condition(v))
        return variables

    @classmethod
    def _symbols_from_lhs_expr(cls, expr):
        """
        Determine symbol use and symbol definition from a left-hand side expression.

        Parameters
        ----------
        expr : :any:`Scalar` or :any:`Array`
            The left-hand side expression of an assignment.

        Returns
        -------
        (defines, uses) : (set, set)
            The sets of defined and used symbols (in that order).
        """
        defines = {strip_nested_dimensions(expr)}
        uses = cls._symbols_from_expr(getattr(expr, 'dimensions', ()))
        return defines, uses

    # Abstract node (also called from every node type for integration)

    def visit_Node(self, o, **kwargs):
        # Live symbols are determined on InternalNode handler levels and
        # get passed down to all child nodes
        o._update(_live_symbols=kwargs.get('live_symbols', OrderedSet()))

        # Symbols defined or used by this node are determined by their individual
        # handler routines and passed on to visitNode from there
        o._update(_defines_symbols=kwargs.get('defines_symbols', OrderedSet()))
        o._update(_uses_symbols=kwargs.get('uses_symbols', OrderedSet()))
        return o

    # Internal nodes

    def visit_Interface(self, o, **kwargs):
        # Subroutines/functions calls defined in an explicit interface
        defines = OrderedSet()
        for b in o.body:
            if isinstance(b, Subroutine):
                defines = defines | OrderedSet(as_tuple(b.procedure_symbol))
        return self.visit_Node(o, defines_symbols=defines, **kwargs)

    def visit_InternalNode(self, o, **kwargs):
        # An internal node defines all symbols defined by its body and uses all
        # symbols used by its body before they are defined in the body
        live = kwargs.pop('live_symbols', OrderedSet())
        body, defines, uses = self._visit_body(o.body, live=live, **kwargs)
        o._update(body=body)
        return self.visit_Node(o, live_symbols=live, defines_symbols=defines, uses_symbols=uses, **kwargs)

    def visit_Associate(self, o, **kwargs):
        # An associate block defines all symbols defined by its body and uses all
        # symbols used by its body before they are defined in the body
        live = kwargs.pop('live_symbols', OrderedSet())
        body, defines, uses = self._visit_body(o.body, live=live, **kwargs)
        o._update(body=body)

        # reverse the mapping of names before assinging lives, defines, uses sets for Associate node itself
        invert_assoc = CaseInsensitiveDict({v.name: k for k, v in o.associations})
        _live = OrderedSet(invert_assoc[v.name] if v.name in invert_assoc else v for v in live)
        _defines = OrderedSet(invert_assoc[v.name] if v.name in invert_assoc else v for v in defines)
        _uses = OrderedSet(invert_assoc[v.name] if v.name in invert_assoc else v for v in uses)

        return self.visit_Node(o, live_symbols=_live, defines_symbols=_defines, uses_symbols=_uses, **kwargs)

    def visit_Loop(self, o, **kwargs):
        # A loop defines the induction variable for its body before entering it
        live = kwargs.pop('live_symbols', OrderedSet())
        mem_calls = as_tuple(i for i in FindInlineCalls().visit(o.bounds) if i.function in self._mem_property_queries)
        query_args = as_tuple(flatten(FindVariables().visit(i.parameters) for i in mem_calls))
        uses = self._symbols_from_expr(o.bounds)
        uses = OrderedSet(v for v in uses if not v in query_args)
        body, defines, uses = self._visit_body(o.body, live=live|{o.variable.clone()}, uses=uses, **kwargs)
        o._update(body=body)
        # Make sure the induction variable is not considered outside the loop
        uses.discard(o.variable)
        defines.discard(o.variable)
        return self.visit_Node(o, live_symbols=live, defines_symbols=defines, uses_symbols=uses, **kwargs)

    def visit_WhileLoop(self, o, **kwargs):
        # A while loop uses variables in its condition
        live = kwargs.pop('live_symbols', OrderedSet())
        uses = self._symbols_from_expr(o.condition)
        body, defines, uses = self._visit_body(o.body, live=live, uses=uses, **kwargs)
        o._update(body=body)
        return self.visit_Node(o, live_symbols=live, defines_symbols=defines, uses_symbols=uses, **kwargs)

    def visit_Conditional(self, o, **kwargs):
        live = kwargs.pop('live_symbols', OrderedSet())

        # exclude arguments to functions that just check the memory attributes of a variable
        mem_call = as_tuple(i for i in FindInlineCalls().visit(o.condition) if i.function in self._mem_property_queries)
        query_args = as_tuple(flatten(FindVariables().visit(i.parameters) for i in mem_call))
        cset = OrderedSet(v for v in FindVariables().visit(o.condition) if not v in query_args)

        if not self.include_literal_kinds:
            # Filter out any symbols used to qualify literals e.g. 0._JPRB
            literals = FindLiterals().visit(o.condition)
            literal_vars = FindVariables().visit(literals)
            cset -= OrderedSet(literal_vars)

        condition = self._symbols_from_expr(as_tuple(cset))
        body, defines, uses = self._visit_body(o.body, live=live, uses=condition, **kwargs)
        else_body, else_defines, uses = self._visit_body(o.else_body, live=live, uses=uses, **kwargs)
        o._update(body=body, else_body=else_body)
        return self.visit_Node(o, live_symbols=live, defines_symbols=defines|else_defines, uses_symbols=uses, **kwargs)

    def visit_MultiConditional(self, o, **kwargs):
        live = kwargs.pop('live_symbols', OrderedSet())

        # exclude arguments to functions that just check the memory attributes of a variable
        mem_calls = as_tuple(i for i in FindInlineCalls().visit(o.expr) if i.function in self._mem_property_queries)
        query_args = as_tuple(flatten(FindVariables().visit(i.parameters) for i in mem_calls))
        eset = OrderedSet(v for v in FindVariables().visit(o.expr) if not v in query_args)

        mem_calls = as_tuple(i for i in FindInlineCalls().visit(o.values) if i.function in self._mem_property_queries)
        query_args = as_tuple(flatten(FindVariables().visit(i.parameters) for i in mem_calls))
        vset = OrderedSet(v for v in FindVariables().visit(o.values) if not v in query_args)

        uses = self._symbols_from_expr(as_tuple(eset)) | self._symbols_from_expr(as_tuple(vset))
        body = ()
        defines = OrderedSet()
        for b in o.bodies:
            _b, _d, uses = self._visit_body(b, live=live, uses=uses, **kwargs)
            body += (as_tuple(_b),)
            defines |= _d
        else_body, else_defines, uses = self._visit_body(o.else_body, live=live, uses=uses, **kwargs)
        o._update(bodies=body, else_body=else_body)
        defines = defines | else_defines
        return self.visit_Node(o, live_symbols=live, defines_symbols=defines, uses_symbols=uses, **kwargs)

    visit_TypeConditional = visit_MultiConditional

    def visit_MaskedStatement(self, o, **kwargs):
        live = kwargs.pop('live_symbols', OrderedSet())

        conditions = self._symbols_from_expr(o.conditions)
        if not self.include_literal_kinds:
            # Filter out any symbols used to qualify literals e.g. 0._JPRB
            literals = as_tuple(FindLiterals().visit(o.conditions))
            literal_vars = FindVariables().visit(literals)
            conditions -= OrderedSet(literal_vars)

        body = ()
        defines = OrderedSet()
        uses = OrderedSet(conditions)
        for b in o.bodies:
            _b, defines, uses = self._visit_body(b, live=live, uses=uses, defines=defines, **kwargs)
            body += (_b,)

        default, default_defs, uses = self._visit_body(o.default, live=live, uses=uses, **kwargs)
        o._update(bodies=body, default=default)
        return self.visit_Node(o, live_symbols=live, defines_symbols=defines|default_defs, uses_symbols=uses, **kwargs)

    # Leaf nodes

    def visit_Assignment(self, o, **kwargs):
        # exclude arguments to functions that just check the memory attributes of a variable
        mem_calls = as_tuple(i for i in FindInlineCalls().visit(o.rhs) if i.function in self._mem_property_queries)
        query_args = as_tuple(flatten(FindVariables().visit(i.parameters) for i in mem_calls))
        rset = OrderedSet(v for v in FindVariables().visit(o.rhs) if not v in query_args)

        if not self.include_literal_kinds:
            # Filter out any symbols used to qualify literals e.g. 0._JPRB
            literals = FindLiterals().visit(o.rhs)
            literal_vars = FindVariables().visit(literals)
            rset -= OrderedSet(literal_vars)

        # The left-hand side variable is defined by this statement
        defines, uses = self._symbols_from_lhs_expr(o.lhs)

        # Anything on the right-hand side is used before assigning to it
        uses |= self._symbols_from_expr(as_tuple(rset))
        return self.visit_Node(o, defines_symbols=defines, uses_symbols=uses, **kwargs)

    def visit_ConditionalAssignment(self, o, **kwargs):
        # The left-hand side variable is defined by this statement
        defines, uses = self._symbols_from_lhs_expr(o.lhs)
        # Anything on the right-hand side is used before assigning to it
        uses |= self._symbols_from_expr((o.condition, o.rhs, o.else_rhs))
        return self.visit_Node(o, defines_symbols=defines, uses_symbols=uses, **kwargs)

    def visit_CallStatement(self, o, **kwargs):
        if o.routine is not BasicType.DEFERRED:
            # With a call context provided we can determine which arguments
            # are potentially defined and which are definitely only used by
            # this call
            defines, uses = OrderedSet(), OrderedSet()
            outvals = [val for arg, val in o.arg_iter() if str(arg.type.intent).lower() in ('inout', 'out')]
            invals = [val for arg, val in o.arg_iter() if str(arg.type.intent).lower() in ('inout', 'in')]

            arrays = [v for v in FindVariables().visit(outvals) if isinstance(v, Array)]
            dims = OrderedSet(v for a in arrays for v in self._symbols_from_expr(a.dimensions))
            for val in outvals:
                exprs = self._symbols_from_expr(val)
                defines |= OrderedSet(e for e in exprs if not e in dims)
                uses |= dims

            uses |= OrderedSet(s for val in invals for s in self._symbols_from_expr(val))
        else:
            # We don't know the intent of any of these arguments and thus have
            # to assume all of them are potentially used or defined by this
            # statement
            arrays = [v for v in FindVariables().visit(o.arguments) if isinstance(v, Array)]
            arrays += [v for arg, val in o.kwarguments for v in FindVariables().visit(val) if isinstance(v, Array)]

            dims = OrderedSet(v for a in arrays for v in FindVariables().visit(a.dimensions))
            defines = self._symbols_from_expr(o.arguments, condition=lambda x: x not in dims)
            for arg, val in o.kwarguments:
                defines |= self._symbols_from_expr(val, condition=lambda x: x not in dims)
            uses = defines.copy() | dims

        return self.visit_Node(o, defines_symbols=defines, uses_symbols=uses, **kwargs)

    def visit_Allocation(self, o, **kwargs):
        arrays = [v for v in FindVariables().visit(o.variables) if isinstance(v, Array)]
        dims = OrderedSet(v for a in arrays for v in FindVariables().visit(a.dimensions))
        defines = self._symbols_from_expr(o.variables, condition=lambda x: x not in dims)
        uses = self._symbols_from_expr(o.data_source or ()) | dims
        return self.visit_Node(o, defines_symbols=defines, uses_symbols=uses, **kwargs)

    def visit_Deallocation(self, o, **kwargs):
        defines = self._symbols_from_expr(o.variables)
        return self.visit_Node(o, defines_symbols=defines, **kwargs)

    visit_Nullify = visit_Deallocation

    def visit_Import(self, o, **kwargs):
        defines = OrderedSet(s.clone(dimensions=None) for s in FindTypedSymbols().visit(o.symbols or ())
                      if isinstance(s, ProcedureSymbol))
        return self.visit_Node(o, defines_symbols=defines, **kwargs)

    def visit_VariableDeclaration(self, o, **kwargs):
        defines = self._symbols_from_expr(o.symbols, condition=lambda v: v.type.initial is not None)
        uses = OrderedSet(v for a in o.symbols if isinstance(a, Array) for v in self._symbols_from_expr(a.dimensions))
        if o.symbols[0].type.kind:
            uses |= {o.symbols[0].type.kind}
        return self.visit_Node(o, defines_symbols=defines, uses_symbols=uses, **kwargs)


class DataflowAnalysisDetacher(Transformer):
    """
    Remove in-place any dataflow analysis properties.
    """

    def __init__(self, **kwargs):
        super().__init__(inplace=True, invalidate_source=False, **kwargs)

    def visit_Node(self, o, **kwargs):
        o._update(_live_symbols=None, _defines_symbols=None, _uses_symbols=None)
        return super().visit_Node(o, **kwargs)


def attach_dataflow_analysis(module_or_routine):
    """
    Determine and attach to each IR node dataflow analysis metadata.

    This makes for each IR node the following properties available:

    * :attr:`Node.live_symbols`: symbols defined before the node;
    * :attr:`Node.defines_symbols`: symbols (potentially) defined by the
      node, i.e., live in subsequent nodes;
    * :attr:`Node.uses_symbols`: symbols used by the node (that had to be
      defined before).

    The IR nodes are updated in-place and thus existing references to IR
    nodes remain valid.
    """
    live_symbols = OrderedSet()
    if hasattr(module_or_routine, 'arguments'):
        live_symbols = DataflowAnalysisAttacher._symbols_from_expr(
            module_or_routine.arguments,
            condition=lambda a: a.type.intent and a.type.intent.lower() in ('in', 'inout')
        )

    if hasattr(module_or_routine, 'spec'):
        DataflowAnalysisAttacher().visit(module_or_routine.spec, live_symbols=live_symbols)
        live_symbols |= module_or_routine.spec.defines_symbols

    if hasattr(module_or_routine, 'body'):
        DataflowAnalysisAttacher().visit(module_or_routine.body, live_symbols=live_symbols)


def detach_dataflow_analysis(module_or_routine):
    """
    Remove from each IR node the stored dataflow analysis metadata.

    Accessing the relevant attributes afterwards raises :py:class:`RuntimeError`.
    """
    if hasattr(module_or_routine, 'spec'):
        DataflowAnalysisDetacher().visit(module_or_routine.spec)
    if hasattr(module_or_routine, 'body'):
        DataflowAnalysisDetacher().visit(module_or_routine.body)


@contextmanager
def dataflow_analysis_attached(module_or_routine):
    r"""
    Create a context in which information about defined, live and used symbols
    is attached to each IR node

    This makes for each IR node the following properties available:

    * :attr:`Node.live_symbols`: symbols defined before the node;
    * :attr:`Node.defines_symbols`: symbols (potentially) defined by the
      node;
    * :attr:`Node.uses_symbols`: symbols used by the node that had to be
      defined before.

    This is an in-place update of nodes and thus existing references to IR
    nodes remain valid. When leaving the context the information is removed
    from IR nodes, while existing references remain valid.

    The analysis is based on a rather crude regions-based analysis, with the
    hierarchy implied by (nested) :any:`InternalNode` IR nodes used as regions
    in the reducible flow graph (cf. Chapter 9, in particular 9.7 of Aho, Lam,
    Sethi, and Ulliman (2007)). Our implementation shares some similarities
    with a full reaching definitions dataflow analysis but is not quite as
    powerful.

    In reaching definitions dataflow analysis (cf. Chapter 9.2.4 Aho et. al.),
    the transfer function of a definition :math:`d` can be expressed as:

    .. math:: f_d(x) = \operatorname{gen}_d \cup (x - \operatorname{kill}_d)

    with the set of definitions generated :math:`\operatorname{gen}_d` and the
    set of definitions killed/invalidated :math:`\operatorname{kill}_d`.

    We, however, do not record definitions explicitly and instead operate on
    consolidated sets of defined symbols, i.e., effectively evaluate the
    chained transfer functions up to the node. This yields a set of active
    definitions at this node. The symbols defined by these definitions are
    in :any:`Node.live_symbols`, and the symbols defined by the node (i.e.,
    symbols defined by definitions in :math:`\operatorname{gen}_d`) are in
    :any:`Node.defines_symbols`.

    The advantage of this approach is that it avoids the need to introduce
    a layer for definitions and dependencies. A downside is that this focus
    on symbols instead of definitions precludes, in particular, the ability
    to take data space into account, which makes it less useful for arrays.

    .. note::
        The context manager operates only on the module or routine itself
        (i.e., its spec and, if applicable, body), not on any contained
        subroutines or functions.

    Parameters
    ----------
    module_or_routine : :any:`Module` or :any:`Subroutine`
        The object for which the IR is to be annotated.
    """
    attach_dataflow_analysis(module_or_routine)
    try:
        yield module_or_routine
    finally:
        detach_dataflow_analysis(module_or_routine)


class FindReads(Visitor):
    """
    Look for reads in a specified part of a control flow tree.

    Parameters
    ----------
    start : (iterable of) :any:`Node`, optional
        Visitor is only active after encountering one of the nodes in
        :data:`start` and until encountering a node in :data:`stop`.
    stop : (iterable of) :any:`Node`, optional
        Visitor is no longer active after encountering one of the nodes in
        :data:`stop` until it encounters again a node in :data:`start`.
    active : bool, optional
        Set the visitor active right from the beginning.
    candidate_set : set of :any:`Node`, optional
        If given, only reads for symbols in this set are considered.
    clear_candidates_on_write : bool, optional
        If enabled, writes of a symbol remove it from the :data:`candidate_set`.
    """

    def __init__(self, start=None, stop=None, active=False,
                 candidate_set=None, clear_candidates_on_write=False, **kwargs):
        super().__init__(**kwargs)
        self.start = OrderedSet(as_tuple(start))
        self.stop = OrderedSet(as_tuple(stop))
        self.active = active
        self.candidate_set = candidate_set
        self.clear_candidates_on_write = clear_candidates_on_write
        self.reads = OrderedSet()

    @staticmethod
    def _symbols_from_expr(expr):
        """
        Return set of symbols found in an expression.
        """
        return {v.clone(dimensions=None) for v in FindVariables().visit(expr)}

    def _register_reads(self, read_symbols):
        if self.active:
            if self.candidate_set is None:
                self.reads |= read_symbols
            else:
                self.reads |= read_symbols & self.candidate_set

    def _register_writes(self, write_symbols):
        if self.active and self.clear_candidates_on_write and self.candidate_set is not None:
            self.candidate_set -= write_symbols

    def visit(self, o, *args, **kwargs):
        self.active = (self.active and o not in self.stop) or o in self.start
        return super().visit(o, *args, **kwargs)

    def visit_object(self, o, **kwargs):  # pylint: disable=unused-argument
        pass

    def visit_LeafNode(self, o, **kwargs):  # pylint: disable=unused-argument
        self._register_reads(o.uses_symbols)
        self._register_writes(o.defines_symbols)

    def visit_Conditional(self, o, **kwargs):
        self._register_reads(self._symbols_from_expr(o.condition))
        # Visit each branch with the original candidate set and then take the
        # union of both afterwards to include all potential read-after-writes
        candidate_set = self.candidate_set.copy() if self.candidate_set is not None else None
        self.visit(o.body, **kwargs)
        self.candidate_set, candidate_set = candidate_set, self.candidate_set
        self.visit(o.else_body, **kwargs)
        if self.candidate_set is not None:
            self.candidate_set |= candidate_set

    def visit_Loop(self, o, **kwargs):
        self._register_reads(self._symbols_from_expr(o.bounds))
        active = self.active
        if self.active and self.candidate_set is not None:
            # remove the loop variable as a variable of interest
            self.candidate_set.discard(o.variable)
        self.visit(o.children, **kwargs)
        if active:
            self.reads.discard(o.variable)

    def visit_WhileLoop(self, o, **kwargs):
        self._register_reads(self._symbols_from_expr(o.condition))
        self.visit(o.children, **kwargs)


class FindWrites(Visitor):
    """
    Look for writes in a specified part of a control flow tree.

    Parameters
    ----------
    start : (iterable of) :any:`Node`, optional
        Visitor is only active after encountering one of the nodes in
        :data:`start` and until encountering a node in :data:`stop`.
    stop : (iterable of) :any:`Node`, optional
        Visitor is no longer active after encountering one of the nodes in
        :data:`stop` until it encounters again a node in :data:`start`.
    active : bool, optional
        Set the visitor active right from the beginning.
    candidate_set : set of :any:`Node`, optional
        If given, only writes for symbols in this set are considered.
    """

    def __init__(self, start=None, stop=None, active=False,
                 candidate_set=None, **kwargs):
        super().__init__(**kwargs)
        self.start = OrderedSet(as_tuple(start))
        self.stop = OrderedSet(as_tuple(stop))
        self.active = active
        self.candidate_set = candidate_set
        self.writes = OrderedSet()

    @staticmethod
    def _symbols_from_expr(expr):
        """
        Return set of symbols found in an expression.
        """
        return {v.clone(dimensions=None) for v in FindVariables().visit(expr)}

    def _register_writes(self, write_symbols):
        if self.candidate_set is None:
            self.writes |= write_symbols
        else:
            self.writes |= write_symbols & self.candidate_set

    def visit(self, o, *args, **kwargs):
        self.active = (self.active and o not in self.stop) or o in self.start
        return super().visit(o, *args, **kwargs)

    def visit_object(self, o, **kwargs):  # pylint: disable=unused-argument
        pass

    def visit_LeafNode(self, o, **kwargs):  # pylint: disable=unused-argument
        if self.active:
            self._register_writes(o.defines_symbols)

    def visit_Loop(self, o, **kwargs):
        if self.active:
            # remove the loop variable as a variable of interest
            if self.candidate_set is not None:
                self.candidate_set.discard(o.variable)
            self.writes.discard(o.variable)
        super().visit_Node(o, **kwargs)


def read_after_write_vars(ir, inspection_node):
    """
    Find variables that are read after being written in the given IR.

    This requires prior application of :meth:`dataflow_analysis_attached` to
    the corresponding :any:`Module` or :any:`Subroutine`.

    The result is the set of variables with a data dependency across the
    :data:`inspection_node`.

    See the remarks about implementation and limitations in the description of
    :meth:`dataflow_analysis_attached`. In particular, this does not take into
    account data space and iteration space for arrays.

    Parameters
    ----------
    ir : :any:`Node`
        The root of the control flow (sub-)tree to inspect.
    inspection_node : :any:`Node`
        Only variables with a write before and a read at or after this node
        are considered.

    Returns
    -------
    :any:`set` of :any:`Scalar` or :any:`Array`
        The list of read-after-write variables.
    """
    write_visitor = FindWrites(stop=inspection_node, active=True)
    write_visitor.visit(ir)
    read_visitor = FindReads(start=inspection_node, candidate_set=write_visitor.writes,
                             clear_candidates_on_write=True)
    read_visitor.visit(ir)
    return read_visitor.reads


def loop_carried_dependencies(loop):
    """
    Find variables that are potentially loop-carried dependencies.

    This requires prior application of :meth:`dataflow_analysis_attached` to
    the corresponding :any:`Module` or :any:`Subroutine`.

    See the remarks about implementation and limitations in the description of
    :meth:`dataflow_analysis_attached`. In particular, this does not take into
    account data space and iteration space for arrays. For cases with a
    linear mapping from iteration to data space and no overlap, this will
    falsely report loop-carried dependencies when there are in fact none.
    However, the risk of false negatives should be low.

    Parameters
    ----------
    loop : :any:`Loop`
        The loop node to inspect.

    Returns
    -------
    :any:`set` of :any:`Scalar` or :any:`Array`
        The list of variables that potentially have a loop-carried dependency.
    """
    return loop.uses_symbols & loop.defines_symbols
loki-ecmwf-0.3.6/loki/analyse/util_polyhedron.py0000664000175000017500000003024415167130205022130 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from typing import List
import numpy as np
from loki.ir import Loop, FindVariables
from loki.expression import (
    symbols as sym, simplify, is_constant,
    accumulate_polynomial_terms
)
from loki.tools import as_tuple

__all__ = ["Polyhedron"]


class Polyhedron:
    """
    Halfspace representation of a (convex) polyhedron.

    A polyhedron `P c R^d` is described by a set of inequalities, in matrix form
    ```
    P = { x=[x1,...,xd]^T c R^d | Ax <= b }
    ```
    with n-by-d matrix `A` and d-dimensional right hand side `b`.

    In loop transformations, polyhedrons are used to represent iteration spaces of
    d-deep loop nests.


    Parameters
    ----------
    A : numpy.array
        The representation matrix A.
    b : numpy.array
        The right hand-side vector b.
    variables : list, optional
        List of variables representing the dimensions in the polyhedron.

    Attributes
    ----------
    A : numpy.array
        The representation matrix A.
    b : numpy.array
        The right hand-side vector b.
    variables : list, optional, default = None
    """

    def __init__(self, A, b, variables=None):
        A = np.array(A, dtype=np.dtype(int))
        b = np.array(b, dtype=np.dtype(int))
        assert A.ndim == 2 and b.ndim == 1
        assert A.shape[0] == b.shape[0]
        self.A = A
        self.b = b

        self.variables = None
        self.variable_names = None
        if variables is not None:
            assert len(variables) == A.shape[1]
            self.variables = variables
            self.variable_names = [v.name.lower() for v in self.variables]

    def __str__(self):
        str_A = "[" + ", ".join([str(row) for row in self.A]) + "]"
        str_b = f"[{', '.join(map(str, self.b))}]"
        str_variable_names = (
            f"[{', '.join(map(str, self.variable_names))}]"
            if self.variable_names
            else "[]"
        )

        return f"Polyhedron(\n\tA={str_A}, \n\tb={str_b}, \n\tvariables={str_variable_names}\n)"

    def __repr__(self):
        return str(self)

    def _has_satisfiable_constant_restrictions(self):
        """
        Check whether the constant restrictions of the polyhedron are satisfiable.

        This method checks if 0 <= b, assuming that A x = 0.

        Returns:
        bool: True if all constant restrictions are satisfiable, False otherwise.

        """

        return (0 <= self.b).all()

    def is_empty(self):
        """
        Determine whether a polyhedron is empty.

        A polyhedron is considered empty under the following conditions:
        1. It contains no inequalities.
        2. It spans no space, which is a nontrivial problem. The simplest case is when it has an empty
        matrix A and does not satisfy the constant restrictions 0 <= b.

        Notes
        -----
        An empty polyhedron implies that it has no valid solutions or feasible points within its boundaries.
        This function is expected to be called only for polyhedrons with an empty matrix.

        Returns
        -------
        bool
            True if the polyhedron is empty; False if it is not.
        """
        if self.A.size == 0:
            return self.b.size == 0 or not self._has_satisfiable_constant_restrictions()

        raise RuntimeError(
            """
            Checking if a polyhedron with a non-empty matrix spans no space is a nontrivial problem.
            This function is expected to be only called upon polyhedrons with an empty matrix!
            """
        )

    def variable_to_index(self, variable):
        if self.variable_names is None:
            raise RuntimeError("No variables list associated with polyhedron.")
        if isinstance(variable, sym.TypedSymbol):
            variable = variable.name.lower()
        assert isinstance(variable, str)
        return self.variable_names.index(variable)

    @staticmethod
    def _to_literal(value):
        if value < 0:
            return sym.Product((-1, sym.IntLiteral(abs(value))))
        return sym.IntLiteral(value)

    def lower_bounds(self, index_or_variable, ignore_variables=None):
        """
        Return all lower bounds imposed on a variable.

        The lower bounds for the variable `j` are given by the index set:

        ``
        L = {i | A_ij < 0, i in {0, ..., d-1}}
        ``

        Parameters
        ----------
        index_or_variable : int or str or sym.Array or sym.Scalar
            The index, name, or expression symbol for which the lower bounds are produced.
        ignore_variables : list or None, optional
            A list of variable names, indices, or symbols for which constraints should be ignored
            if they depend on one of them.

        Returns
        -------
        list
            The bounds for the specified variable.
        """
        if isinstance(index_or_variable, int):
            j = index_or_variable
        else:
            j = self.variable_to_index(index_or_variable)

        if ignore_variables:
            ignore_variables = [
                i if isinstance(i, int) else self.variable_to_index(i)
                for i in ignore_variables
            ]

        bounds = []
        for i in range(self.A.shape[0]):
            if self.A[i, j] < 0:
                if ignore_variables and any(
                    self.A[i, k] != 0 for k in ignore_variables
                ):
                    # Skip constraint that depends on any of the ignored variables
                    continue

                components = [
                    self._to_literal(self.A[i, k]) * self.variables[k]
                    for k in range(self.A.shape[1])
                    if k != j and self.A[i, k] != 0
                ]
                if not components:
                    lhs = sym.IntLiteral(0)
                elif len(components) == 1:
                    lhs = components[0]
                else:
                    lhs = sym.Sum(as_tuple(components))
                bounds += [
                    simplify(
                        sym.Quotient(
                            self._to_literal(self.b[i]) - lhs,
                            self._to_literal(self.A[i, j]),
                        )
                    )
                ]
        return bounds

    def upper_bounds(self, index_or_variable, ignore_variables=None):
        """
        Return all upper bounds imposed on a variable.

        The upper bounds for the variable `j` are given by the index set:
        ``
        U = {i | A_ij > 0, i in {0, ..., d-1}}
        ``

        Parameters
        ----------
        index_or_variable : int or str or sym.Array or sym.Scalar
            The index, name, or expression symbol for which the upper bounds are produced.
        ignore_variables : list or None, optional
            A list of variable names, indices, or symbols for which constraints should be ignored
            if they depend on one of them.

        Returns
        -------
        list
            The bounds for the specified variable.
        """
        if isinstance(index_or_variable, int):
            j = index_or_variable
        else:
            j = self.variable_to_index(index_or_variable)

        if ignore_variables:
            ignore_variables = [
                i if isinstance(i, int) else self.variable_to_index(i)
                for i in ignore_variables
            ]

        bounds = []
        for i in range(self.A.shape[0]):
            if self.A[i, j] > 0:
                if ignore_variables and any(
                    self.A[i, k] != 0 for k in ignore_variables
                ):
                    # Skip constraint that depends on any of the ignored variables
                    continue

                components = [
                    self._to_literal(self.A[i, k]) * self.variables[k]
                    for k in range(self.A.shape[1])
                    if k != j and self.A[i, k] != 0
                ]
                if not components:
                    lhs = sym.IntLiteral(0)
                elif len(components) == 1:
                    lhs = components[0]
                else:
                    lhs = sym.Sum(as_tuple(components))
                bounds += [
                    simplify(
                        sym.Quotient(
                            self._to_literal(self.b[i]) - lhs,
                            self._to_literal(self.A[i, j]),
                        )
                    )
                ]
        return bounds

    @staticmethod
    def generate_entries_for_lower_bound(bound, variables, index):
        """
        Helper routine to generate matrix and right-hand side entries for a given lower bound.

        Note that this routine can only handle affine bounds, which means expressions that are
        constant or can be reduced to a linear polynomial.

        Upper bounds can be derived from this by multiplying the left-hand side and right-hand side
        with -1.

        Parameters
        ----------
        bound : int or str or sym.Array or sym.Scalar
            The expression representing the lower bound.
        variables : list of str
            The list of variable names.
        index : int
            The index of the variable constrained by this bound.

        Returns
        -------
        tuple(np.array, np.array)
            The pair ``(lhs, rhs)`` of the row in the matrix inequality, where `lhs` is the left-hand side
            and `rhs` is the right-hand side.
        """
        supported_types = (sym.TypedSymbol, sym.MetaSymbol, sym.Sum, sym.Product)
        if not (is_constant(bound) or isinstance(bound, supported_types)):
            raise ValueError(f"Cannot derive inequality from bound {str(bound)}")
        summands = accumulate_polynomial_terms(bound)
        b = -summands.pop(1, 0)  # Constant term or 0
        A = np.zeros([1, len(variables)], dtype=np.dtype(int))
        A[0, index] = -1
        for base, coef in summands.items():
            if not len(base) == 1:
                raise ValueError(f"Non-affine bound {str(bound)}")
            A[0, variables.index(base[0].name.lower())] = coef
        return A, b

    @classmethod
    def from_loop_ranges(cls, loop_variables, loop_ranges):
        """
        Create polyhedron from a list of loop ranges and associated variables.
        """
        assert len(loop_ranges) == len(loop_variables)

        # Add any variables that are not loop variables to the vector of variables
        variables = list(loop_variables)
        variable_names = [v.name.lower() for v in variables]
        for v in sorted(
            FindVariables().visit(loop_ranges), key=lambda v: v.name.lower()
        ):
            if v.name.lower() not in variable_names:
                variables += [v]
                variable_names += [v.name.lower()]

        n = 2 * len(loop_ranges)
        d = len(variables)
        A = np.zeros([n, d], dtype=np.dtype(int))
        b = np.zeros([n], dtype=np.dtype(int))

        for i, (loop_variable, loop_range) in enumerate(
            zip(loop_variables, loop_ranges)
        ):
            assert loop_range.step is None or loop_range.step == "1"
            j = variables.index(loop_variable.name.lower())

            # Create inequality from lower bound
            lhs, rhs = cls.generate_entries_for_lower_bound(
                loop_range.start, variable_names, j
            )
            A[2 * i, :] = lhs
            b[2 * i] = rhs

            # Create inequality from upper bound
            lhs, rhs = cls.generate_entries_for_lower_bound(
                loop_range.stop, variable_names, j
            )
            A[2 * i + 1, :] = -lhs
            b[2 * i + 1] = -rhs

        return cls(A, b, variables)

    @classmethod
    def from_nested_loops(cls, nested_loops: List[Loop]):
        """
        Helper function, for creating a polyhedron from a list of loops.
        """
        return cls.from_loop_ranges(
            [l.variable for l in nested_loops], [l.bounds for l in nested_loops]
        )
loki-ecmwf-0.3.6/loki/module.py0000664000175000017500000002532415167130205016544 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Contains the declaration of :any:`Module` to represent Fortran modules.
"""
from loki.frontend import (
    get_fparser_node, parse_omni_ast, parse_fparser_ast,
    parse_regex_source
)
from loki.ir import (
    VariableDeclaration, pragmas_attached, process_dimension_pragmas
)
from loki.program_unit import ProgramUnit
from loki.subroutine import Subroutine
from loki.tools import as_tuple
from loki.types import ModuleType, SymbolAttributes, Scope


__all__ = ['Module']


class Module(ProgramUnit):
    """
    Class to handle and manipulate source modules.

    Parameters
    ----------
    name : str
        Name of the module.
    docstring : :any:`CommentBlock` or list of :any:`Comment`
        The module docstring
    spec : :any:`Section`, optional
        The spec section of the module.
    contains : tuple of :any:`Subroutine`, optional
        The module-subprogram part following a ``CONTAINS`` statement declaring
        member procedures.
    default_access_spec : str, optional
        The default access attribute for variables as defined by an access-spec
        statement without access-id-list, i.e., ``public`` or ``private``.
        Default value is `None` corresponding to the absence of an access-spec
        statement for default accessibility (which is equivalent to ``public``).
    public_access_spec : tuple of str, optional
        List of identifiers that are declared ``public`` in an access-spec statement.
        Default value is `None` which is stored as an empty tuple.
    private_access_spec : tuple of str, optional
        List of identifiers that are declared ``private`` in an access-spec statement.
        Default value is `None` which is stored as an empty tuple.
    ast : optional
        The node for this module from the parse tree produced by the frontend.
    source : :any:`Source`, optional
        Object representing the raw source string information from the read file.
    parent : :any:`Scope`, optional
        The enclosing parent scope of the module. Declarations from the parent
        scope remain valid within the module's scope (unless shadowed by local
        declarations).
    rescope_symbols : bool, optional
        Ensure that the type information for all :any:`TypedSymbol` in the
        module's IR exist in the module's scope. Defaults to `False`.
    symbol_attrs : :any:`SymbolTable`, optional
        Use the provided :any:`SymbolTable` object instead of creating a new
    incomplete : bool, optional
        Mark the object as incomplete, i.e. only partially parsed. This is
        typically the case when it was instantiated using the :any:`Frontend.REGEX`
        frontend and a full parse using one of the other frontends is pending.
    parser_classes : :any:`RegexParserClass`, optional
        Provide the list of parser classes used during incomplete regex parsing
    """

    def __init__(
            self, name=None, docstring=None, spec=None, contains=None,
            default_access_spec=None, public_access_spec=None, private_access_spec=None,
            ast=None, source=None, parent=None, symbol_attrs=None, rescope_symbols=False,
            incomplete=False, parser_classes=None
    ):
        super().__init__(parent=parent)

        if symbol_attrs:
            self.symbol_attrs.update(symbol_attrs)

        self.__initialize__(
            name=name, docstring=docstring, spec=spec, contains=contains,
            default_access_spec=default_access_spec, public_access_spec=public_access_spec,
            private_access_spec=private_access_spec, ast=ast, source=source,
            rescope_symbols=rescope_symbols, incomplete=incomplete, parser_classes=parser_classes
        )

    def __initialize__(
            self, name=None, docstring=None, spec=None, contains=None,
            ast=None, source=None, rescope_symbols=False, incomplete=False, parser_classes=None,
            default_access_spec=None, public_access_spec=None, private_access_spec=None
    ):
        # Apply dimension pragma annotations to declarations
        if spec:
            with pragmas_attached(self, VariableDeclaration):
                spec = process_dimension_pragmas(spec)

        # Store the access spec properties
        self.default_access_spec = None if not default_access_spec else default_access_spec.lower()
        if not public_access_spec:
            self.public_access_spec = ()
        else:
            self.public_access_spec = tuple(v.lower() for v in as_tuple(public_access_spec))
        if not private_access_spec:
            self.private_access_spec = ()
        else:
            self.private_access_spec = tuple(v.lower() for v in as_tuple(private_access_spec))

        super().__initialize__(
            name=name, docstring=docstring, spec=spec, contains=contains, ast=ast,
            source=source, rescope_symbols=rescope_symbols, incomplete=incomplete, parser_classes=parser_classes
        )

    @classmethod
    def from_omni(cls, ast, raw_source, definitions=None, parent=None, type_map=None):
        """
        Create :any:`Module` from :any:`OMNI` parse tree

        Parameters
        ----------
        ast :
            The OMNI parse tree
        raw_source : str
            Fortran source string
        definitions : list, optional
            List of external :any:`Module` to provide derived-type and procedure declarations
        parent : :any:`Scope`, optional
            The enclosing parent scope of the module
        type_map : dict, optional
            A mapping from type hash identifiers to type definitions, as provided in
            OMNI's ``typeTable`` parse tree node
        """
        type_map = type_map or {}
        if ast.tag != 'FmoduleDefinition':
            ast = ast.find('globalDeclarations/FmoduleDefinition')
        return parse_omni_ast(
            ast=ast, definitions=definitions, raw_source=raw_source,
            type_map=type_map, scope=parent
        )

    @classmethod
    def from_fparser(cls, ast, raw_source, definitions=None, pp_info=None, parent=None):
        """
        Create :any:`Module` from :any:`FP` parse tree

        Parameters
        ----------
        ast :
            The FParser parse tree
        raw_source : str
            Fortran source string
        definitions : list
            List of external :any:`Module` to provide derived-type and procedure declarations
        pp_info :
            Preprocessing info as obtained by :any:`sanitize_input`
        parent : :any:`Scope`, optional
            The enclosing parent scope of the module.
        """
        if ast.__class__.__name__ != 'Module':
            ast = get_fparser_node(ast, 'Module')
        # Note that our Fparser interface returns a tuple with the
        # Module object always last but potentially containing
        # comments before the Module object
        return parse_fparser_ast(
            ast, pp_info=pp_info, definitions=definitions,
            raw_source=raw_source, scope=parent
        )[-1]

    @classmethod
    def from_regex(cls, raw_source, parser_classes=None, parent=None):
        """
        Create :any:`Module` from source regex'ing

        Parameters
        ----------
        raw_source : str
            Fortran source string
        parent : :any:`Scope`, optional
            The enclosing parent scope of the subroutine, typically a :any:`Module`.
        """
        ir_ = parse_regex_source(raw_source, parser_classes=parser_classes, scope=parent)
        return [node for node in ir_.body if isinstance(node, cls)][0]

    def register_in_parent_scope(self):
        """
        Insert the type information for this object in the parent's symbol table

        If :attr:`parent` is `None`, this does nothing.
        """
        if self.parent:
            self.parent.symbol_attrs[self.name] = SymbolAttributes(self.module_type)

    def clone(self, **kwargs):
        """
        Create a copy of the module with the option to override individual
        parameters.

        Parameters
        ----------
        **kwargs :
            Any parameters from the constructor of :any:`Module`.

        Returns
        -------
        :any:`Module`
            The cloned module object.
        """
        # Collect all properties bespoke to Subroutine
        if self.default_access_spec and 'default_access_spec' not in kwargs:
            kwargs['default_access_spec'] = self.default_access_spec
        if self.public_access_spec and 'public_access_spec' not in kwargs:
            kwargs['public_access_spec'] = self.public_access_spec
        if self.private_access_spec and 'private_access_spec' not in kwargs:
            kwargs['private_access_spec'] = self.private_access_spec

        # Escalate to parent class
        return super().clone(**kwargs)

    @property
    def module_type(self):
        """
        Return the :any:`ModuleType` of this module
        """
        return ModuleType(module=self)

    @property
    def _canonical(self):
        """
        Base definition for comparing :any:`Module` objects.
        """
        return (
            self.name, self.docstring, self.spec, self.contains, self.symbol_attrs,
            self.default_access_spec, self.public_access_spec, self.private_access_spec,
        )

    def __eq__(self, other):
        if isinstance(other, Module):
            return self._canonical == other._canonical
        return super().__eq__(other)

    def __hash__(self):
        return hash(self._canonical)

    def __getstate__(self):
        s = self.__dict__.copy()
        # TODO: We need to remove the AST, as certain AST types
        # (eg. FParser) are not pickle-safe.
        del s['_ast']
        return s

    def __setstate__(self, s):
        self.__dict__.update(s)

        # Re-register all contained procedures in symbol table and update parentage
        if self.contains:
            for node in self.contains.body:
                if isinstance(node, Subroutine):
                    node._reset_parent(self)
                    node.register_in_parent_scope()

                if isinstance(node, Scope):
                    node._reset_parent(self)

        # Ensure that we are attaching all symbols to the newly create ``self``.
        self.rescope_symbols()

    @property
    def definitions(self):
        """
        The list of IR nodes defined by this module

        Returns :any:`Subroutine` and :any:`TypeDef` nodes declared
        in this module
        """
        return self.subroutines + self.typedefs + self.variables + self.interfaces
loki-ecmwf-0.3.6/loki/jit_build/0000775000175000017500000000000015167130205016644 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/jit_build/__init__.py0000664000175000017500000000200015167130205020745 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
"""
Just-in-Time compilation utilities used in the Loki test base.

These allow compilation and wrapping of generated Fortran source code
using `f90wrap `_ for
execution from Python tests.
"""

from loki.logging import * # noqa

from loki.jit_build.binary import * # noqa
from loki.jit_build.builder import * # noqa
from loki.jit_build.compiler import * # noqa  # pylint: disable=redefined-builtin
from loki.jit_build.header import * # noqa
from loki.jit_build.jit import * # noqa
from loki.jit_build.lib import * # noqa
from loki.jit_build.obj import * # noqa
from loki.jit_build.workqueue import * # noqa
loki-ecmwf-0.3.6/loki/jit_build/tests/0000775000175000017500000000000015167130205020006 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/jit_build/tests/__init__.py0000664000175000017500000000057015167130205022121 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
loki-ecmwf-0.3.6/loki/jit_build/tests/c_util.c0000664000175000017500000000020215167130205021423 0ustar  alastairalastair
int mult_add_c(double a, double b, double c, double *sum)
{
  // Result value `sum` is passed by reference
  *sum = a * b + c;
}
loki-ecmwf-0.3.6/loki/jit_build/tests/base.f900000664000175000017500000000127615167130205021246 0ustar  alastairalastairmodule base

  implicit none

  save

  ! TODO: Using this in this module causes issues in f90wrap,
  ! so we leave it here for other modules to use, but don't
  ! use it in this module itself.
  integer, parameter :: jprb = selected_real_kind(13,300)

  real(kind=8) :: a, b
  integer :: i, j

contains

  function a_plus_b()
    ! Test to verify module-level variables work
    real(kind=8) :: a_plus_b
    a_plus_b = a + b
  end function a_plus_b

  function a_times_b_plus_c(a, b, c)
    ! Simple test to verify that module functions work
    real(kind=8) :: a_times_b_plus_c
    real(kind=8), intent(in) :: a, b, c
    a_times_b_plus_c = a * b + c
  end function a_times_b_plus_c

end module base
loki-ecmwf-0.3.6/loki/jit_build/tests/extension.f900000664000175000017500000000042515167130205022343 0ustar  alastairalastairsubroutine extended_fma(a, b, c, sum)
  ! Add number from an imported module
  use base, only: jprb, a_times_b_plus_c

  implicit none
  real(kind=jprb), intent(in) :: a, b, c
  real(kind=jprb), intent(out) :: sum

  sum = a_times_b_plus_c(a, b, c)
end subroutine extended_fma
loki-ecmwf-0.3.6/loki/jit_build/tests/wrapper.f900000664000175000017500000000126715167130205022014 0ustar  alastairalastairmodule wrapper
  implicit none
  integer, parameter :: jprb = selected_real_kind(13,300)

contains

  subroutine mult_add_external(a, b, c, sum)
    implicit none

    interface
       subroutine mult_add_fc(a, b, c, sum) &
            & bind(c, name='mult_add_c')
         use iso_c_binding, only: c_double
         implicit none

         ! Pass values in by value, out by reference
         real(kind=c_double), value :: a, b, c
         real(kind=c_double) :: sum
       end subroutine mult_add_fc
    end interface

    real(kind=jprb), intent(in) :: a, b, c
    real(kind=jprb), intent(out) :: sum

    call mult_add_fc(a, b, c, sum)
  end subroutine mult_add_external

end module wrapper
loki-ecmwf-0.3.6/loki/jit_build/tests/test_build.py0000664000175000017500000001566515167130205022533 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from pathlib import Path
import pytest

from loki.jit_build import Obj, Lib, Builder
from loki.jit_build.compiler import  (
    Compiler, GNUCompiler, NvidiaCompiler, get_compiler_from_env,
    _default_compiler
)


@pytest.fixture(scope='module', name='here')
def fixture_here():
    return Path(__file__).parent


@pytest.fixture(scope='function', name='builder')
def fixture_builder(here, tmp_path):
    yield Builder(source_dirs=here, build_dir=tmp_path)
    Obj.clear_cache()


@pytest.fixture(scope='module', name='testdir')
def fixture_testdir(here):
    return here.parent.parent/'tests'


def test_build_clean(builder):
    """
    Test basic `make clean`-style functionality.
    """
    # Mess up the build dir before cleaning it...
    (builder.build_dir/'xxx_a.o').touch(exist_ok=True)
    (builder.build_dir/'xxx_b.o').touch(exist_ok=True)
    (builder.build_dir/'xxx_a.mod').touch(exist_ok=True)
    (builder.build_dir/'xxx_a.so').touch(exist_ok=True)
    (builder.build_dir/'f90wrap_xxx_a.f90').touch(exist_ok=True)

    builder.clean('*.o *.mod *.so f90wrap*.f90')
    for f in builder.build_dir.iterdir():
        assert 'xxx' not in str(f)


def test_build_object(here, testdir, builder):
    """
    Test basic object compilation and wrapping via f90wrap.
    """
    obj = Obj(source_path=here/'base.f90')
    obj.build(builder=builder)
    assert (builder.build_dir/'base.o').exists()

    base = obj.wrap(builder=builder, kind_map=testdir/'kind_map')
    assert base.Base.a_times_b_plus_c(a=2, b=3, c=1) == 7


@pytest.mark.parametrize('workers', [None, 1, 3])
def test_build_lib(here, tmp_path, testdir, workers):
    """
    Test basic library compilation and wrapping via f90wrap
    from a specific list of source objects.
    """
    builder = Builder(source_dirs=here, build_dir=tmp_path, workers=workers)

    # Create library with explicit dependencies
    base = Obj(source_path=here/'base.f90')
    extension = Obj(source_path=here/'extension.f90')
    # Note: Need to compile statically to avoid LD_LIBRARY_PATH lookup
    lib = Lib(name='library', objs=[base, extension], shared=False)
    lib.build(builder=builder)
    assert (builder.build_dir/'liblibrary.a').exists()

    test = lib.wrap(modname='test', sources=[here/'extension.f90'], builder=builder,
                    kind_map=testdir/'kind_map')
    assert test.extended_fma(2., 3., 1.) == 7.


def test_build_lib_with_c(here, testdir, builder):
    """
    Test basic library compilation and wrapping via f90wrap
    from a specific list of source objects.
    """
    # Create library with explicit dependencies
    # objects = ['wrapper.f90', 'c_util.c']
    wrapper = Obj(source_path=here/'wrapper.f90')
    c_util = Obj(source_path=here/'c_util.c')
    lib = Lib(name='library', objs=[wrapper, c_util], shared=False)
    lib.build(builder=builder)
    assert (builder.build_dir/'liblibrary.a').exists()

    wrap = lib.wrap(modname='wrap', sources=[here/'wrapper.f90'], builder=builder,
                    kind_map=testdir/'kind_map')
    assert wrap.wrapper.mult_add_external(2., 3., 1.) == 7.


def test_build_obj_dependencies():
    """
    Test dependency resolution in a non-trivial module tree.
    """
    # # Wrap obj without specifying dependencies
    # test = builder.Obj('extension.f90').wrap()
    # assert test.library_test(1, 2, 3) == 12


def test_build_binary(builder):
    """
    Test basic binary compilation from objects and libs.
    """
    assert builder


@pytest.mark.parametrize('env,cls,attrs', [
    # Overwrite completely custom
    (
        {'CC': 'my-weird-compiler', 'FC': 'my-other-weird-compiler', 'F90': 'weird-fortran', 'FCFLAGS': '-my-flag  '},
        Compiler,
        {'CC': 'my-weird-compiler', 'FC': 'my-other-weird-compiler', 'F90': 'weird-fortran', 'FCFLAGS': ['-my-flag']},
    ),
    # GNUCompiler
    ({'CC': 'gcc'}, GNUCompiler, {'CC': 'gcc', 'FC': 'gfortran', 'F90': 'gfortran'}),
    ({'CC': 'gcc-13'}, GNUCompiler, None),
    ({'CC': '/path/to/my/gcc'}, GNUCompiler, None),
    ({'CC': '../../relative/path/to/my/gcc-11'}, GNUCompiler, None),
    ({'CC': 'C:\\windows\\path\\to\\gcc'}, GNUCompiler, None),
    ({'FC': 'gfortran'}, GNUCompiler, None),
    ({'FC': 'gfortran-13', 'FCFLAGS': '-O3 -g'}, GNUCompiler, {'FC': 'gfortran-13', 'FCFLAGS': ['-O3', '-g']}),
    ({'FC': '/path/to/my/gfortran'}, GNUCompiler, None),
    ({'FC': '../../relative/path/to/my/gfortran'}, GNUCompiler, None),
    ({'FC': 'C:\\windows\\path\\to\\gfortran'}, GNUCompiler, None),
    # NvidiaCompiler
    ({'FC': 'nvfortran'}, NvidiaCompiler, {'CC': 'nvc', 'FC': 'nvfortran', 'F90': 'nvfortran'}),
    ({'CC': 'nvc'}, NvidiaCompiler, None),
    ({'CC': '/path/to/my/nvc'}, NvidiaCompiler, None),
    ({'CC': '../../relative/path/to/my/nvc'}, NvidiaCompiler, None),
    ({'CC': 'C:\\windows\\path\\to\\nvc'}, NvidiaCompiler, None),
    ({'FC': 'pgf90'}, NvidiaCompiler, None),
    ({'FC': 'pgf95'}, NvidiaCompiler, None),
    ({'FC': 'pgfortran'}, NvidiaCompiler, None),
    ({'FC': '/path/to/my/nvfortran'}, NvidiaCompiler, None),
    ({'FC': '../../relative/path/to/my/pgfortran'}, NvidiaCompiler, None),
    ({'FC': 'C:\\windows\\path\\to\\nvfortran'}, NvidiaCompiler, None),
])
def test_get_compiler_from_env(env, cls, attrs):
    compiler = get_compiler_from_env(env)
    assert type(compiler) == cls  # pylint: disable=unidiomatic-typecheck
    for attr, expected_value in (attrs or env).items():
        # NB: We are comparing the lower-case attribute
        # because that contains the runtime value
        assert getattr(compiler, attr.lower()) == expected_value


def test_default_compiler():
    # Check that _default_compiler corresponds to a call with None
    compiler = get_compiler_from_env()
    assert type(compiler) == type(_default_compiler)  # pylint: disable=unidiomatic-typecheck


def test_obj_dependencies(tmp_path):
    fcode = """
module import_mod
    use module_mod
    implicit none
contains
    subroutine proc1
        use  , non_intrinsic :: iso_fortran_env, only: int8, int16
        use  , intrinsic :: iso_c_binding
        use other_Mod
        use :: third_mod
        use    fourth_mod , only:some
        use,non_intrinsic::very_condensed
    end subroutine proc1
    subroutine proc2
        use, intrinsic :: iso_fortran_env, only: int8, int16
        use::fifth_mod
    end subroutine proc2
end module import_mod
    """.strip()

    filepath = tmp_path/'import_mod.f90'
    filepath.write_text(fcode)

    obj = Obj(name='import_mod', source_path=filepath)
    assert obj.dependencies == (
        'module_mod', 'iso_fortran_env', 'other_Mod', 'third_mod', 'fourth_mod',
        'very_condensed', 'fifth_mod'
    )
    Obj.clear_cache()
loki-ecmwf-0.3.6/loki/jit_build/workqueue.py0000664000175000017500000001215215167130205021246 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from contextlib import contextmanager
from multiprocessing import Manager
from logging.handlers import QueueListener, QueueHandler
from concurrent.futures import ProcessPoolExecutor

from loki.logging import default_logger
from loki.tools import execute


__all__ = ['workqueue', 'wait_and_check', 'MEMORY_URL', 'DEFAULT_TIMEOUT']


MEMORY_URL = 'memory://'

# TODO: REALLY NEED TO MAKE THIS USER CONFIGURABLE!
DEFAULT_TIMEOUT = 60


class DummyQueue:
    """
    Dummy queue object to funnel workqueue requests to the current
    main process.
    """

    @staticmethod
    def execute(*args, **kwargs):
        execute(*args, **kwargs)

    @staticmethod
    def call(fn, *args, **kwargs):
        return fn(*args, **kwargs)


"""
A global flag to make worker initialization happen once only.
This is hacky, but it requires a stable Python3.7 to fix.
"""
_initialized = False


def init_worker(log_queue=None):
    """
    Process-local initialization of the worker process. This sets up
    the queue-based logging, etc.
    """
    if log_queue is not None:
        from loki import config  # pylint: disable=import-outside-toplevel
        log_level = config['log-level']

        # Set up logger to funnel logs back to master via ``log_queue``
        qh = QueueHandler(log_queue)
        qh.setLevel(log_level)

        # Wipe all local handlers, since we dispatch to the master.
        # We also drop the logging level, so that the master may
        # decide what to do.
        for handler in default_logger.handlers:
            default_logger.removeHandler(handler)
        default_logger.addHandler(qh)
        default_logger.setLevel(log_level)


def init_call(fn, *args, **kwargs):
    """
    Hack alert: This small wrapper function ensure that an initialization
    function is called once and only once per worker from the within the
    work scheduler. This is done to work around the fact that a global worker
    initialization mechanism is only added to :class:`ProcessPoolExecutor`
    in Python3.7, which (at the time of writing) is not out or mature yet.
    """
    global _initialized  # pylint: disable=global-statement
    log_queue = kwargs.pop('log_queue', None)
    if not _initialized:
        init_worker(log_queue=log_queue)
        _initialized = True

    return fn(*args, **kwargs)


def wait_and_check(task, timeout=DEFAULT_TIMEOUT, logger=None):
    """
    Wait for :param:`task` to complete and check for possible exceptions.
    """
    logger = logger or default_logger

    if task is not None:
        try:
            # Get result from the worker task and sanity check
            task.result(timeout=timeout)
            error = task.exception(timeout=timeout)

            if error is not None:
                logger.error('Failed compilation task: %s', task)
                raise error

        except TimeoutError as e:
            logger.error('Compilation task timed out: %s', task)
            raise e


class ParallelQueue:
    """
    Dummy queue object to funnel workqueue requests to the current
    main process.
    """

    def __init__(self, executor, logger=None, manager=None):
        self.executor = executor

        self.manager = None
        self.listener = None
        self.log_queue = None

        if logger is not None:
            # Initialize a listener for the logging queue that dispatches
            # to our pre-configured handlers on the master process
            self.manager = manager or Manager()
            self.log_queue = self.manager.Queue()
            self.listener = QueueListener(self.log_queue, *(logger.handlers),
                                          respect_handler_level=True)

    def execute(self, *args, **kwargs):
        """
        Wrapper around the ``tools.execute(cmd)`` function presented by the
        :class:`ParallelQueue` object to its users.
        """
        return self.executor.submit(init_call, execute, *args, **kwargs)

    def call(self, fn, *args, **kwargs):
        """
        Arbitrary interface to submit function calls to the
        :class:`ParallelQueue` object.
        """
        return self.executor.submit(init_call, fn, *args, **kwargs)


@contextmanager
def workqueue(workers=None, logger=None, manager=None):
    """
    Parallel work queue manager that creates a worker pool and exposes
    the ``q.execute(cmd)`` utility to invoke shell commands in parallel.
    """
    if workers is None:
        yield DummyQueue()
        return

    with ProcessPoolExecutor(max_workers=workers) as executor:
        q = ParallelQueue(executor, logger=logger, manager=manager)

        # We have to manually start and stop the queue listener
        # for our funneled logging setup.
        if q.listener:
            q.listener.start()

        yield q

        if q.listener:
            q.listener.stop()
loki-ecmwf-0.3.6/loki/jit_build/obj.py0000664000175000017500000001562215167130205017776 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from functools import cached_property
from pathlib import Path
import re

from loki.logging import debug
from loki.tools import execute, as_tuple, flatten, cached_func
from loki.jit_build.compiler import _default_compiler
from loki.jit_build.header import Header


__all__ = ['Obj']


_re_use = re.compile(
    r'\s*use\b(?:\s*,\s*non_intrinsic)?\s*(?:::\s*)?(?P\w+)',
    re.IGNORECASE | re.MULTILINE
)
"""
Pattern to match Fortran USE statements

This will intentionally not match module imports that have the module-nature
specified as ``INTRINSIC`` but it does match imports with optional colons or
explicit ``NON_INTRINSIC`` module-nature given.
"""

_re_include = re.compile(r'\#include\s+["\']([\w\.]+)[\"\']', re.IGNORECASE)
# Please note that the below regexes are fairly expensive due to .* with re.DOTALL
_re_module = re.compile(r'module\s+(\w+).*?end module', re.IGNORECASE | re.DOTALL)
_re_subroutine = re.compile(r'subroutine\s+(\w+).*?end subroutine', re.IGNORECASE | re.DOTALL)


class Obj:
    """
    A single source object representing a single C or Fortran source file.
    """

    MODEMAP = {'.f90': 'f90', '.f': 'f', '.c': 'c', '.cc': 'c', '.cpp': 'cpp',
               '.CC': 'cpp', '.cxx': 'cpp'}

    # Default source and header extension recognized
    # TODO: Make configurable!
    _ext = ['.f90', '.F90', '.f', '.F', '.c', '.cpp', '.CC', '.cc', '.cxx']

    def __new__(cls, *args, name=None, **kwargs):  # pylint: disable=unused-argument
        # Name is either provided or inferred from source_path
        name = name or Path(kwargs.get('source_path')).stem
        name = name.lower()  # Ensure no-caps!

        # Return an instance cached on the derived or provided name
        # TODO: We could make the path relative to a "cache path" here...
        return Obj.__xnew_cached_(cls, name)

    def __new_stage2_(self, name):  # pylint: disable=unused-private-member
        obj = super().__new__(self)
        obj.name = name
        return obj

    __xnew_cached_ = staticmethod(cached_func(__new_stage2_))

    @classmethod
    def clear_cache(cls):
        debug('Clearing Obj cache')
        cls._Obj__xnew_cached_.cache_clear()

    def __init__(self, name=None, source_path=None):  # pylint: disable=unused-argument
        self.q_task = None  # The parallel worker task

        if not hasattr(self, 'source_path'):
            # If this is the first time, establish the source path
            self.source_path = Path(source_path or self.name)  # pylint: disable=no-member

            if not self.source_path.exists():
                debug('Could not find source file for %s', self)
                self.source_path = None

    def __repr__(self):
        return f'Obj<{self.name}>'  # pylint: disable=no-member

    @cached_property
    def source(self):
        if self.source_path is not None:
            # TODO: Make encoding a global config item.
            with self.source_path.open(encoding='latin1') as f:
                source = f.read()
            return source
        return None

    @cached_property
    def modules(self):
        return list(_re_module.findall(self.source))

    @cached_property
    def subroutines(self):
        return list(_re_subroutine.findall(self.source))

    @cached_property
    def uses(self):
        if self.source is None:
            return []
        return list(_re_use.findall(self.source))

    @cached_property
    def includes(self):
        return list(_re_include.findall(self.source))

    @property
    def dependencies(self):
        """
        Names of build items that this item depends on.
        """
        if self.source is None:
            return ()

        # Pick out the header object from imports
        includes = [Path(incl).stem for incl in self.includes]
        includes = [Path(incl).stem if '.intfb' in incl else incl
                    for incl in includes]
        headers = [Header(name=i) for i in includes]

        # Add transitive module dependencies through header imports
        transitive = flatten(h.uses for h in headers if h.source_path is not None)
        return as_tuple(dict.fromkeys(self.uses + transitive))

    @property
    def definitions(self):
        """
        Names of provided subroutine and modules.
        """
        return as_tuple(self.modules + self.subroutines)

    def build(self, builder=None, logger=None, compiler=None,
              workqueue=None, force=False, include_dirs=None):
        """
        Execute the respective build command according to the given
        :param toochain:.

        Please note that this does not build any dependencies.
        """
        logger = logger or builder.logger
        compiler = compiler or builder.compiler
        build_dir = builder.build_dir
        include_dirs = (include_dirs or []) + ((builder.include_dirs if builder else None) or [])
        include_dirs = include_dirs if len(include_dirs) > 0 else None

        if self.source_path is None:
            raise RuntimeError(f'No source file found for {self}')

        mode = self.MODEMAP[self.source_path.suffix.lower()]
        source = self.source_path.absolute()
        target = (build_dir/self.name).with_suffix('.o')  # pylint: disable=no-member
        t_time = target.stat().st_mtime if target.exists() else None
        s_time = source.stat().st_mtime if source.exists() else None

        if not force and t_time is not None and s_time is not None \
           and t_time > s_time:
            logger.debug(f'{self} up-to-date, skipping...')
            return

        args = compiler.compile_args(source=source, include_dirs=include_dirs,
                                     target=target, mode=mode, mod_dir=build_dir)

        if workqueue is not None:
            self.q_task = workqueue.execute(args, log_queue=workqueue.log_queue)
        else:
            execute(args)

    def wrap(self, builder=None, kind_map=None):
        """
        Wrap the compiled object using ``f90wrap`` and return the loaded module.
        """
        build_dir = str(builder.build_dir)
        compiler = builder.compiler or _default_compiler

        module = self.source_path.stem
        source = [str(self.source_path)]
        compiler.f90wrap(modname=module, source=source, cwd=build_dir, kind_map=kind_map)

        # Execute the second-level wrapper (f2py-f90wrap)
        wrapper = f'f90wrap_{self.source_path.stem}.f90'
        if self.modules is None or len(self.modules) == 0:
            wrapper = 'f90wrap_toplevel.f90'
        compiler.f2py(modname=module, source=[wrapper, self.source_path],
                      cwd=build_dir)

        return builder.load_module(module)
loki-ecmwf-0.3.6/loki/jit_build/jit.py0000664000175000017500000001073015167130205020005 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
"""
Utilities to facilitate Just-in-Time compilation for testing purposes.
"""
from pathlib import Path

from loki.backend import fgen
from loki.jit_build.builder import Builder
from loki.jit_build.compiler import compile_and_load
from loki.jit_build.lib import Lib
from loki.jit_build.obj import Obj
from loki.ir import Section
from loki.module import Module
from loki.sourcefile import Sourcefile
from loki.subroutine import Subroutine
from loki.tools import as_tuple, gettempdir, filehash


__all__ = ['jit_compile', 'jit_compile_lib', 'clean_test']


_f90wrap_kind_map = Path(__file__).parent.parent/'tests/kind_map'


def jit_compile(source, filepath=None, objname=None):
    """
    Generate, Just-in-Time compile and load a given item
    for interactive execution.

    Parameters
    ----------
    source : :any:`Sourcefile` or :any:`Module` or :any:`Subroutine`
        The item to compile and load
    filepath : str or :any:`Path`, optional
        Path of the source file to write (default: hashed name in :any:`gettempdir()`)
    objname : str, optional
        Return a specific object (module or subroutine) in :attr:`source`
    """
    if isinstance(source, Sourcefile):
        filepath = source.path if filepath is None else Path(filepath)
        if filepath is None:
            filepath = Path(gettempdir()/filehash(source, prefix='', suffix='.f90'))
        source.write(path=filepath)
    else:
        source = fgen(source)
        if filepath is None:
            filepath = gettempdir()/filehash(source, prefix='', suffix='.f90')
        else:
            filepath = Path(filepath)
        Sourcefile(filepath).write(source=source)

    pymod = compile_and_load(filepath, cwd=str(filepath.parent), f90wrap_kind_map=_f90wrap_kind_map)

    if objname:
        return getattr(pymod, objname)
    return pymod


def jit_compile_lib(sources, path, name, wrap=None, builder=None):
    """
    Generate, just-in-time compile and load a set of items into a
    library and import dynamically into the Python runtime.

    Parameters
    ----------
    source : list
        Source items or filepaths to compile and add to lib
    path : str or :any:`Path`
        Basepath for on-the-fly creation of source files
    name : str
        Name of created lib
    wrap : list, optional
        File names to pass to ``f90wrap``. Defaults to list of source files.
    builder : :any:`Builder`, optional
        Builder object to use for lib compilation and linking
    """
    if builder is None:
        builder_provided = False
        builder = Builder(source_dirs=path, build_dir=path)
    else:
        builder_provided = True
    sourcefiles = []

    for source in sources:
        if isinstance(source, (str, Path)):
            sourcefiles.append(source)

        if isinstance(source, Sourcefile):
            filepath = source.path or path/f'{source.name}.f90'
            source.write(path=filepath)
            sourcefiles.append(source.path)

        elif isinstance(source, (Module, Subroutine)):
            filepath = path/f'{source.name}.f90'
            source = Sourcefile(filepath, ir=Section(body=as_tuple(source)))
            source.write(path=filepath)
            sourcefiles.append(source.path)

    objects = [Obj(source_path=s) for s in sourcefiles]
    lib = Lib(name=name, objs=objects, shared=False)
    lib.build(builder=builder)
    wrap = wrap or sourcefiles
    pymod = lib.wrap(modname=name, sources=wrap, builder=builder, kind_map=_f90wrap_kind_map)
    if not builder_provided:
        Obj.clear_cache()
    return pymod


def clean_test(filepath):
    """
    Clean test directory based on JIT'ed source file.
    """
    file_list = [
        filepath.with_suffix('.f90'), filepath.with_suffix('.o'),
        filepath.with_suffix('.py'), filepath.parent/'f90wrap_toplevel.f90',
        filepath.with_suffix('.mod'), filepath.with_suffix('.xmod')
    ]
    for f in file_list:
        if f.exists():
            f.unlink()
    for sofile in filepath.parent.glob(f'_{filepath.stem}.*.so'):
        sofile.unlink()
    f90wrap_path = filepath.parent/f'f90wrap_{filepath.name}'
    if f90wrap_path.exists():
        f90wrap_path.unlink()
loki-ecmwf-0.3.6/loki/jit_build/builder.py0000664000175000017500000002027715167130205020654 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from pathlib import Path
from collections import deque
from operator import attrgetter
import networkx as nx

from loki.config import config
from loki.logging import default_logger
from loki.tools import as_tuple, delete, load_module
from loki.jit_build.compiler import _default_compiler
from loki.jit_build.obj import Obj
from loki.jit_build.header import Header


__all__ = ['Builder']


class Builder:
    """
    A :class:`Builder` that compiles binaries or libraries, while performing
    automated dependency discovery from one or more source paths.

    :param sources: One or more paths to search for source files
    :param includes: One or more paths to that include header files
    """

    def __init__(self, source_dirs=None, include_dirs=None, root_dir=None,
                 build_dir=None, compiler=None, logger=None, workers=None):
        self.compiler = compiler or _default_compiler
        self.logger = logger or default_logger
        if workers is None:
            self.workers = config['jit-build-workers']
        else:
            self.workers = workers

        # Source dirs for auto-detection and include dis for preprocessing
        self.source_dirs = [Path(p).resolve() for p in as_tuple(source_dirs)]
        self.include_dirs = [Path(p).resolve() for p in as_tuple(include_dirs)]

        # Root and source directories for out-of source builds
        self.root_dir = None if root_dir is None else Path(root_dir)
        self.build_dir = Path.cwd() if build_dir is None else Path(build_dir)
        self.build_dir.mkdir(exist_ok=True)

        # Populate _object_cache for everything in source_dirs
        for source_dir in self.source_dirs:
            for ext in Obj._ext:
                _ = [Obj(source_path=f) for f in source_dir.glob(f'**/*{ext}')]

        for include_dir in self.include_dirs:
            for ext in Header._ext:
                _ = [Header(source_path=f) for f in include_dir.glob(f'**/*{ext}')]

    def __getitem__(self, *args, **kwargs):
        return Obj(*args, **kwargs)

    def get_item(self, key):
        return self[key]

    @staticmethod
    def get_dependency_graph(objs, depgen=None):
        """
        Construct a :class:`networkx.DiGraph` that represents the dependency graph.

        :param objs: List of :class:`Obj` to use as the root of the graph.
        :param depgen: Generator object to generate the next level of dependencies
                       from an item. Defaults to ``operator.attrgetter('dependencies')``.
        """
        depgen = depgen or attrgetter('dependencies')

        q = deque(as_tuple(objs))
        nodes = []
        edges = []

        while len(q) > 0:
            item = q.popleft()
            nodes.append(item)

            # Record the actual :class:`Obj` dependency objects
            item.obj_dependencies = []

            for dep in depgen(item):
                # Note, we always create an `Obj` node, even
                # if it has no source attached.
                node = Obj(name=dep)

                item.obj_dependencies.append(node)

                if node not in nodes:
                    nodes.append(node)
                    q.append(node)

                edges.append((item, node))

        # Create a nw.DiGraph from nodes/edges
        g = nx.DiGraph()
        g.add_nodes_from(nodes)
        g.add_edges_from(edges)

        return g

    def clean(self, rules=None, path=None):
        """
        Clean up a build directory according, either according to
        globbing rules or via explicit file paths.

        :param rules: String or list of strings with either explicit
                      filepaths or globbing rules; default is
                      ``'*.o *.mod *.so f90wrap*.f90'``.
        :param path: Optional directory path to clean; defaults
                     first to ``self.build_dir``, then simply ``./``.
        """
        # Derive defaults, split string rules and ensure iterability
        rules = rules or '*.o *.mod *.so *.a f90wrap*.f90'
        if isinstance(rules, str):
            rules = rules.split(' ')
        rules = as_tuple(rules)

        path = path or self.build_dir or Path('.')

        for r in rules:
            for f in path.glob(r):
                delete(f)

    def build(self, filename, target=None, shared=True, include_dirs=None, external_objs=None):  # pylint: disable=unused-argument
        item = self.get_item(filename)
        self.logger.info("Building %s", item)

        build_dir = str(self.build_dir) if self.build_dir else None

        # Include optional external objects in the build
        objs = [Path(o).resolve() for o in external_objs or []]

        # Build the entire dependency graph, including the source object
        dependencies = self.get_dependency_graph(item)
        for dep in reversed(list(nx.topological_sort(dependencies))):
            dep.build(compiler=self.compiler, build_dir=build_dir,
                      include_dirs=self.include_dirs)
            objs += [f'{dep.path.stem}.o']

        if target is not None:
            self.logger.info('Linking target: %s', target)
            self.compiler.link(objs=objs, target=target, cwd=build_dir)

    def load_module(self, module):
        """
        Handle import paths and load the compiled module
        """
        return load_module(module, path=self.build_dir.absolute())

    def wrap_and_load(self, sources, modname=None, build=True,
                      libs=None, lib_dirs=None, incl_dirs=None,
                      kind_map=None):
        """
        Performs the necessary build steps to compile and wrap a set
        of sources using ``f90wrap``

        This method returns a dynamically loaded Python module
        containinig wrappers for each Fortran
        procedure and module specified in :data:`sources`.

        Parameters
        ==========
        source : str or list of str
            Name(s) of source files to wrap
        modname : str, optional
            Optional module name for f90wrap to use
        build : bool, optional
            Flag to force building the source first; default: True.
        libs : list of str, optional
            Override for library names to link
        lib_dirs : list of str, optional
            Override for library paths to link from
        incl_dirs : list of str, optional
            Override for include directories
        kind_map : str, optional
            Path to ``f90wrap`` KIND_MAP file, containing a Python dictionary
            in f2py_f2cmap format.
        """
        items = as_tuple(self.get_item(s) for s in as_tuple(sources))
        build_dir = str(self.build_dir) if self.build_dir else None
        modname = modname or items[0].path.stem

        # Invoke build to ensure all base objects are built
        # TODO: Could automate this via timestamps/hashes, etc.
        if build:
            for item in items:
                target = f'lib{item.path.stem}.a'
                self.build(item.path.name, target=target, shared=False)

        # Execute the first-level wrapper (f90wrap)
        self.logger.info('Python-wrapping %s', items[0])
        sourcepaths = [str(i.path) for i in items]
        self.compiler.f90wrap(modname=modname, source=sourcepaths, cwd=build_dir, kind_map=kind_map)

        # Execute the second-level wrapper (f2py-f90wrap)
        wrappers = [f'f90wrap_{item.path.stem}.f90' for item in items]
        wrappers += ['f90wrap_toplevel.f90']  # Include the generic wrapper
        wrappers = [w for w in wrappers if (self.build_dir/w).exists()]

        # Resolve final compilation libraries and include dirs
        libs = libs or [modname]
        lib_dirs = lib_dirs or [str(self.build_dir.absolute())]
        incl_dirs = incl_dirs or []

        self.compiler.f2py(modname=modname, source=wrappers,
                           libs=libs, lib_dirs=lib_dirs,
                           incl_dirs=incl_dirs, cwd=build_dir)

        self.load_module(module=modname)
loki-ecmwf-0.3.6/loki/jit_build/header.py0000664000175000017500000000565715167130205020463 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import re
from pathlib import Path

try:
    from functools import cached_property
except ImportError:
    try:
        from cached_property import cached_property
    except ImportError:
        def cached_property(func):
            return func

from loki.logging import debug
from loki.tools import cached_func


__all__ = ['Header']


_re_use = re.compile(r'^\s*use\s+(?P\w+)', re.IGNORECASE | re.MULTILINE)
_re_include = re.compile(r'\#include\s+["\']([\w\.]+)[\"\']', re.IGNORECASE)
# Please note that the below regexes are fairly expensive due to .* with re.DOTALL
_re_module = re.compile(r'module\s+(\w+).*end module', re.IGNORECASE | re.DOTALL)
_re_subroutine = re.compile(r'subroutine\s+(\w+).*end subroutine', re.IGNORECASE | re.DOTALL)


class Header:

    _ext = ['.intfb.h', '.h']

    def __new__(cls, *args, name=None, **kwargs):  # pylint: disable=unused-argument
        # Name is either provided or inferred from source_path
        name = name or Path(kwargs.get('source_path')).stem
        # Hack: Remove the .intfb from the name
        if 'intfb' in name:
            name = Path(name).stem
        name = name.lower()

        # Return an instance cached on the derived or provided name
        # TODO: We could make the path relative to a "cache path" here...
        return Header.__xnew_cached_(cls, name)

    def __new_stage2_(self, name):
        obj = super().__new__(self)
        obj.name = name
        return obj

    __xnew_cached_ = staticmethod(cached_func(__new_stage2_))

    def __init__(self, name=None, source_path=None):  # pylint: disable=unused-argument
        if not hasattr(self, 'source_path'):
            # If this is the first time, establish the source path
            self.source_path = Path(source_path or self.name)  # pylint: disable=no-member

            if not self.source_path.exists():
                debug('Could not find source file for %s', self)
                self.source_path = None

    def __repr__(self):
        return f'Header<{self.name}>'  # pylint: disable=no-member

    @cached_property
    def source(self):
        if self.source_path is not None:
            # TODO: Make encoding a global config item.
            with self.source_path.open(encoding='latin1') as f:
                source = f.read()
            return source
        return None

    @cached_property
    def uses(self):
        if self.source is None:
            return []
        return [m.lower() for m in _re_use.findall(self.source)]

    @cached_property
    def includes(self):
        return [m.lower() for m in _re_include.findall(self.source)]
loki-ecmwf-0.3.6/loki/jit_build/lib.py0000664000175000017500000001435415167130205017773 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from operator import attrgetter
from pathlib import Path
import networkx as nx

try:
    from tqdm import tqdm
except ImportError:
    def tqdm(iterable, *args, **kwargs):  # pylint: disable=unused-argument
        return iterable

from loki.logging import warning
from loki.tools import as_tuple, find_paths
from loki.jit_build.compiler import _default_compiler
from loki.jit_build.obj import Obj
from loki.jit_build.workqueue import workqueue, wait_and_check


__all__ = ['Lib']


class Lib:
    """
    A library object linked from multiple compiled objects (:class:`Obj`).

    Note, either :param objs: or the arguments :param pattern: and
    :param source_dir: are required to generated the necessary dependencies.

    :param name: Name of the resulting library (without leading ``lib``).
    :param shared: Flag indicating a shared library build.
    :param objs: List of :class:`Obj` objects that define the objects to link.
    :param pattern: A glob pattern that determines the objects to link.
    :param source_dir: A file path to find objects on when resolving glob patterns.
    :param ignore: A (list of) glob patterns definig file to ignore when
                   generating dependencies from a glob pattern.
    """

    def __init__(self, name, shared=True, objs=None, pattern=None, source_dir=None, ignore=None):
        self.name = name
        self.shared = shared

        if objs is not None:
            self.objs = objs

        else:
            # Generate object list by globbing the source_dir according to pattern
            if source_dir is None:
                raise RuntimeError(f'No source directory found for pattern expansion in {self}')

            obj_paths = find_paths(directory=source_dir, pattern=pattern, ignore=ignore)
            self.objs = [Obj(source_path=p) for p in obj_paths]

        if len(self.objs) == 0:
            warning(f'{self}:: Empty dependency list: {self.objs}')

    def __repr__(self):
        return f'Lib<{self.name}>'

    def build(self, builder=None, logger=None, compiler=None, shared=None,
              force=False, include_dirs=None, external_objs=None):
        """
        Build the source objects and create target library.
        """
        compiler = compiler or builder.compiler
        logger = logger or builder.logger
        shared = self.shared if shared is None else shared
        build_dir = builder.build_dir
        workers = builder.workers

        suffix = '.so' if shared else '.a'
        target = (build_dir/(f'lib{self.name}')).with_suffix(suffix)

        # Establish file-modified times
        t_time = target.stat().st_mtime if target.exists() else None
        o_paths = [o.source_path for o in self.objs]
        if any(p is None for p in o_paths):
            o_time = None
        else:
            o_time = max(p.stat().st_mtime for p in o_paths)

        # Skip the build if up-to-date...
        if not force and t_time is not None and o_time is not None \
           and t_time > o_time:
            logger.info(f'{self} up-to-date, skipping...')
            return

        logger.info(f'Building {self} (workers={workers})')

        # Generate the dependncy graph implied by .mod files
        dep_graph = builder.get_dependency_graph(self.objs, depgen=attrgetter('dependencies'))

        def _build_objs(queue=None):
            # Traverse the dependency tree in reverse topological order
            topo_nodes = list(reversed(list(nx.topological_sort(dep_graph))))
            for obj in tqdm(topo_nodes):
                if obj.source_path and obj.q_task is None:

                    # Wait for dependencies to complete before scheduling item
                    if queue:
                        for dep in obj.obj_dependencies:
                            wait_and_check(dep.q_task, logger=logger)

                    # Schedule object compilation on the workqueue
                    obj.build(builder=builder, compiler=compiler, logger=logger,
                            workqueue=queue, force=force, include_dirs=include_dirs)

            if queue:
                # Ensure all build tasks have finished
                for obj in dep_graph.nodes:
                    if obj.q_task is not None:
                        wait_and_check(obj.q_task, logger=logger)

        if workers > 1:
            # Execute the object build in parallel via a queue of worker processes
            with workqueue(workers=workers, logger=logger) as q:
                _build_objs(q)
        else:
            _build_objs()

        # Link the final library
        objs = [Path(o).resolve() for o in external_objs or []]
        objs += [(build_dir/obj.name).with_suffix('.o') for obj in self.objs]
        logger.debug(f'Linking {self} ({len(objs)} objects)')
        compiler.link(target=target, objs=objs, shared=shared)

    def wrap(self, modname, builder, sources=None, libs=None, lib_dirs=None, kind_map=None):
        """
        Wrap the compiled library using ``f90wrap`` and return the loaded module.

        :param sources: List of source files to wrap for Python access.
        """
        items = as_tuple(Obj(source_path=s) for s in as_tuple(sources))
        build_dir = builder.build_dir
        compiler = builder.compiler or _default_compiler

        sourcepaths = [str(i.source_path) for i in items]
        compiler.f90wrap(modname=modname, source=sourcepaths, cwd=str(build_dir), kind_map=kind_map)

        # Execute the second-level wrapper (f2py-f90wrap)
        wrappers = [f'f90wrap_{item.source_path.stem}.f90' for item in items]
        wrappers += ['f90wrap_toplevel.f90']  # Include the generic wrapper
        wrappers = [w for w in wrappers if (build_dir/w).exists()]

        libs = [self.name] + (libs or [])
        lib_dirs = [str(build_dir.absolute())] + (lib_dirs or [])
        compiler.f2py(modname=modname, source=wrappers,
                      libs=libs, lib_dirs=lib_dirs, cwd=str(build_dir))

        return builder.load_module(modname)
loki-ecmwf-0.3.6/loki/jit_build/compiler.py0000664000175000017500000003416415167130205021040 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from importlib import import_module, reload
import os
from pathlib import Path
import re
import shutil
import sys

from loki.logging import info, debug
from loki.tools import execute, as_tuple, delete


__all__ = [
    'clean', 'compile', 'compile_and_load', '_default_compiler',
    'Compiler', 'get_compiler_from_env', 'GNUCompiler', 'NvidiaCompiler'
]


def _which(cmd):
    """
    Convenience wrapper around :any:`shutil.which` that adds the binary
    directory of the Python interpreter to the search path

    This is useful when called from a script that is installed
    in a virtual environment without having explicitly enabled that environment.
    In that case, utilities like f90wrap may be installed inside the virtual
    environment but the binary dir will not be part of the search path.
    """
    return shutil.which(cmd, path=f'{os.environ["PATH"]}:{Path(sys.executable).parent}')


def compile(filename, include_dirs=None, compiler=None, cwd=None):
    # Stop complaints about `compile` in this function
    # pylint: disable=redefined-builtin
    filepath = Path(filename)
    compiler = compiler or _default_compiler
    args = compiler.build_args(source=filepath.absolute(),
                               include_dirs=include_dirs)
    execute(args, cwd=cwd)


def clean(filename, pattern=None):
    """
    Clean up compilation files of previous runs.

    :param filename: Filename that triggered the original compilation.
    :param suffixes: Optional list of filetype suffixes to delete.
    """
    filepath = Path(filename)
    pattern = pattern or ['*.f90.cache', '*.o', '*.mod']
    for p in as_tuple(pattern):
        for f in filepath.glob(p):
            delete(f)


def compile_and_load(filename, cwd=None, f90wrap_kind_map=None, compiler=None):
    """
    Just-in-time compile Fortran source code and load the respective
    module or class.

    Both paths, classic subroutine-only and modern module-based are
    supported via the ``f2py`` and ``f90wrap`` packages.

    Parameters
    ----------
    filename : str
        The source file to be compiled.
    cwd : str, optional
        Working directory to use for calls to compiler.
    f90wrap_kind_map : str, optional
        Path to ``f90wrap`` KIND_MAP file, containing a Python dictionary
        in f2py_f2cmap format.
    compiler : :any:`Compiler`, optional
        Use the specified compiler to compile the Fortran source code. Defaults
        to :any:`_default_compiler`
    """
    info(f'Compiling: {filename}')
    filepath = Path(filename)
    clean(filename)

    pattern = ['*.f90.cache', '*.o', '*.mod', 'f90wrap_*.f90',
               f'{filepath.stem}.cpython*.so', f'{filepath.stem}.py']
    clean(filename, pattern=pattern)

    # Select a default compiler if none specified
    if not compiler:
        compiler = _default_compiler

    # Compile the true sources into a small static lib first
    compiler.compile(source=filepath.absolute(), cwd=cwd)
    compiler.link([f'{filepath.stem}.o'], f'lib{filepath.stem}.a', shared=False, cwd=cwd)

    # Generate the Python interfaces
    compiler.f90wrap(modname=filepath.stem, source=[filepath.absolute()], kind_map=f90wrap_kind_map, cwd=cwd)

    # Compile the dynamic library and link the static source lib
    f2py_source = [
        s for s in (f'f90wrap_{filepath.stem}.f90', 'f90wrap_toplevel.f90')
        if (filepath.parent/s).exists()
    ]
    compiler.f2py(
        modname=filepath.stem, source=f2py_source,
        libs=[f'{filepath.stem}'], lib_dirs=[cwd], cwd=cwd
    )

    # Add directory to module search path
    moddir = str(filepath.parent)
    if moddir not in sys.path:
        sys.path.append(moddir)

    if filepath.stem in sys.modules:
        # Reload module if already imported
        reload(sys.modules[filepath.stem])
        return sys.modules[filepath.stem]

    # Import module
    return import_module(filepath.stem)


class Compiler:
    """
    Base class for specifying different compiler toolchains.
    """

    CC = None
    CFLAGS = None
    CPP = None
    CPPFLAGS = None
    F90 = None
    F90FLAGS = None
    FC = None
    FCFLAGS = None
    LD = None
    LDFLAGS = None
    LD_STATIC = None
    LDFLAGS_STATIC = None

    def __init__(self):
        self.cc = self.CC or 'gcc'
        self.cflags = self.CFLAGS or ['-g', '-fPIC']
        self.cpp = self.CPP or 'g++'
        self.cppflags = self.CPPFLAGS or ['-g', '-fPIC']
        self.f90 = self.F90 or 'gfortran'
        self.f90flags = self.F90FLAGS or ['-g', '-fPIC']
        self.fc = self.FC or 'gfortran'
        self.fcflags = self.FCFLAGS or ['-g', '-fPIC']
        self.ld = self.LD or 'gfortran'
        self.ldflags = self.LDFLAGS or ['-static']
        self.ld_static = self.LD_STATIC or 'ar'
        self.ldflags_static = self.LDFLAGS_STATIC or ['src']

    def compile_args(self, source, target=None, include_dirs=None, mod_dir=None, mode='f90'):
        """
        Generate arguments for the build line.

        Parameters:
        -----------
        source : str or pathlib.Path
            Path to the source file to compile
        target : str or pathlib.Path, optional
            Path to the output binary to generate
        include_dirs : list of str or pathlib.Path, optional
            Path of include directories to specify during compile
        mod_dir : str or pathlib.Path, optional
            Path to directory containing Fortran .mod files
        mode : str, optional
            One of ``'f90'`` (free form), ``'f'`` (fixed form) or ``'c'``
        """
        assert mode in ['f90', 'f', 'c', 'cpp']
        include_dirs = include_dirs or []
        cc = {'f90': self.f90, 'f': self.fc, 'c': self.cc, 'cpp': self.cpp}[mode]
        args = [cc, '-c']
        args += {'f90': self.f90flags, 'f': self.fcflags, 'c': self.cflags, 'cpp': self.cppflags}[mode]
        args += self._include_dir_args(include_dirs)
        if mode not in ['c', 'cpp']:
            args += self._mod_dir_args(mod_dir)
        args += [] if target is None else ['-o', str(target)]
        args += [str(source)]
        return args

    def _include_dir_args(self, include_dirs):
        """
        Return a list of compile command arguments for adding
        all paths in :data:`include_dirs` as include directories
        """
        return [
            f'-I{incl!s}' for incl in as_tuple(include_dirs)
        ]

    def _mod_dir_args(self, mod_dir):
        """
        Return a list of compile command arguments for setting
        :data:`mod_dir` as search and output directory for module files
        """
        if mod_dir is None:
            return []
        return [f'-J{mod_dir!s}']

    def compile(self, source, target=None, include_dirs=None, use_c=False, use_cpp=False, cwd=None):
        """
        Execute a build command for a given source.
        """
        kwargs = {'target': target, 'include_dirs': include_dirs}
        if use_c:
            kwargs['mode'] = 'c'
        if use_cpp:
            kwargs['mode'] = 'cpp'
        args = self.compile_args(source, **kwargs)
        execute(args, cwd=cwd)

    def linker_args(self, objs, target, shared=True):
        """
        Generate arguments for the linker line.
        """
        linker = self.ld if shared else self.ld_static
        args = [linker]
        args += self.ldflags if shared else self.ldflags_static
        if linker != "ar":
            args += ['-o', str(target)]
        else:
            args += [str(target)]
        args += [str(o) for o in objs]
        return args

    def link(self, objs, target, shared=True, cwd=None):
        """
        Execute a link command for a given source.
        """
        args = self.linker_args(objs=objs, target=target, shared=shared)
        execute(args, cwd=cwd)

    @staticmethod
    def f90wrap_args(modname, source, kind_map=None):
        """
        Generate arguments for the ``f90wrap`` utility invocation line.
        """
        args = [_which('f90wrap')]
        args += ['-m', str(modname)]
        if kind_map is not None:
            args += ['-k', str(kind_map)]
        args += [str(s) for s in source]
        return args

    def f90wrap(self, modname, source, cwd=None, kind_map=None):
        """
        Invoke f90wrap command to create wrappers for a given module.
        """
        args = self.f90wrap_args(modname=modname, source=source, kind_map=kind_map)
        execute(args, cwd=cwd)

    def f2py_args(self, modname, source, libs=None, lib_dirs=None, incl_dirs=None, cwd=None):
        """
        Generate arguments for the ``f2py-f90wrap`` utility invocation line.
        """
        libs = libs or []
        lib_dirs = lib_dirs or []
        incl_dirs = incl_dirs or []

        # Due to f90wrap's recent switch to Meson as a build backend, the current working
        # directory is no longer automatically "included" because of the out-of-tree build
        # this implies. To make sure .mod files are still found as before, we need to add
        # the curent working directory to the include paths as a workaround, see
        # https://github.com/jameskermode/f90wrap/issues/226 for more details.
        if cwd and cwd not in incl_dirs:
            incl_dirs += [cwd]

        args = [_which('f2py-f90wrap'), '-c']
        args += [f'--f77exec={self.fc}']
        args += [f'--f90exec={self.f90}']
        args += ['--backend=meson']
        args += ['-m', f'_{modname}']
        for incl_dir in incl_dirs:
            args += [f'-I{incl_dir}']
        for lib in libs:
            args += [f'-l{lib}']
        for lib_dir in lib_dirs:
            args += [f'-L{lib_dir}']
        args += [str(s) for s in source]
        return args

    def f2py_env(self):
        env = os.environ.copy()
        env['CC'] = self.cc
        env['FC'] = self.fc
        env['F90'] = self.f90
        return env

    def f2py(self, modname, source, libs=None, lib_dirs=None, incl_dirs=None, cwd=None):
        """
        Invoke f90wrap command to create wrappers for a given module.
        """
        args = self.f2py_args(modname=modname, source=source, libs=libs,
                              lib_dirs=lib_dirs, incl_dirs=incl_dirs, cwd=cwd)
        execute(args, cwd=cwd, env=self.f2py_env())


class GNUCompiler(Compiler):
    """
    GNU compiler configuration for gcc and gfortran
    """

    CC = 'gcc'
    CFLAGS = ['-g', '-fPIC']
    CPP = 'g++'
    CPPFLAGS = ['-g', '-fPIC']
    F90 = 'gfortran'
    F90FLAGS = ['-g', '-fPIC']
    FC = 'gfortran'
    FCFLAGS = ['-g', '-fPIC']
    LD = 'gfortran'
    LDFLAGS = ['-static']
    LD_STATIC = 'ar'
    LDFLAGS_STATIC = ['src']
    F2PY_FCOMPILER_TYPE = 'gnu95'

    CC_PATTERN = re.compile(r'(^|/|\\)gcc\b')
    CPP_PATTERN = re.compile(r'(^|/|\\)g\+\+\b')
    FC_PATTERN = re.compile(r'(^|/|\\)gfortran\b')


class NvidiaCompiler(Compiler):
    """
    NVHPC compiler configuration for nvc and nvfortran
    """

    CC = 'nvc'
    CFLAGS = ['-g', '-fPIC']
    CPP = 'nvc++'
    CPPFLAGS =  ['-g', '-fPIC']
    F90 = 'nvfortran'
    F90FLAGS = ['-g', '-fPIC']
    FC = 'nvfortran'
    FCFLAGS = ['-g', '-fPIC']
    LD = 'nvfortran'
    LDFLAGS = ['-static']
    LD_STATIC = 'ar'
    LDFLAGS_STATIC = ['src']
    F2PY_FCOMPILER_TYPE = 'nv'

    CC_PATTERN = re.compile(r'(^|/|\\)nvc\b')
    CPP_PATTERN = re.compile(r'(^|/|\\)nvc\+\+\b')
    FC_PATTERN = re.compile(r'(^|/|\\)(pgf9[05]|pgfortran|nvfortran)\b')

    def _mod_dir_args(self, mod_dir):
        if mod_dir is None:
            return []
        return ['-module', str(mod_dir)]


def get_compiler_from_env(env=None):
    """
    Utility function to determine what compiler to use

    This takes the following environment variables in the given order
    into account to determine the most likely compiler family:
    ``F90``, ``FC``, ``CC``.

    Currently, :any:`GNUCompiler` and :any:`NvidiaCompiler` are available.

    The compiler binary and flags can be further overwritten by setting
    the corresponding environment variables:

    - ``CC``, ``FC``, ``F90``, ``LD`` for compiler/linker binary name or path
    - ``CFLAGS``, ``FCFLAGS``, ``LDFLAGS`` for compiler/linker flags to use

    Parameters
    ----------
    env : dict, optional
        Use the specified environment (default: :any:`os.environ`)

    Returns
    -------
    :any:`Compiler`
        A compiler object
    """
    if env is None:
        env = os.environ

    candidates = (GNUCompiler, NvidiaCompiler)
    compiler = None

    # "guess" the most likely compiler choice
    var_pattern_map = {
        'F90': 'FC_PATTERN',
        'FC': 'FC_PATTERN',
        'CC': 'CC_PATTERN',
        'CPP': 'CPP_PATTERN'
    }
    for var, pattern in var_pattern_map.items():
        if env.get(var):
            for candidate in candidates:
                if getattr(candidate, pattern).search(env[var]):
                    compiler = candidate()
                    debug(f'Environment variable {var}={env[var]} set, using {candidate}')
                    break
            else:
                continue
            break

    if compiler is None:
        compiler = Compiler()

    # overwrite compiler executable and compiler flags with environment values
    var_compiler_map = {
        'CC': 'cc',
        'CPP': 'cpp',
        'FC': 'fc',
        'F90': 'f90',
        'LD': 'ld',
    }
    for var, attr in var_compiler_map.items():
        if var in env:
            setattr(compiler, attr, env[var].strip())
            debug(f'Environment variable {var} set, using custom compiler executable {env[var]}')

    var_flag_map = {
        'CFLAGS': 'cflags',
        'CPPFLAGS': 'cppflags',
        'FCFLAGS': 'fcflags',
        'LDFLAGS': 'ldflags',
    }
    for var, attr in var_flag_map.items():
        if var in env:
            setattr(compiler, attr, env[var].strip().split())
            debug(f'Environment variable {var} set, overwriting compiler flags as {env[var]}')

    return compiler


# TODO: Properly integrate with a config dict (with callbacks)
_default_compiler = get_compiler_from_env()
loki-ecmwf-0.3.6/loki/jit_build/binary.py0000664000175000017500000000172215167130205020504 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.tools import flatten

__all__ = ['Binary']


class Binary:
    """
    A binary build target to generate executables.
    """

    def __init__(self, name, objs=None, libs=None):
        self.name = name
        self.objs = objs or []
        self.libs = libs or []

    def build(self, builder):

        # Trigger build for object dependencies
        for obj in flatten(self.objs):
            obj.build(builder=builder)

        # Trigger build for library dependencies
        for lib in flatten(self.libs):
            lib.build(builder=builder)

        # TODO: Link the final binary
loki-ecmwf-0.3.6/loki/sourcefile.py0000664000175000017500000005647415167130205017431 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Contains the declaration of :any:`Sourcefile` that is used to represent and
manipulate (Fortran) source code files.
"""
from pathlib import Path
from codetiming import Timer

from loki.backend.fgen import fgen
from loki.backend.cufgen import cufgen
from loki.frontend import (
    Frontend, OMNI, FP, REGEX, sanitize_input, Source, read_file,
    preprocess_cpp, parse_omni_source, parse_fparser_source,
    parse_omni_ast, parse_fparser_ast, parse_regex_source,
    RegexParserClass

)
from loki.ir import Section, RawSource, Comment, PreprocessorDirective
from loki.logging import debug, detail, perf
from loki.module import Module
from loki.program_unit import ProgramUnit
from loki.subroutine import Subroutine
from loki.tools import flatten, as_tuple


__all__ = ['Sourcefile']


class Sourcefile:
    """
    Class to handle and manipulate source files, storing :any:`Module` and
    :any:`Subroutine` objects.

    Reading existing source code from file or string can be done via
    :meth:`from_file` or :meth:`from_source`.

    Parameters
    ----------
    path : str
        The name of the source file.
    ir : :any:`Section`, optional
        The IR of the file content (including :any:`Subroutine`, :any:`Module`,
        :any:`Comment` etc.)
    ast : optional
        Parser-AST of the original source file.
    source : :any:`Source`, optional
        Raw source string and line information about the original source file.
    incomplete : bool, optional
        Mark the object as incomplete, i.e. only partially parsed. This is
        typically the case when it was instantiated using the :any:`Frontend.REGEX`
        frontend and a full parse using one of the other frontends is pending.
    parser_classes : :any:`RegexParserClass`, optional
        Provide the list of parser classes used during incomplete regex parsing
    """

    def __init__(self, path, ir=None, ast=None, source=None, incomplete=False, parser_classes=None):
        self.path = Path(path) if path is not None else path
        if ir is not None and not isinstance(ir, Section):
            ir = Section(body=ir)
        self.ir = ir
        self._ast = ast
        self._source = source
        self._incomplete = incomplete
        self._parser_classes = parser_classes

    def clone(self, **kwargs):
        """
        Replicate the object with the provided overrides.
        """
        kwargs.setdefault('path', self.path)
        if self.ir is not None and 'ir' not in kwargs:
            kwargs['ir'] = self.ir
            ir_needs_clone = True
        else:
            ir_needs_clone = False
        if self._ast is not None and 'ast' not in kwargs:
            kwargs['ast'] = self._ast
        if self.source is not None and 'source' not in kwargs:
            kwargs['source'] = self._source.clone(file=kwargs['path'])
        kwargs.setdefault('incomplete', self._incomplete)
        if self._parser_classes is not None and 'parser_classes' not in kwargs:
            kwargs['parser_classes'] = self._parser_classes

        obj = type(self)(**kwargs)

        # When the IR has been carried over from the current sourcefile
        # we need to make sure we perform a deep copy
        if obj.ir and ir_needs_clone:
            ir_body = tuple(
                node.clone(rescope_symbols=True) if isinstance(node, ProgramUnit)
                else node.clone() for node in obj.ir.body
            )
            obj.ir = obj.ir.clone(body=ir_body)
        return obj

    @classmethod
    def from_file(cls, filename, definitions=None, preprocess=False,
                  includes=None, defines=None, omni_includes=None,
                  xmods=None, frontend=FP, parser_classes=None):
        """
        Constructor from raw source files that can apply a
        C-preprocessor before invoking frontend parsers.

        Parameters
        ----------
        filename : str
            Name of the file to parse into a :any:`Sourcefile` object.
        definitions : list of :any:`Module`, optional
            :any:`Module` object(s) that may supply external type or procedure
            definitions.
        preprocess : bool, optional
            Flag to trigger CPP preprocessing (by default `False`).

            .. attention::
                Please note that, when using the OMNI frontend, C-preprocessing
                will always be applied, so :data:`includes` and :data:`defines`
                may have to be defined even when disabling :data:`preprocess`.

        includes : list of str, optional
            Include paths to pass to the C-preprocessor.
        defines : list of str, optional
            Symbol definitions to pass to the C-preprocessor.
        xmods : str, optional
            Path to directory to find and store ``.xmod`` files when using the
            OMNI frontend.
        omni_includes: list of str, optional
            Additional include paths to pass to the preprocessor run as part of
            the OMNI frontend parse. If set, this **replaces** (!)
            :data:`includes`, otherwise :data:`omni_includes` defaults to the
            value of :data:`includes`.
        frontend : :any:`Frontend`, optional
            Frontend to use for producing the AST (default :any:`FP`).
        """
        if isinstance(frontend, str):
            frontend = Frontend[frontend.upper()]

        # Log full parses at INFO and regex scans at DETAIL level
        log = f'[Loki::Sourcefile] Constructed from {filename}' + ' in {:.2f}s'
        with Timer(logger=detail if frontend is REGEX else perf, text=log):

            filepath = Path(filename)
            raw_source = read_file(filepath)

            if preprocess:
                # Trigger CPP-preprocessing explicitly, as includes and
                # defines can also be used by our OMNI frontend
                source = preprocess_cpp(source=raw_source, filepath=filepath,
                                        includes=includes, defines=defines)
            else:
                source = raw_source

            if frontend == REGEX:
                return cls.from_regex(source, filepath, parser_classes=parser_classes)

            if frontend == OMNI:
                return cls.from_omni(source, filepath, definitions=definitions,
                                     includes=includes, defines=defines,
                                     xmods=xmods, omni_includes=omni_includes)

            if frontend == FP:
                return cls.from_fparser(source, filepath, definitions=definitions)

            raise NotImplementedError(f'Unknown frontend: {frontend}')

    @classmethod
    def from_omni(cls, raw_source, filepath, definitions=None, includes=None,
                  defines=None, xmods=None, omni_includes=None):
        """
        Parse a given source string using the OMNI frontend

        Parameters
        ----------
        raw_source : str
            Fortran source string
        filepath : str or :any:`pathlib.Path`
            The filepath of this source file
        definitions : list
            List of external :any:`Module` to provide derived-type and procedure declarations
        includes : list of str, optional
            Include paths to pass to the C-preprocessor.
        defines : list of str, optional
            Symbol definitions to pass to the C-preprocessor.
        xmods : str, optional
            Path to directory to find and store ``.xmod`` files when using the
            OMNI frontend.
        omni_includes: list of str, optional
            Additional include paths to pass to the preprocessor run as part of
            the OMNI frontend parse. If set, this **replaces** (!)
            :data:`includes`, otherwise :data:`omni_includes` defaults to the
            value of :data:`includes`.
        """
        # Always CPP-preprocess source files for OMNI, but optionally
        # use a different set of include paths if specified that way.
        # (It's a hack, I know, but OMNI sucks, so what can I do...?)
        if omni_includes is not None and len(omni_includes) > 0:
            includes = omni_includes
        source = preprocess_cpp(raw_source, filepath=filepath,
                                includes=includes, defines=defines)

        # Parse the file content into an OMNI Fortran AST
        ast = parse_omni_source(source=source, filepath=filepath, xmods=xmods)
        typetable = ast.find('typeTable')
        return cls._from_omni_ast(ast=ast, path=filepath, raw_source=raw_source,
                                  definitions=definitions, typetable=typetable)

    @classmethod
    def _from_omni_ast(cls, ast, path=None, raw_source=None, definitions=None, typetable=None):
        """
        Generate the full set of `Subroutine` and `Module` members of the `Sourcefile`.
        """
        type_map = {t.attrib['type']: t for t in typetable}
        if ast.find('symbols') is not None:
            symbol_map = {s.attrib['type']: s for s in ast.find('symbols')}
        else:
            symbol_map = None

        ir = parse_omni_ast(
            ast=ast, definitions=definitions, raw_source=raw_source,
            type_map=type_map, symbol_map=symbol_map
        )

        lines = (1, raw_source.count('\n') + 1)
        source = Source(lines, string=raw_source, file=path)
        return cls(path=path, ir=ir, ast=ast, source=source)

    @classmethod
    def from_fparser(cls, raw_source, filepath, definitions=None):
        """
        Parse a given source string using the fparser frontend

        Parameters
        ----------
        raw_source : str
            Fortran source string
        filepath : str or :any:`pathlib.Path`
            The filepath of this source file
        definitions : list
            List of external :any:`Module` to provide derived-type and procedure declarations
        """
        # Preprocess using internal frontend-specific PP rules
        # to sanitize input and work around known frontend problems.
        source, pp_info = sanitize_input(source=raw_source, frontend=FP)

        # Parse the file content into a Fortran AST
        ast = parse_fparser_source(source)

        return cls._from_fparser_ast(path=filepath, ast=ast, definitions=definitions,
                                     pp_info=pp_info, raw_source=raw_source)

    @classmethod
    def _from_fparser_ast(cls, ast, path=None, raw_source=None, definitions=None, pp_info=None):
        """
        Generate the full set of :any:`Subroutine` and :any:`Module` members
        in the :any:`Sourcefile`.
        """
        ir = parse_fparser_ast(ast, pp_info=pp_info, definitions=definitions, raw_source=raw_source)

        lines = (1, raw_source.count('\n') + 1)
        source = Source(lines, string=raw_source, file=path)
        ir._update(source=source)
        return cls(path=path, ir=ir, ast=ast, source=source)

    @classmethod
    def from_regex(cls, raw_source, filepath, parser_classes=None):
        """
        Parse a given source string using the REGEX frontend
        """
        source, _ = sanitize_input(source=raw_source, frontend=REGEX)

        if parser_classes is None:
            parser_classes = RegexParserClass.AllClasses
        ir = parse_regex_source(source, parser_classes=parser_classes)
        lines = (1, raw_source.count('\n') + 1)
        source = Source(lines, string=raw_source, file=filepath)
        return cls(path=filepath, ir=ir, source=source, incomplete=True, parser_classes=parser_classes)

    @classmethod
    def from_source(cls, source, definitions=None, preprocess=False,
                    includes=None, defines=None, omni_includes=None,
                    xmods=None, frontend=FP, parser_classes=None):
        """
        Constructor from raw source string that invokes specified frontend parser

        Parameters
        ----------
        source : str
            Fortran source string
        definitions : list of :any:`Module`, optional
            :any:`Module` object(s) that may supply external type or procedure
            definitions.
        preprocess : bool, optional
            Flag to trigger CPP preprocessing (by default `False`).

            .. attention::
                Please note that, when using the OMNI frontend, C-preprocessing
                will always be applied, so :data:`includes` and :data:`defines`
                may have to be defined even when disabling :data:`preprocess`.

        includes : list of str, optional
            Include paths to pass to the C-preprocessor.
        defines : list of str, optional
            Symbol definitions to pass to the C-preprocessor.
        xmods : str, optional
            Path to directory to find and store ``.xmod`` files when using the
            OMNI frontend.
        omni_includes: list of str, optional
            Additional include paths to pass to the preprocessor run as part of
            the OMNI frontend parse. If set, this **replaces** (!)
            :data:`includes`, otherwise :data:`omni_includes` defaults to the
            value of :data:`includes`.
        frontend : :any:`Frontend`, optional
            Frontend to use for producing the AST (default :any:`FP`).
        """
        if isinstance(frontend, str):
            frontend = Frontend[frontend.upper()]

        if preprocess:
            # Trigger CPP-preprocessing explicitly, as includes and
            # defines can also be used by our OMNI frontend
            source = preprocess_cpp(source=source, includes=includes, defines=defines)

        if frontend == REGEX:
            return cls.from_regex(source, filepath=None, parser_classes=parser_classes)

        if frontend == OMNI:
            return cls.from_omni(source, filepath=None, definitions=definitions, includes=includes,
                                 defines=defines, xmods=xmods, omni_includes=omni_includes)

        if frontend == FP:
            return cls.from_fparser(source, filepath=None, definitions=definitions)

        raise NotImplementedError(f'Unknown frontend: {frontend}')

    def make_complete(self, **frontend_args):
        """
        Trigger a re-parse of the source file if incomplete to produce a full Loki IR

        If the source file is marked to be incomplete, i.e. when using the `lazy` constructor
        option, this triggers a new parsing of all :any:`ProgramUnit` objects and any
        :any:`RawSource` nodes in the :attr:`Sourcefile.ir`.

        Existing :any:`Module` and :any:`Subroutine` objects continue to exist and references
        to them stay valid, as they will only be updated instead of replaced.
        """
        if not self._incomplete:
            return

        frontend = frontend_args.pop('frontend', FP)

        log = f'[Loki::Sourcefile] Finished constructing from {self.path}' + ' in {:.2f}s'
        with Timer(logger=debug if frontend == REGEX else perf, text=log):

            # Sanitize frontend_args
            if isinstance(frontend, str):
                frontend = Frontend[frontend.upper()]
            if frontend == REGEX:
                frontend_argnames = ['parser_classes']
            elif frontend == OMNI:
                frontend_argnames = ['definitions', 'type_map', 'symbol_map', 'scope']
                xmods = frontend_args.get('xmods')
            elif frontend == FP:
                frontend_argnames = ['definitions', 'scope']
            else:
                raise NotImplementedError(f'Unknown frontend: {frontend}')
            sanitized_frontend_args = {k: frontend_args.get(k) for k in frontend_argnames}

            body = []
            for node in self.ir.body:
                if isinstance(node, ProgramUnit):
                    node.make_complete(frontend=frontend, **frontend_args)
                    body += [node]
                elif isinstance(node, RawSource):
                    # Sanitize the input code to ensure non-supported features
                    # do not break frontend parsing ourside of program units
                    raw_source = node.source.string
                    source, pp_info = sanitize_input(source=raw_source, frontend=frontend)

                    # Typically, this should only be comments, PP statements etc., therefore
                    # we are not bothering with type tables, definitions or similar to parse them
                    if frontend == REGEX:
                        ir_ = parse_regex_source(source, **sanitized_frontend_args)
                    elif frontend == OMNI:
                        ast = parse_omni_source(source=source, xmods=xmods)
                        ir_ = parse_omni_ast(ast=ast, raw_source=raw_source, **sanitized_frontend_args)
                    elif frontend == FP:
                        # Fparser is unable to parse comment-only source files/strings,
                        # so we see if this is only comments and convert them ourselves
                        # (https://github.com/stfc/fparser/issues/375)
                        # This can be removed once fparser 0.0.17 is released
                        lines = [l.lstrip() for l in source.splitlines()]
                        if all(not l or l[0] in '!#' for l in lines):
                            ir_ = [
                                PreprocessorDirective(text=line.string, source=line)
                                if line.string.lstrip().startswith('#')
                                else Comment(text=line.string, source=line)
                                for line in node.source.clone_lines()
                            ]
                        else:
                            ast = parse_fparser_source(source)
                            ir_ = parse_fparser_ast(ast, raw_source=raw_source, pp_info=pp_info,
                                                    **sanitized_frontend_args)
                    else:
                        raise NotImplementedError(f'Unknown frontend: {frontend}')
                    if isinstance(ir_, Section):
                        ir_ = ir_.body
                    body += flatten([ir_])
                else:
                    body += [node]

            self.ir._update(body=as_tuple(body))
            self._incomplete = frontend == REGEX
            if frontend == REGEX:
                parser_classes = frontend_args.get('parser_classes', RegexParserClass.AllClasses)
                if self._parser_classes:
                    parser_classes = self._parser_classes | parser_classes
                self._parser_classes = parser_classes

    @property
    def source(self):
        return self._source

    def to_fortran(self, conservative=False, cuf=False, style=None):
        if cuf:
            return cufgen(self, style=style)
        return fgen(self, conservative=conservative, style=style)

    @property
    def modules(self):
        """
        List of :class:`Module` objects that are members of this :class:`Sourcefile`.
        """
        if self.ir is None:
            return ()
        return as_tuple(
            module for module in self.ir.body if isinstance(module, Module)
        )

    @property
    def routines(self):
        """
        List of :class:`Subroutine` objects that are members of this :class:`Sourcefile`.
        """
        if self.ir is None:
            return ()
        return as_tuple(
            routine for routine in self.ir.body if isinstance(routine, Subroutine)
        )

    subroutines = routines

    @property
    def typedefs(self):
        """
        List of :class:`TypeDef` objects that are declared in the :any:`Module` in this :class:`Sourcefile`.
        """
        if self.ir is None:
            return ()
        return as_tuple(flatten(module.typedefs for module in self.modules))

    @property
    def all_subroutines(self):
        routines = self.subroutines
        routines += as_tuple(flatten(m.subroutines for m in self.modules))
        return routines

    @property
    def definitions(self):
        """
        List of all definitions made in this sourcefile, i.e. modules and subroutines
        """
        return self.modules + self.subroutines

    def __contains__(self, name):
        """
        Check if a module, type, or subroutine with the given name is declared
        inside this sourcefile
        """
        return self[name] is not None

    def __getitem__(self, name):
        name = name.lower()
        for module in self.modules:
            if name == module.name.lower():
                return module

        for routine in self.all_subroutines:
            if name == routine.name.lower():
                return routine

        for module in self.modules:
            for typedef in module.typedefs:
                if name == typedef.name.lower():
                    return typedef
            for interface in module.interfaces:
                if name in interface.symbols:
                    return interface

        return None

    def __iter__(self):
        raise TypeError('Sourcefiles alone cannot be traversed! Try traversing "Sourcefile.ir".')

    def __bool__(self):
        """
        Ensure existing objects register as True in boolean checks, despite
        raising exceptions in `__iter__`.
        """
        return True

    @property
    def _canonical(self):
        """
        Base definition for comparing :any:`Subroutine` objects.
        """
        return (self.path, self.ir, self.source, )

    def __eq__(self, other):
        if isinstance(other, Sourcefile):
            return self._canonical == other._canonical
        return super().__eq__(other)

    def __hash__(self):
        return hash(self._canonical)

    def __getstate__(self):
        # Do not pickle the AST, as it is not pickle-safe for certain frontends
        _ignore = ('_ast',)
        return dict((k, v) for k, v in self.__dict__.items() if k not in _ignore)

    def apply(self, op, **kwargs):
        """
        Apply a given transformation to the source file object.

        Note that the dispatch routine `op.apply(source)` will ensure
        that all entities of this `Sourcefile` are correctly traversed.
        """
        # TODO: Should type-check for an `Operation` object here
        op.apply(self, **kwargs)

    def write(self, path=None, source=None, conservative=False, cuf=False, style=None):
        """
        Write content as Fortran source code to file

        Parameters
        ----------
        path : str, optional
            Filepath of target file; if not provided, :attr:`Sourcefile.path` is used
        source : str, optional
            Write the provided string instead of generating via :any:`Sourcefile.to_fortran`
        conservative : bool, optional
            Enable conservative output in the backend, aiming to be as much string-identical
            as possible (default: False)
        cuf: bool, optional
            To use either Cuda Fortran or Fortran backend
        """
        path = self.path if path is None else Path(path)
        source = self.to_fortran(conservative, cuf, style=style) if source is None else source
        self.to_file(source=source, path=path)

    @classmethod
    def to_file(cls, source, path):
        """
        Same as :meth:`write` but can be called from a static context.
        """
        detail(f'[Loki::Sourcefile] Writing to {path}')
        with path.open('w') as f:
            f.write(source)
            if source[-1] != '\n':
                f.write('\n')
loki-ecmwf-0.3.6/loki/ir/0000775000175000017500000000000015167130205015311 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/ir/__init__.py0000664000175000017500000000134515167130205017425 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
The Loki internal representation (IR) and associated APIs for tree traversal.
"""

from loki.ir.expr_visitors import *  # noqa
from loki.ir.find import *  # noqa
from loki.ir.ir_graph import *  # noqa
from loki.ir.nodes import *  # noqa
from loki.ir.pragma_utils import *  # noqa
from loki.ir.transformer import *  # noqa
from loki.ir.visitor import *  # noqa
loki-ecmwf-0.3.6/loki/ir/tests/0000775000175000017500000000000015167130205016453 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/ir/tests/__init__.py0000664000175000017500000000057015167130205020566 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
loki-ecmwf-0.3.6/loki/ir/tests/test_pragma_utils.py0000664000175000017500000006166615167130205022572 0ustar  alastairalastairfrom io import StringIO
import pytest

from loki import Module, Subroutine, FindNodes, flatten, pprint, fgen
from loki.frontend import available_frontends, FP
from loki.ir import Pragma, Loop, VariableDeclaration, PragmaRegion
from loki.ir.pragma_utils import (
    is_loki_pragma, get_pragma_parameters, attach_pragmas, detach_pragmas,
    pragmas_attached, pragma_regions_attached, SubstitutePragmaStrings
)


@pytest.mark.parametrize('keyword, content, starts_with, ref', [
    ('foo', None, None, False),
    ('foo', 'bar', None, False),
    ('foo', 'loki', None, False),
    ('foo', 'loki', 'loki', False),
    ('loki', None, None, True),
    ('loki', None, 'foo', False),
    ('loki', 'dataflow', None, True),
    ('loki', 'dataflow', 'dataflow', True),
    ('loki', 'dataflow', 'foobar', False),
    ('loki', 'fusion group(1)', None, True),
    ('loki', 'fusion group(1)', 'fusion', True),
    ('loki', 'fusion group(1)', 'group', False),
])
def test_is_loki_pragma(keyword, content, starts_with, ref):
    """
    Test correct identification of Loki pragmas.
    """
    pragma = Pragma(keyword, content)
    pragma_list = (pragma,)
    if starts_with is not None:
        assert is_loki_pragma(pragma, starts_with=starts_with) == ref
        assert is_loki_pragma(pragma_list, starts_with=starts_with) == ref
    else:
        assert is_loki_pragma(pragma) == ref
        assert is_loki_pragma(pragma_list) == ref


@pytest.mark.parametrize('content, starts_with, ref', [
    (None, None, {}),
    ('', None, {}),
    ('', 'foo', {}),
    ('dataflow', None, {'dataflow': None}),
    ('dataflow', 'dataflow', {}),
    ('dataflow group(1)', None, {'dataflow': None, 'group': '1'}),
    ('dataflow group(1)', 'dataflow', {'group': '1'}),
    ('dataflow group(1)', 'foo', {}),
    ('dataflow group(1) group(2)', 'dataflow', {'group': ['1', '2']}),
    ('foo bar(^£!$%*[]:@+-_=~#/?.,<>;) baz foobar(abc_123")', 'foo',
     {'bar':'^£!$%*[]:@+-_=~#/?.,<>;', 'baz': None, 'foobar': 'abc_123"'}),
    ('target map(a) map(to: b) map(from: c)', None, {'target': None, 'map': ['a', 'to: b', 'from: c']}),
    ('arg1(val1) arg2(val2/val3) arg3((val1 + val2)/(val3))', None, {'arg1': 'val1',
        'arg2': 'val2/val3', 'arg3': '(val1 + val2)/(val3)'})
])
def test_get_pragma_parameters(content, starts_with, ref):
    """
    Test correct extraction of Loki pragma parameters.
    """
    pragma = Pragma('loki', content)
    pragma_list = (pragma,)
    if starts_with is None:
        assert get_pragma_parameters(pragma) == ref
        assert get_pragma_parameters(pragma_list) == ref
    else:
        assert get_pragma_parameters(pragma, starts_with=starts_with) == ref
        assert get_pragma_parameters(pragma_list, starts_with=starts_with) == ref


@pytest.mark.parametrize('frontend', available_frontends())
def test_get_pragma_parameters_multiline(frontend):
    """
    Test correct extraction of Loki pragma parameters from pragmas
    with line-contunation.
    """
    fcode = """
subroutine test_pragmas_map(a)
    implicit none
    real, intent(in) :: a(:,:)
    integer :: i, j, k

!$OMP PARALLEL &
!$OMP &  PRIVATE(i, j) &
!$OMP &  FIRSTPRIVATE( &
!$OMP &        n, a, b &
!$OMP &  )

end subroutine test_pragmas_map
    """.strip()
    routine = Subroutine.from_source(fcode, frontend=frontend)
    pragmas = FindNodes(Pragma).visit(routine.body)

    assert len(pragmas) == 1
    assert pragmas[0].keyword == 'OMP'
    params = get_pragma_parameters(pragmas[0], only_loki_pragmas=False)
    assert len(params) == 3
    assert params['PARALLEL'] is None
    assert params['PRIVATE'].strip() == 'i, j'
    assert params['FIRSTPRIVATE'].strip() == 'n, a, b'

    assert fgen(pragmas[0]) == '!$OMP PARALLEL PRIVATE( i, j ) FIRSTPRIVATE( n, a, b )'


@pytest.mark.parametrize('frontend', available_frontends())
def test_tools_pragma_inlining(frontend):
    """
    A short test that verifies pragmas that are the first statement
    in a routine's body are correctly identified and inlined.
    """
    fcode = """
subroutine test_tools_pragma_inlining (in, out, n)
  implicit none
  real, intent(in) :: in(:)
  real, intent(out) :: out(:)
  integer, intent(in) :: n
  integer :: i
  !$loki some pragma
  do i=1,n
    out(i) = in(i)
  end do
end subroutine test_tools_pragma_inlining
    """
    routine = Subroutine.from_source(fcode, frontend=frontend)

    # Check that pragmas are not inlined
    loops = FindNodes(Loop).visit(routine.body)
    assert len(loops) == 1
    assert loops[0].pragma is None

    # Now inline pragmas and see if everything matches
    routine.body = attach_pragmas(routine.body, Loop)
    loops = FindNodes(Loop).visit(routine.body)
    assert len(loops) == 1
    assert loops[0].pragma is not None
    assert isinstance(loops[0].pragma, tuple) and len(loops[0].pragma) == 1
    assert loops[0].pragma[0].keyword == 'loki' and loops[0].pragma[0].content == 'some pragma'


@pytest.mark.parametrize('frontend', available_frontends())
def test_tools_pragma_inlining_multiple(frontend):
    """
    A short test that verifies that multiple pragmas are inlined
    and kept in the right order.
    """
    fcode = """
subroutine test_tools_pragma_inlining_multiple (in, out, n)
  implicit none
  real, intent(in) :: in(:)
  real, intent(out) :: out(:)
  integer, intent(in) :: n
  integer :: i
  !$blub other pragma
  !$loki some pragma(5)
  !$loki more
  do i=1,n
    out(i) = in(i)
  end do
end subroutine test_tools_pragma_inlining_multiple
    """
    routine = Subroutine.from_source(fcode, frontend=frontend)

    # Check that pragmas are not inlined
    loops = FindNodes(Loop).visit(routine.body)
    assert len(loops) == 1
    assert loops[0].pragma is None

    # Now inline pragmas and see if everything matches
    routine.body = attach_pragmas(routine.body, Loop)
    loops = FindNodes(Loop).visit(routine.body)
    assert len(loops) == 1
    assert loops[0].pragma is not None
    assert isinstance(loops[0].pragma, tuple) and len(loops[0].pragma) == 3
    assert [p.keyword for p in loops[0].pragma] == ['blub', 'loki', 'loki']
    assert loops[0].pragma[0].content == 'other pragma'
    assert loops[0].pragma[1].content == 'some pragma(5)'
    assert loops[0].pragma[2].content == 'more'

    # A few checks for the pragma utility functions
    assert is_loki_pragma(loops[0].pragma)
    assert is_loki_pragma(loops[0].pragma, starts_with='some')
    assert is_loki_pragma(loops[0].pragma, starts_with='more')
    assert not is_loki_pragma(loops[0].pragma, starts_with='other')
    assert get_pragma_parameters(loops[0].pragma) == {'some': None, 'pragma': '5', 'more': None}
    assert get_pragma_parameters(loops[0].pragma, starts_with='some') == {'pragma': '5'}
    # Note: the following is really unexpected behaviour
    assert get_pragma_parameters(loops[0].pragma, only_loki_pragmas=False) == \
            {'some': None, 'pragma': [None, '5'], 'more': None, 'other': None}


@pytest.mark.parametrize('frontend', available_frontends())
def test_tools_pragma_detach(frontend):
    """
    A short test that verifies that multiple pragmas are inlined
    and kept in the right order.
    """
    fcode = """
subroutine test_tools_pragma_detach (in, out, n)
  implicit none
  real, intent(in) :: in(:)
  real, intent(out) :: out(:)
  integer, intent(in) :: n
  integer :: i, j

!$blub other pragma
!$loki some pragma(5)
!$loki more
  do i=1,n
    out(i) = in(i)

!$loki inner pragma
    do j=1,n
      out(i) = out(i) + 1.0
    end do

  end do
end subroutine test_tools_pragma_detach
    """
    routine = Subroutine.from_source(fcode, frontend=frontend)

    # Originally, pragmas shouldn't be inlined
    loops = FindNodes(Loop).visit(routine.body)
    assert len(loops) == 2
    assert all(loop.pragma is None for loop in loops)
    pragmas = FindNodes(Pragma).visit(routine.body)
    assert len(pragmas) == 4

    # Inline pragmas
    ir = attach_pragmas(routine.body, Loop)
    orig_loops = FindNodes(Loop).visit(ir)
    assert len(orig_loops) == 2
    assert all(loop.pragma is not None for loop in orig_loops)
    assert not FindNodes(Pragma).visit(ir)

    # Serialize pragmas
    ir = detach_pragmas(ir, Loop)

    loops = FindNodes(Loop).visit(ir)
    assert len(loops) == 2
    assert all(loop.pragma is None for loop in loops)
    pragmas = FindNodes(Pragma).visit(ir)
    assert len(pragmas) == 4

    # Inline pragmas again
    ir = attach_pragmas(ir, Loop)

    stream_ir = StringIO()
    stream_body = StringIO()
    pprint(ir, stream=stream_ir)
    pprint(routine.body, stream=stream_body)
    assert stream_ir.getvalue() == stream_body.getvalue()

    loops = FindNodes(Loop).visit(ir)
    assert len(loops) == 2
    assert all(loop.pragma is not None for loop in loops)
    assert not FindNodes(Pragma).visit(ir)

    for loop, orig_loop in zip(loops, orig_loops):
        pragma = [p.keyword + ' ' + p.content for p in loop.pragma]
        orig_pragma = [p.keyword + ' ' + p.content for p in orig_loop.pragma]
        assert '\n'.join(pragma) == '\n'.join(orig_pragma)


@pytest.mark.parametrize('frontend', available_frontends())
def test_tools_pragmas_attached_loop(frontend):
    """
    A short test that verifies that the context manager to attach
    pragmas works as expected.
    """
    fcode = """
subroutine test_tools_pragmas_attached_loop(in, out, n)
  implicit none
  real, intent(in) :: in(:)
  real, intent(out) :: out(:)
  integer, intent(in) :: n
  integer :: i, j

!$blub other pragma
!$loki some pragma(5)
!$loki more
  do i=1,n
    out(i) = in(i)

!$loki inner pragma
    do j=1,n
      out(i) = out(i) + 1.0
    end do

  end do
end subroutine test_tools_pragmas_attached_loop
    """
    routine = Subroutine.from_source(fcode, frontend=frontend)

    loops = FindNodes(Loop).visit(routine.body)
    assert len(loops) == 2
    assert all(loop.pragma is None for loop in loops)
    assert len(FindNodes(Pragma).visit(routine.body)) == 4

    with pragmas_attached(routine, Loop):
        # Verify that pragmas are attached
        attached_loops = FindNodes(Loop).visit(routine.body)
        assert len(attached_loops) == 2
        assert all(loop.pragma is not None for loop in attached_loops)
        assert not FindNodes(Pragma).visit(routine.body)

        # Make sure that existing references to nodes still work
        # (and have been changed, too)
        assert all(loop.pragma is not None for loop in loops)

    # Check that the original state is restored
    loops = FindNodes(Loop).visit(routine.body)
    assert len(loops) == 2
    assert all(loop.pragma is None for loop in loops)
    assert len(FindNodes(Pragma).visit(routine.body)) == 4

    # Make sure that reference from inside the context still work
    # (and have their pragmas detached)
    assert all(loop.pragma is None for loop in attached_loops)


@pytest.mark.parametrize('frontend', available_frontends())
def test_tools_pragmas_attached_example(frontend):
    """
    A short test that verifies that the example from the docstring works.
    """
    fcode = """
subroutine test_tools_pragmas_attached_example (in, out, n)
  implicit none
  real, intent(in) :: in(:)
  real, intent(out) :: out(:)
  integer, intent(in) :: n
  integer :: i

  do i=1,n
    out(i) = 0.0
  end do

!$loki foobar
  do i=1,n
    out(i) = in(i)
  end do
end subroutine test_tools_pragmas_attached_example
    """
    routine = Subroutine.from_source(fcode, frontend=frontend)

    loop_of_interest = None
    with pragmas_attached(routine, Loop):
        for loop in FindNodes(Loop).visit(routine.body):
            if is_loki_pragma(loop.pragma, starts_with='foobar'):
                loop_of_interest = loop
                break

    assert loop_of_interest is not None
    assert loop_of_interest.pragma is None


@pytest.mark.parametrize('frontend', available_frontends())
def test_tools_pragmas_attached_post(frontend):
    """
    Verify the inlining of pragma_post.
    """
    fcode = """
subroutine test_tools_pragmas_attached_post(a, jtend, iend, jend)
  ! Code snippet example adapted from CLAW manual
  integer, intent(out) :: a(jend, iend, jtend)
  integer, intent(in) :: jtend, iend, jend
  integer :: jt, i, j

!$acc parallel loop gang vector collapse(2)
  DO jt=1,jtend
    DO i=1,iend
      DO j=1,jend
        a(j, i, jt) = j + i + jt
      END DO
    END DO
  END DO
!$acc end parallel
end subroutine test_tools_pragmas_attached_post
    """
    routine = Subroutine.from_source(fcode, frontend=frontend)
    assert len(FindNodes(Pragma).visit(routine.body)) == 2
    loop = FindNodes(Loop).visit(routine.body)
    assert len(loop) == 3
    loop = loop[0]
    assert loop.pragma is None and loop.pragma_post is None

    with pragmas_attached(routine, Loop, attach_pragma_post=False):
        assert isinstance(loop.pragma, tuple) and len(loop.pragma) == 1
        assert loop.pragma[0].keyword.lower() == 'acc'
        assert loop.pragma_post is None
        assert len(FindNodes(Pragma).visit(routine.body)) == 1

    assert loop.pragma is None and loop.pragma_post is None

    # default behaviour: attach_pragma_post=True
    with pragmas_attached(routine, Loop):
        assert isinstance(loop.pragma, tuple) and len(loop.pragma) == 1
        assert loop.pragma[0].keyword.lower() == 'acc'
        assert isinstance(loop.pragma_post, tuple) and len(loop.pragma_post) == 1
        assert loop.pragma_post[0].keyword.lower() == 'acc'
        assert not FindNodes(Pragma).visit(routine.body)

    assert loop.pragma is None and loop.pragma_post is None
    assert len(FindNodes(Pragma).visit(routine.body)) == 2


@pytest.mark.parametrize('frontend', available_frontends())
def test_tools_pragmas_attached_module(frontend, tmp_path):
    """
    Verify pragmas_attached works for Module objects.
    """
    fcode = """
module test_tools_pragmas_attached_module
  integer, allocatable :: a(:)
!$loki dimension(10, 20)
  integer, allocatable :: b(:,:)
end module test_tools_pragmas_attached_module
    """
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    assert len(FindNodes(Pragma).visit(module.spec)) == 1
    decl = FindNodes(VariableDeclaration).visit(module.spec)[1]
    assert len(decl.symbols) == 1 and decl.symbols[0].name.lower() == 'b'
    assert decl.pragma is None

    with pragmas_attached(module, VariableDeclaration):
        assert not FindNodes(Pragma).visit(module.spec)
        assert isinstance(decl.pragma, tuple) and is_loki_pragma(decl.pragma, starts_with='dimension')

    assert decl.pragma is None
    assert len(FindNodes(Pragma).visit(module.spec)) == 1


@pytest.mark.parametrize('frontend', available_frontends())
def test_tools_pragma_regions_attached(frontend):
    """
    Verify ``pragma_regions_attached`` creates and removes `PragmaRegion` objects.
    """
    fcode = """
subroutine test_tools_pragmas_attached_region (in, out, n)
  implicit none
  real, intent(in) :: in(:)
  real, intent(out) :: out(:)
  integer, intent(in) :: n
  integer :: i

  out(0) = -1.0

!$loki whatever

  out(0) = -2.0

  !$loki do_something
  do i=1,n
    out(i) = 0.0
  end do
!$loki end whatever

  do i=1,n
    out(i) = 1.0
  end do

!$foo bar
  do i=1,n
    out(i) = in(i)
  end do
!$foo end bar
end subroutine test_tools_pragmas_attached_region
    """
    routine = Subroutine.from_source(fcode, frontend=frontend)

    loops = FindNodes(Loop).visit(routine.body)
    assert len(loops) == 3
    assert all(loop.pragma is None for loop in loops)
    assert len(FindNodes(Pragma).visit(routine.body)) == 5

    with pragma_regions_attached(routine):
        assert len(FindNodes(Pragma).visit(routine.body)) == 1
        assert len(FindNodes(PragmaRegion).visit(routine.body)) == 2
        # Find loops inside regions
        regions = FindNodes(PragmaRegion).visit(routine.body)
        region_loops = flatten(FindNodes(Loop).visit(r) for r in regions)
        assert len(region_loops) == 2
        assert all(l in loops for l in region_loops)

    # Verify that loops from context are still valid
    assert all(l in loops for l in region_loops)

    # Ensure that everything is back to where it was
    loops_after = FindNodes(Loop).visit(routine.body)
    assert len(loops_after) == 3
    assert loops_after == loops
    assert all(loop.pragma is None for loop in loops_after)
    assert len(FindNodes(Pragma).visit(routine.body)) == 5


@pytest.mark.parametrize('frontend', available_frontends())
def test_tools_pragma_regions_attached_nested(frontend):
    """
    Verify ``pragma_regions_attached`` creates and removes `PragmaRegion` objects.
    """
    fcode = """
subroutine test_tools_pragmas_attached_region (in, out, n)
  implicit none
  real, intent(in) :: in(:)
  real, intent(out) :: out(:)
  integer, intent(in) :: n
  integer :: i

  out(0) = -1.0

!$loki data foo

  out(0) = -2.0

  !$loki data nofoo endfoo
  do i=1,n
    !$loki do_nothing
    out(i) = 0.0
  end do
  !$loki end data

  do i=1,n
    out(i) = 1.0
  end do

  !$loki data tofu
  do i=1,n
    out(i) = in(i)
  end do
  !$loki end data

!$loki end data

end subroutine test_tools_pragmas_attached_region
    """
    routine = Subroutine.from_source(fcode, frontend=frontend)

    loops = FindNodes(Loop).visit(routine.body)
    assert len(loops) == 3
    assert all(loop.pragma is None for loop in loops)
    assert len(FindNodes(Pragma).visit(routine.body)) == 7

    with pragma_regions_attached(routine):
        assert len(FindNodes(Pragma).visit(routine.body)) == 1
        assert len(FindNodes(PragmaRegion).visit(routine.body)) == 3

        # Check that we are finding the right loops for each region
        regions = FindNodes(PragmaRegion).visit(routine.body)
        assert len(FindNodes(PragmaRegion).visit(regions[0].body)) == 2
        assert len(FindNodes(Loop).visit(regions[0])) == 3
        assert len(FindNodes(Loop).visit(regions[1])) == 1
        assert len(FindNodes(Loop).visit(regions[2])) == 1

        # Check that all loops in outer region are unchanged
        region_loops = FindNodes(Loop).visit(regions[0])
        assert all(l in loops for l in region_loops)

    # Verify that loops from context are still valid
    assert all(l in loops for l in region_loops)

    # Ensure that everything is back to where it was
    loops_after = FindNodes(Loop).visit(routine.body)
    assert len(loops_after) == 3
    assert loops_after == loops
    assert all(loop.pragma is None for loop in loops_after)
    assert len(FindNodes(Pragma).visit(routine.body)) == 7


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('keyword,num_pragmas,num_pragma_regions', [
    (None, 4, 2), # Default/legacy behaviour: this fails to attach the first Loki region because
                  # a second pragma region begins within the first, and cannot close the
                  # ACC region for the same reason
    ('loki', 4, 2), # Creates two Loki pragma regions, each containing an ACC pragma
    ('acc', 6, 1), # Create only the ACC kernels region
])
def test_tools_pragma_regions_attached_keyword(frontend, keyword, num_pragmas, num_pragma_regions):
    fcode = """
subroutine nested_regions(arg)
implicit none
real, intent(inout) :: arg

!$LoKi remove
!$acc kernels
!$loki end remove
arg = 5
!$loki remove
!$ACC end kernels
!$LOKI end remove

!$lokiish print
print *,'hello world'
!$LOKIish end print

end subroutine nested_regions
    """.strip()

    routine = Subroutine.from_source(fcode, frontend=frontend)
    assert len(FindNodes(Pragma).visit(routine.body)) == 8
    assert not FindNodes(PragmaRegion).visit(routine.body)

    with pragma_regions_attached(routine, keyword=keyword):
        assert len(FindNodes(Pragma).visit(routine.body)) == num_pragmas
        assert len(FindNodes(PragmaRegion).visit(routine.body)) == num_pragma_regions

    assert len(FindNodes(Pragma).visit(routine.body)) == 8
    assert not FindNodes(PragmaRegion).visit(routine.body)


@pytest.mark.parametrize('frontend', available_frontends())
def test_long_pragmas(frontend):
    """
    Test correct dealing with long pragmas.
    """
    fcode = """
subroutine test_long_pragmas(in, out, n)
  implicit none
  real, intent(in) :: in(:)
  real, intent(out) :: out(:)
  real :: some_very_long_temporary_variable_name, some_even_longer_variable_name, another_variable_name
  real :: even_more_variable_name, really_really_i_mean_we_need_like_a_lot
  integer, intent(in) :: n
  integer :: i

  !$acc data &
  !$acc   copyin(some_very_long_temporary_variable_name, some_even_longer_variable_name) &
  !$acc   copy(another_variable_name, even_more_variable_name, really_really_i_mean_we_need_like_a_lot) &
  !$acc   copyout(out)

  do i=1,n
    out(i) = in(i)
  end do

  !$acc end data
end subroutine test_long_pragmas
    """
    routine = Subroutine.from_source(fcode, frontend=frontend)

    for line in fgen(routine).splitlines():
        assert len(line) < 135


@pytest.mark.parametrize('frontend', available_frontends())
def test_pragmas_map(frontend):
    """
    Test correct handling of pragmas with multiple occurences of same keyword.
    """
    fcode = """
subroutine test_pragmas_map(n, a, b, c)
    implicit none
    integer, intent(in) :: n
    real, intent(in) :: a(:,:), b(:,:)
    real, intent(inout) :: c(:,:)
    integer :: i, j, k

!$omp target map(to: a) map(b) map(tofrom: c)
!$omp parallel do private(j,i,k)
    do j=1,n
        do i=1,n
            do k=1,n
                c(i,j) = c(i,j) + a(i,k) * b(k,j)
            enddo
        enddo
    enddo
!$omp end parallel do
!$omp end target
end subroutine test_pragmas_map
    """.strip()

    routine = Subroutine.from_source(fcode, frontend=frontend)
    pragmas = FindNodes(Pragma).visit(routine.body)

    assert len(pragmas) == 4
    assert all(p.keyword.lower() == 'omp' for p in pragmas)
    assert all(v in pragmas[0].content for v in ['target', 'map(to: a)', 'map(b)', 'map(tofrom: c)'])

    fgen_code = fgen(pragmas[0]).lower()
    assert '!$omp' in fgen_code
    assert 'target' in fgen_code
    assert 'map( to: a )' in fgen_code
    assert 'map( b )' in fgen_code
    assert 'map( tofrom: c )' in fgen_code


@pytest.mark.parametrize('frontend', available_frontends())
def test_pragmas_mixed_key_value_attrs(frontend):
    """
    Test correct handling of pragmas that contain attributes with and without
    values in parentheses (reported in #317).
    """
    fcode = """
SUBROUTINE TEST()
IMPLICIT NONE
END SUBROUTINE TEST
    """.strip()

    routine = Subroutine.from_source(fcode, frontend=frontend)

    pragma = Pragma(keyword='acc', content='kernels num_gangs ( 1 ) async wait')
    assert get_pragma_parameters(pragma, only_loki_pragmas=False) == {
        'kernels': None,
        'num_gangs': ' 1 ',
        'async': None,
        'wait': None
    }

    pragma = Pragma(keyword='acc', content=f'seq routine ({routine.name})')
    assert get_pragma_parameters(pragma, only_loki_pragmas=False) == {
        'seq': None,
        'routine': routine.name
    }

    pragma = Pragma(keyword='acc', content=f'routine ({routine.name}) seq')
    assert get_pragma_parameters(pragma, only_loki_pragmas=False) == {
        'seq': None,
        'routine': routine.name
    }

    routine.spec.prepend(pragma)
    fgen_code = routine.to_fortran()
    assert f'!$acc routine( {routine.name} ) seq' in fgen_code


@pytest.mark.parametrize('frontend', available_frontends())
def test_substitute_pragma_strings(frontend):
    """
    Test pragma string mapping using the :any:`SubstitutePragmaStrings` transformer.
    """

    fcode = """
subroutine test()
integer :: i

!$loki region sOMe(stupIdly rIDICULOusly) long(: .OK. :) /we aRe stIll g01Ng ... m*AY+b3% &
!$loki &lEt's sT0P

if (.true.) then
   !$loki sOMe(stupIdly rIDICULOusly) long(: .OK. :) /we aRe stIll g01Ng ... m*AY+b3% &
   !$loki &lEt's sT0P
   if (.not. .false.) then
      !$loki sOMe(stupIdly rIDICULOusly) long(: .OK. :) /we aRe stIll g01Ng ... m*AY+b3% &
      !$loki &lEt's sT0P
      do i = 1,10
        !$loki sOMe(stupIdly rIDICULOusly) long(: .OK. :) /we aRe stIll g01Ng ... m*AY+b3% &
        !$loki &lEt's sT0P
      enddo
   endif
endif

!$loki end region

end subroutine test
"""

    routine = Subroutine.from_source(fcode, frontend=frontend)
    str_map = {
        'stupIdly': 'A bIt lESs stUPid nOW',
        'rIDICULOusly': ', eVeN bEttEr Now!',
        'long(: .OK. :)': 'take care TO ONLY PUT COLONS in a parameter',
        '/we': 'foRWARDslash test pASSEd',
        'aRe stIll g01Ng': 'and we insERTed (01)'
        }

    # OMNI adds the keyword for continued lines to the
    # pragma content, so pragma string substitutions across
    # lines only work for FP
    fp_only_str_map = {
        "... m*AY+b3% lEt's sT0P": "thanks I'm tired now"
    }

    if frontend == FP:
        str_map.update(fp_only_str_map)

    transform = SubstitutePragmaStrings(str_map)
    #in-place transformer
    transform.visit(routine.body)
    pragmas = FindNodes(Pragma).visit(routine.body)

    assert len(pragmas) == 5

    for pragma in pragmas:
        if pragma.content == 'end region':
            continue
        for k, v in str_map.items():
            assert not k in pragma.content
            assert v in pragma.content

    # double check that the pragma region was also updated
    with pragma_regions_attached(routine):
        pragma_region = FindNodes(PragmaRegion).visit(routine.body)
        assert 'sOMe(A bIt lESs stUPid nOW , eVeN bEttEr Now!)' in pragma_region[0].pragma.content
loki-ecmwf-0.3.6/loki/ir/tests/test_transformer.py0000664000175000017500000006142015167130205022431 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki import Subroutine, fgen
from loki.frontend import available_frontends, OMNI, SourceStatus
from loki.ir import (
    nodes as ir, FindNodes, Transformer, NestedTransformer,
    MaskedTransformer, NestedMaskedTransformer, SubstituteExpressions
)
from loki.expression import symbols as sym


@pytest.mark.parametrize('frontend', available_frontends())
def test_transformer_source_invalidation_replace(frontend):
    """
    Test basic transformer functionality and verify source invalidation
    when replacing nodes.
    """
    fcode = """
subroutine routine_simple (x, y, scalar, vector, matrix)
  integer, parameter :: jprb = selected_real_kind(13,300)
  integer, intent(in) :: x, y
  real(kind=jprb), intent(in) :: scalar
  real(kind=jprb), intent(inout) :: vector(x), matrix(x, y)
  integer :: i, j

  do i=1, x
    vector(i) = vector(i) + scalar
    do j=1, y
      if (j > i) then
        matrix(i, j) = real(i * j, kind=jprb) + 1.
      else
        matrix(i, j) = i * vector(j)
      end if
    end do
  end do
end subroutine routine_simple
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    # Replace the innermost statement in the body of the conditional
    def get_innermost_statement(nodes):
        for stmt in FindNodes(ir.Assignment).visit(nodes):
            if 'matrix' in str(stmt.lhs) and isinstance(stmt.rhs, sym.Sum):
                return stmt
        return None

    stmt = get_innermost_statement(routine.ir)
    new_expr = sym.Sum((*stmt.rhs.children[:-1], sym.FloatLiteral(2.)))

    # Check source invalidation via status flags
    mapper = {stmt: stmt.clone(rhs=new_expr, source=stmt.source.clone().invalidate())}
    body_invalid_source = Transformer(mapper, invalidate_source=True).visit(routine.body)

    # Check that original source has not been modified
    assert stmt.source and stmt.source.status == SourceStatus.VALID

    # Check that directly and indirectly affected nodes have been invalidated
    assigns = FindNodes(ir.Assignment).visit(body_invalid_source)
    assert len(assigns) == 3
    assert assigns[0].source.status == SourceStatus.VALID
    assert assigns[1].source.status == SourceStatus.INVALID_NODE
    assert assigns[2].source.status == SourceStatus.VALID

    loops = FindNodes(ir.Loop).visit(body_invalid_source)
    assert len(loops) == 2
    assert loops[0].source.status == SourceStatus.INVALID_CHILDREN
    assert loops[1].source.status == SourceStatus.INVALID_CHILDREN

    conds = FindNodes(ir.Conditional).visit(body_invalid_source)
    assert len(conds) == 1
    assert conds[0].source.status == SourceStatus.INVALID_CHILDREN

    # Check manual source removal without invalidation
    mapper = {stmt: stmt.clone(rhs=new_expr, source=None)}
    body_valid_source = Transformer(mapper, invalidate_source=False).visit(routine.body)

    assigns = FindNodes(ir.Assignment).visit(body_valid_source)
    assert len(assigns) == 3
    assert assigns[0].source.status == SourceStatus.VALID
    assert assigns[1].source is None
    assert assigns[2].source.status == SourceStatus.VALID

    loops = FindNodes(ir.Loop).visit(body_valid_source)
    assert len(loops) == 2
    assert loops[0].source.status == SourceStatus.VALID
    assert loops[1].source.status == SourceStatus.VALID

    conds = FindNodes(ir.Conditional).visit(body_valid_source)
    assert len(conds) == 1
    assert conds[0].source.status == SourceStatus.VALID


@pytest.mark.parametrize('frontend', available_frontends())
def test_transformer_source_invalidation_prepend(frontend):
    """
    Test basic transformer functionality and verify source invalidation
    when adding items to a loop body.
    """
    fcode = """
subroutine routine_simple (x, y, scalar, vector, matrix)
  integer, parameter :: jprb = selected_real_kind(13,300)
  integer, intent(in) :: x, y
  real(kind=jprb), intent(in) :: scalar
  real(kind=jprb), intent(inout) :: vector(x), matrix(x, y)
  integer :: i, j

  do i=1, x
    vector(i) = vector(i) + scalar
    do j=1, y
      if (j > i) then
        matrix(i, j) = real(i * j, kind=jprb) + 1.
      else
        matrix(i, j) = i * vector(j)
      end if
    end do
  end do
end subroutine routine_simple
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    # Insert a new statement before the conditional
    def get_conditional(nodes):
        return FindNodes(ir.Conditional).visit(nodes)[0]

    cond = get_conditional(routine.ir)
    new_stmt = ir.Assignment(lhs=routine.arguments[0], rhs=routine.arguments[1])
    mapper = {cond: (new_stmt, cond.clone(source=cond.source.clone().invalidate()))}

    body_invalid_source = Transformer(mapper, invalidate_source=True).visit(routine.body)

    # Check that original source has not been modified
    assert cond.source and cond.source.status == SourceStatus.VALID
    assert not new_stmt.source

    # Check that directly and indirectly affected nodes have been invalidated
    assigns = FindNodes(ir.Assignment).visit(body_invalid_source)
    assert len(assigns) == 4
    assert assigns[0].source.status == SourceStatus.VALID
    assert not assigns[1].source
    assert assigns[2].source.status == SourceStatus.VALID
    assert assigns[2].source.status == SourceStatus.VALID

    loops = FindNodes(ir.Loop).visit(body_invalid_source)
    assert len(loops) == 2
    assert loops[0].source.status == SourceStatus.INVALID_CHILDREN
    assert loops[1].source.status == SourceStatus.INVALID_CHILDREN

    conds = FindNodes(ir.Conditional).visit(body_invalid_source)
    assert len(conds) == 1
    assert conds[0].source.status == SourceStatus.INVALID_NODE

    # Check manual source removal without invalidation
    mapper = {cond: (new_stmt, cond.clone())}
    body_valid_source = Transformer(mapper, invalidate_source=False).visit(routine.body)

    assigns = FindNodes(ir.Assignment).visit(body_valid_source)
    assert len(assigns) == 4
    assert assigns[0].source.status == SourceStatus.VALID
    assert not assigns[1].source
    assert assigns[2].source.status == SourceStatus.VALID
    assert assigns[2].source.status == SourceStatus.VALID

    loops = FindNodes(ir.Loop).visit(body_valid_source)
    assert len(loops) == 2
    assert loops[0].source.status == SourceStatus.VALID
    assert loops[1].source.status == SourceStatus.VALID

    conds = FindNodes(ir.Conditional).visit(body_valid_source)
    assert len(conds) == 1
    assert conds[0].source.status == SourceStatus.VALID


@pytest.mark.parametrize('frontend', available_frontends())
def test_transformer_rebuild(frontend):
    """
    Test basic transformer functionality with and without node rebuilding.
    """
    fcode = """
subroutine routine_simple (x, y, scalar, vector, matrix)
  integer, parameter :: jprb = selected_real_kind(13,300)
  integer, intent(in) :: x, y
  real(kind=jprb), intent(in) :: scalar
  real(kind=jprb), intent(inout) :: vector(x), matrix(x, y)
  integer :: i, j

  do i=1, x
    vector(i) = vector(i) + scalar
    do j=1, y
      if (j > i) then
        matrix(i, j) = real(i * j, kind=jprb) + 1.
      else
        matrix(i, j) = i * vector(j)
      end if
    end do
  end do
end subroutine routine_simple
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    # Replace the innermost statement in the body of the conditional
    def get_innermost_statement(nodes):
        for stmt in FindNodes(ir.Assignment).visit(nodes):
            if 'matrix' in str(stmt.lhs) and isinstance(stmt.rhs, sym.Sum):
                return stmt
        return None

    stmt = get_innermost_statement(routine.ir)
    new_expr = sym.Sum((*stmt.rhs.children[:-1], sym.FloatLiteral(2.)))
    new_stmt = ir.Assignment(stmt.lhs, new_expr)
    mapper = {stmt: new_stmt}

    loops = FindNodes(ir.Loop).visit(routine.body)
    conds = FindNodes(ir.Conditional).visit(routine.body)

    # Check that all loops and conditionals around statements are rebuilt
    body_rebuild = Transformer(mapper, inplace=False).visit(routine.body)
    stmts_rebuild = [str(s) for s in FindNodes(ir.Assignment).visit(body_rebuild)]
    loops_rebuild = FindNodes(ir.Loop).visit(body_rebuild)
    conds_rebuild = FindNodes(ir.Conditional).visit(body_rebuild)
    assert str(stmt) not in stmts_rebuild
    assert str(new_stmt) in stmts_rebuild
    assert not any(l in loops for l in loops_rebuild)
    assert not any(c in conds for c in conds_rebuild)

    # Check that no loops or conditionals around statements are rebuilt
    body_no_rebuild = Transformer(mapper, inplace=True).visit(routine.body)
    stmts_no_rebuild = [str(s) for s in FindNodes(ir.Assignment).visit(body_no_rebuild)]
    loops_no_rebuild = FindNodes(ir.Loop).visit(body_no_rebuild)
    conds_no_rebuild = FindNodes(ir.Conditional).visit(body_no_rebuild)
    assert str(stmt) not in stmts_no_rebuild
    assert str(new_stmt) in stmts_no_rebuild
    assert all(l in loops for l in loops_no_rebuild)
    assert all(c in conds for c in conds_no_rebuild)

    # Check that no loops or conditionals around statements are rebuilt,
    # even if source_invalidation is deactivated
    body_no_rebuild = Transformer(mapper, invalidate_source=False, inplace=True).visit(routine.body)
    stmts_no_rebuild = [str(s) for s in FindNodes(ir.Assignment).visit(body_no_rebuild)]
    loops_no_rebuild = FindNodes(ir.Loop).visit(body_no_rebuild)
    conds_no_rebuild = FindNodes(ir.Conditional).visit(body_no_rebuild)
    assert str(stmt) not in stmts_no_rebuild
    assert str(new_stmt) in stmts_no_rebuild
    assert all(l in loops for l in loops_no_rebuild)
    assert all(c in conds for c in conds_no_rebuild)


@pytest.mark.parametrize('frontend', available_frontends())
def test_transformer_multinode_keys(frontend):
    """
    Test basic transformer functionality with nulti-node keys
    """
    fcode = """
subroutine routine_simple (x, y, a, b, c, d, e)
  integer, parameter :: jprb = selected_real_kind(13,300)
  integer, intent(in) :: x, y
  real(kind=jprb), intent(in) :: a(x), b(x), c(x), d(x), e(x)
  integer :: i

  b(i) = a(i) + 1.
  c(i) = a(i) + 2.
  d(i) = c(i) + 3.
  e(i) = d(i) + 4.
end subroutine routine_simple
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)
    assigns = FindNodes(ir.Assignment).visit(routine.body)
    bounds = sym.LoopRange((sym.IntLiteral(1), routine.variable_map['x']))

    # Filter out only the two middle assignments to wrap in a loop.
    # Note that we need to be careful to clone loop body nodes to
    # avoid infinite recursion.
    assigns = tuple(a for a in assigns if a.lhs in ['c(i)', 'd(i)'])
    loop = ir.Loop(variable=routine.variable_map['i'], bounds=bounds,
                   body=tuple(a.clone() for a in assigns))
    # Need to use NestedTransformer here, since replacement contains
    # the original nodes.
    transformed = NestedTransformer({assigns: loop}).visit(routine.body)

    new_loops = FindNodes(ir.Loop).visit(transformed)
    assert len(new_loops) == 1
    assert len(FindNodes(ir.Assignment).visit(new_loops)) == 2
    assert len(FindNodes(ir.Assignment).visit(transformed)) == 4


@pytest.mark.parametrize('frontend', available_frontends())
def test_masked_transformer(frontend):
    """
    A very basic sanity test for the MaskedTransformer class.
    """
    fcode = """
subroutine masked_transformer(a)
  integer, intent(inout) :: a

  a = a + 1
  a = a + 2
  a = a + 3
  a = a + 4
  a = a + 5
  a = a + 6
  a = a + 7
  a = a + 8
  a = a + 9
  a = a + 10
end subroutine masked_transformer
    """

    routine = Subroutine.from_source(fcode, frontend=frontend)
    assignments = FindNodes(ir.Assignment).visit(routine.body)

    # Removes all nodes
    body = MaskedTransformer(start=None, stop=None).visit(routine.body)
    assert not FindNodes(ir.Assignment).visit(body)

    # Retains all nodes
    body = MaskedTransformer(start=None, stop=None, active=True).visit(routine.body)
    assert len(FindNodes(ir.Assignment).visit(body)) == 10

    # Removes all nodes but the last
    body = MaskedTransformer(start=assignments[-1], stop=None).visit(routine.body)
    assert len(FindNodes(ir.Assignment).visit(body)) == 1

    # Retains all nodes but the last
    body = MaskedTransformer(start=None, stop=assignments[-1], active=True).visit(routine.body)
    assert len(FindNodes(ir.Assignment).visit(body)) == len(assignments) - 1

    # Retains the first two and last two nodes
    start = [assignments[0], assignments[-2]]
    stop = assignments[2]
    body = MaskedTransformer(start=start, stop=stop).visit(routine.body)
    assert len(FindNodes(ir.Assignment).visit(body)) == 4

    # Retains the first two and the second to last node
    start = [assignments[0], assignments[-2]]
    stop = [assignments[2], assignments[-1]]
    body = MaskedTransformer(start=start, stop=stop).visit(routine.body)
    assert len(FindNodes(ir.Assignment).visit(body)) == 3

    # Retains three nodes in the middle
    start = assignments[3]
    stop = assignments[6]
    body = MaskedTransformer(start=start, stop=stop).visit(routine.body)
    assert len(FindNodes(ir.Assignment).visit(body)) == 3

    # Retains nodes two to four and replaces the third by the first node
    start = assignments[1]
    stop = assignments[4]
    mapper = {assignments[2]: assignments[0]}
    body = MaskedTransformer(start=start, stop=stop, mapper=mapper).visit(routine.body)
    assert len(FindNodes(ir.Assignment).visit(body)) == 3
    assert str(FindNodes(ir.Assignment).visit(body)[1]) == str(assignments[0])

    # Retains nodes two to four and replaces the second by the first node
    start = assignments[1]
    stop = assignments[4]
    mapper = {assignments[1]: assignments[0]}
    body = MaskedTransformer(start=start, stop=stop, mapper=mapper).visit(routine.body)
    assert len(FindNodes(ir.Assignment).visit(body)) == 3
    assert str(FindNodes(ir.Assignment).visit(body)[0]) == str(assignments[0])


@pytest.mark.parametrize('frontend', available_frontends())
def test_masked_transformer_minimum_set(frontend):
    """
    A very basic sanity test for the MaskedTransformer class with
    require_all_start or greedy_stop properties.
    """
    fcode = """
subroutine masked_transformer_minimum_set(a)
  integer, intent(inout) :: a

  a = a + 1
  a = a + 2
  a = a + 3
  a = a + 4
  a = a + 5
  a = a + 6
  a = a + 7
  a = a + 8
  a = a + 9
  a = a + 10
end subroutine masked_transformer_minimum_set
    """

    routine = Subroutine.from_source(fcode, frontend=frontend)
    assignments = FindNodes(ir.Assignment).visit(routine.body)

    # Requires all nodes and thus retains only the last
    body = MaskedTransformer(start=assignments, require_all_start=True).visit(routine.body)
    assert len(FindNodes(ir.Assignment).visit(body)) == 1
    assert fgen(body) == fgen(assignments[-1])

    # Retains only the second node
    body = MaskedTransformer(start=assignments[:2], stop=assignments[2], require_all_start=True).visit(routine.body)
    assert len(FindNodes(ir.Assignment).visit(body)) == 1
    assert fgen(body) == fgen(assignments[1])

    # Retains only first node
    body = MaskedTransformer(start=assignments, stop=assignments[1], greedy_stop=True).visit(routine.body)
    assert len(FindNodes(ir.Assignment).visit(body)) == 1
    assert fgen(body) == fgen(assignments[0])


@pytest.mark.parametrize('frontend', available_frontends())
def test_masked_transformer_associates(frontend):
    """
    Test the masked transformer in conjunction with associate blocks
    """
    fcode = """
subroutine masked_transformer(a)
  integer, intent(inout) :: a

associate(b=>a)
  b = b + 1
  b = b + 2
  b = b + 3
  b = b + 4
  b = b + 5
  b = b + 6
  b = b + 7
  b = b + 8
  b = b + 9
  b = b + 10
end associate
end subroutine masked_transformer
    """

    routine = Subroutine.from_source(fcode, frontend=frontend)
    assignments = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assignments) == 10
    assert len(FindNodes(ir.Associate).visit(routine.body)) == 1

    # Removes all nodes
    body = MaskedTransformer(start=None, stop=None).visit(routine.body)
    assert not FindNodes(ir.Assignment).visit(body)
    assert not FindNodes(ir.Associate).visit(body)

    # Removes all nodes but the last
    body = MaskedTransformer(start=assignments[-1], stop=None).visit(routine.body)
    assert len(FindNodes(ir.Assignment).visit(body)) == 1
    assert not FindNodes(ir.Associate).visit(body)

    # Retains all nodes but the last
    body = MaskedTransformer(start=None, stop=assignments[-1], active=True).visit(routine.body)
    assert len(FindNodes(ir.Assignment).visit(body)) == len(assignments) - 1
    assert len(FindNodes(ir.Associate).visit(body)) == 1

    # Retains the first two and last two nodes
    start = [assignments[0], assignments[-2]]
    stop = assignments[2]
    body = MaskedTransformer(start=start, stop=stop).visit(routine.body)
    assert len(FindNodes(ir.Assignment).visit(body)) == 4
    assert not FindNodes(ir.Associate).visit(body)

    # Retains the first two and the second to last node
    start = [assignments[0], assignments[-2]]
    stop = [assignments[2], assignments[-1]]
    body = MaskedTransformer(start=start, stop=stop).visit(routine.body)
    assert len(FindNodes(ir.Assignment).visit(body)) == 3
    assert not FindNodes(ir.Associate).visit(body)

    # Retains three nodes in the middle
    start = assignments[3]
    stop = assignments[6]
    body = MaskedTransformer(start=start, stop=stop).visit(routine.body)
    assert len(FindNodes(ir.Assignment).visit(body)) == 3
    assert not FindNodes(ir.Associate).visit(body)

    # Retains all nodes but the last, but check with ``inplace=True``
    body = MaskedTransformer(start=None, stop=assignments[-1], active=True, inplace=True).visit(routine.body)
    assert len(FindNodes(ir.Assignment).visit(body)) == len(assignments) - 1
    assocs = FindNodes(ir.Associate).visit(body)
    assert len(assocs) == 1
    assert len(assocs[0].body) == len(assignments) - 1
    assert all(isinstance(n, ir.Assignment) for n in assocs[0].body)


@pytest.mark.parametrize('frontend', available_frontends())
def test_nested_masked_transformer(frontend):
    """
    Test the masked transformer in conjunction with nesting
    """
    fcode = """
subroutine nested_masked_transformer
  implicit none
  integer :: a=0, b, c, d
  integer :: i, j

  do i=1,10
    a = a + i
    if (a < 5) then
      b = 0
    else if (a == 5) then
      c = 0
    else
      do j=1,5
        d = a
      end do
    end if
  end do
end subroutine nested_masked_transformer
    """.strip()

    routine = Subroutine.from_source(fcode, frontend=frontend)
    assignments = FindNodes(ir.Assignment).visit(routine.body)
    loops = FindNodes(ir.Loop).visit(routine.body)
    conditionals = FindNodes(ir.Conditional).visit(routine.body)
    assert len(assignments) == 4
    assert len(loops) == 2
    assert len(conditionals) == 2 if frontend == OMNI else 1

    # Drops the outermost loop
    start = [a for a in assignments if a.lhs == 'a']
    body = MaskedTransformer(start=start).visit(routine.body)
    assert len(FindNodes(ir.Assignment).visit(body)) == 4
    assert len(FindNodes(ir.Loop).visit(body)) == 1
    assert len(FindNodes(ir.Conditional).visit(body)) == len(conditionals)

    # Should produce the original version
    body = NestedMaskedTransformer(start=start).visit(routine.body)
    assert len(FindNodes(ir.Assignment).visit(body)) == 4
    assert len(FindNodes(ir.Loop).visit(body)) == 2
    assert len(FindNodes(ir.Conditional).visit(body)) == len(conditionals)
    assert fgen(routine.body).strip() == fgen(body).strip()

    # Should drop the first assignment
    body = NestedMaskedTransformer(start=conditionals[0]).visit(routine.body)
    assert len(FindNodes(ir.Assignment).visit(body)) == 3
    assert len(FindNodes(ir.Loop).visit(body)) == 2
    assert len(FindNodes(ir.Conditional).visit(body)) == len(conditionals)

    # Should leave no more than a single assignment
    start = [a for a in assignments if a.lhs == 'c']
    stop = [l for l in loops if l.variable == 'j']
    body = MaskedTransformer(start=start, stop=stop).visit(routine.body)
    assert fgen(start).strip() == fgen(body).strip()

    # Should leave a single assignment with the hierarchy of nested sections
    # in the else-if branch
    body = NestedMaskedTransformer(start=start, stop=stop).visit(routine.body)
    assert len(FindNodes(ir.Assignment).visit(body)) == 1
    assert len(FindNodes(ir.Loop).visit(body)) == 1
    assert len(FindNodes(ir.Conditional).visit(body)) == 1

    # Should leave no more than a single assignment
    start = [a for a in assignments if a.lhs == 'd']
    body = MaskedTransformer(start=start, stop=start).visit(routine.body)
    assert fgen(start).strip() == fgen(body).strip()

    # Should leave a single assignment with the hierarchy of nested sections
    # in the else branch
    body = NestedMaskedTransformer(start=start, stop=start).visit(routine.body)
    assert len(FindNodes(ir.Assignment).visit(body)) == 1
    assert len(FindNodes(ir.Loop).visit(body)) == 2
    assert len(FindNodes(ir.Conditional).visit(body)) == 0

    # Should produce the original body
    start = [a for a in assignments if a.lhs in ('a', 'd')]
    body = NestedMaskedTransformer(start=start).visit(routine.body)
    assert fgen(routine.body).strip() == fgen(body).strip()

    # Should leave a single assignment with the hierarchy of nested sections
    # in the else branch
    body = NestedMaskedTransformer(start=start, require_all_start=True).visit(routine.body)
    assert [a.lhs == 'd' for a in FindNodes(ir.Assignment).visit(body)] == [True]
    assert len(FindNodes(ir.Loop).visit(body)) == 2
    assert len(FindNodes(ir.Conditional).visit(body)) == 0

    # Drops everything
    stop = [a for a in assignments if a.lhs == 'a']
    body = NestedMaskedTransformer(start=start, stop=stop, greedy_stop=True).visit(routine.body)
    assert not body

    # Should drop the else-if branch
    start = [a for a in assignments if a.lhs in ('b', 'd')]
    stop = [a for a in assignments if a.lhs == 'c']
    body = NestedMaskedTransformer(start=start, stop=stop).visit(routine.body)
    assert len(FindNodes(ir.Assignment).visit(body)) == 2
    assert len(FindNodes(ir.Loop).visit(body)) == 2
    assert len(FindNodes(ir.Conditional).visit(body)) == 1

    # Should drop everything buth the if branch
    body = NestedMaskedTransformer(start=start, stop=stop, greedy_stop=True).visit(routine.body)
    assert [a.lhs == 'b' for a in FindNodes(ir.Assignment).visit(body)] == [True]
    assert len(FindNodes(ir.Loop).visit(body)) == 1
    assert len(FindNodes(ir.Conditional).visit(body)) == 1


@pytest.mark.parametrize('invalidate_source', [True, False])
@pytest.mark.parametrize('replacement', ['body', 'self', 'self_tuple', 'duplicate'])
@pytest.mark.parametrize('frontend', available_frontends())
def test_transformer_duplicate_node_tuple_injection(frontend, invalidate_source, replacement):
    """Test for #41, where identical nodes in a tuple have not been
    correctly handled in the tuple injection mechanism."""
    fcode_kernel = """
SUBROUTINE compute_column(start, end, nlon, nz, q)
    INTEGER, INTENT(IN) :: start, end
    INTEGER, INTENT(IN) :: nlon, nz
    REAL, INTENT(INOUT) :: q(nlon,nz)
    INTEGER :: jl
    DO JL = START, END
        Q(JL, NZ) = Q(JL, NZ) * 0.5
    END DO
    DO JL = START, END
        Q(JL, NZ) = Q(JL, NZ) * 0.5
    END DO
END SUBROUTINE compute_column
"""
    kernel = Subroutine.from_source(fcode_kernel, frontend=frontend)

    # Empty substitution pass, which invalidates the source property
    kernel.body = SubstituteExpressions({}, invalidate_source=invalidate_source).visit(kernel.body)

    loops = FindNodes(ir.Loop).visit(kernel.body)
    if replacement == 'body':
        # Replace loop by its body
        mapper = {l: l.body for l in loops}
    elif replacement == 'self':
        # Replace loop by itself
        mapper = {l: l for l in loops}
    elif replacement == 'self_tuple':
        # Replace loop by itself, but wrapped in a tuple
        mapper = {l: (l,) for l in loops}
    elif replacement == 'duplicate':
        # Duplicate the loop (will this trigger infinite recursion in tuple injection)?
        mapper = {l: (l, l) for l in loops}
    else:
        # We shouldn't be here!
        assert False
    kernel.body = Transformer(mapper).visit(kernel.body)
    # Make sure we don't have any nested tuples or similar nasty things, which would
    # cause a transformer pass to fail
    kernel.body = Transformer({}).visit(kernel.body)
    # If the code gen works, then it's probably not too broken...
    assert kernel.to_fortran()
    # Make sure the number of loops is correct
    assert len(FindNodes(ir.Loop).visit(kernel.body)) == {
        'body': 0, # All loops replaced by the body
        'self': 2, 'self_tuple': 2,  # Loop replaced by itself
        'duplicate': 4  # Loops duplicated
    }[replacement]
loki-ecmwf-0.3.6/loki/ir/tests/test_scoped_nodes.py0000664000175000017500000001560115167130205022534 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

# A set of tests for the symbol accessort and management API built into `ScopedNode`.

import pytest

from loki import Module
from loki.frontend import available_frontends
from loki.ir import nodes as ir, FindNodes
from loki.expression import symbols as sym
from loki.types import BasicType


@pytest.mark.parametrize('frontend', available_frontends())
def test_scoped_node_get_symbols(frontend, tmp_path):
    """ Test :method:`get_symbol` functionality on scoped nodes. """
    fcode = """
module test_scoped_node_symbols_mod
implicit none
integer, parameter :: jprb = 8

contains
  subroutine test_scoped_node_symbols(n, a, b, c)
    integer, intent(in) :: n
    real(kind=jprb), intent(inout) :: a(n), b(n), c
    integer :: i

    a(1) = 42.0_jprb

    associate(d => a)
    do i=1, n
      b(i) = a(i) + c
    end do
    end associate
  end subroutine test_scoped_node_symbols
end module test_scoped_node_symbols_mod
"""
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    routine = module['test_scoped_node_symbols']
    associate = FindNodes(ir.Associate).visit(routine.body)[0]

    # Check symbol lookup from subroutine
    assert routine.get_symbol('a') == 'a(n)'
    assert routine.get_symbol('a').scope == routine
    assert routine.get_symbol('b') == 'b(n)'
    assert routine.get_symbol('b').scope == routine
    assert routine.get_symbol('c') == 'c'
    assert routine.get_symbol('c').scope == routine
    assert routine.get_symbol('jprb') == 'jprb'
    assert routine.get_symbol('jprb').scope == module
    assert routine.get_symbol('jprb').initial == 8

    # Check passthrough from the Associate (ScopedNode)
    assert associate.get_symbol('a') == 'a(n)'
    assert associate.get_symbol('a').scope == routine
    assert associate.get_symbol('b') == 'b(n)'
    assert associate.get_symbol('b').scope == routine
    assert associate.get_symbol('c') == 'c'
    assert associate.get_symbol('c').scope == routine
    assert associate.get_symbol('d') == 'd'
    assert associate.get_symbol('d').scope == associate
    assert associate.get_symbol('jprb') == 'jprb'
    assert associate.get_symbol('jprb').scope == module
    assert associate.get_symbol('jprb').initial == 8


@pytest.mark.parametrize('frontend', available_frontends())
def test_scoped_node_variable_constructor(frontend, tmp_path):
    """ Test :any:`Variable` constrcutore on scoped nodes. """
    fcode = """
module test_scoped_nodes_mod
implicit none
integer, parameter :: jprb = 8

contains
  subroutine test_scoped_nodes(n, a, b, c)
    integer, intent(in) :: n
    real(kind=jprb), intent(inout) :: a(n), b(n), c
    integer :: i

    a(1) = 42.0_jprb

    associate(d => a)
    do i=1, n
      b(i) = a(i) + c
    end do
    end associate
  end subroutine test_scoped_nodes
end module test_scoped_nodes_mod
"""
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    routine = module['test_scoped_nodes']
    associate = FindNodes(ir.Associate).visit(routine.body)[0]

    # Build some symbiols and check their type
    c = routine.Variable(name='c')
    assert c.type.dtype == BasicType.REAL
    assert c.type.kind in ('jprb', 8)
    i = routine.Variable(name='i')
    assert i.type.dtype == BasicType.INTEGER
    jprb = routine.Variable(name='jprb')
    assert jprb.type.dtype == BasicType.INTEGER
    assert jprb.type.initial == 8

    a_i = routine.Variable(name='a', dimensions=(i,))
    assert a_i == 'a(i)'
    assert isinstance(a_i, sym.Array)
    assert a_i.dimensions == (i,)
    assert a_i.type.dtype == BasicType.REAL
    assert a_i.type.kind in ('jprb', 8)

    # Build another, but from the associate node
    b_i = associate.Variable(name='b', dimensions=(i,))
    assert b_i == 'b(i)'
    assert isinstance(b_i, sym.Array)
    assert b_i.dimensions == (i,)
    assert b_i.type.dtype == BasicType.REAL
    assert b_i.type.kind in ('jprb', 8)


@pytest.mark.parametrize('frontend', available_frontends())
def test_scoped_node_parse_expr(frontend, tmp_path):
    """ Test :any:`Variable` constrcutore on scoped nodes. """
    fcode = """
module test_scoped_nodes_mod
implicit none
integer, parameter :: jprb = 8

contains
  subroutine test_scoped_nodes(n, a, b, c)
    integer, intent(in) :: n
    real(kind=jprb), intent(inout) :: a(n), b(n), c
    integer :: i

    a(1) = 42.0_jprb

    associate(d => a)
    do i=1, n
      b(i) = a(i) + c
    end do
    end associate
  end subroutine test_scoped_nodes
end module test_scoped_nodes_mod
"""
    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    routine = module['test_scoped_nodes']
    associate = FindNodes(ir.Associate).visit(routine.body)[0]

    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert assigns[0].lhs == 'a(1)'
    assert assigns[1].rhs == 'a(i) + c'

    # Check that all variables are identified
    ai_c = routine.parse_expr('a(i)   +   c')
    assert ai_c == assigns[1].rhs
    assert isinstance(ai_c, sym.Sum)
    assert isinstance(ai_c.children[0], sym.Array)
    assert isinstance(ai_c.children[0].dimensions[0], sym.Scalar)
    assert isinstance(ai_c.children[1], sym.Scalar)
    assert ai_c.children[0].scope == routine
    assert ai_c.children[0].dimensions[0].scope == routine
    assert ai_c.children[1].scope == routine

    # Check that k is deferred
    ai_k = routine.parse_expr('a(i) + k')
    assert isinstance(ai_k, sym.Sum)
    assert isinstance(ai_k.children[0], sym.Array)
    assert isinstance(ai_k.children[0].dimensions[0], sym.Scalar)
    assert isinstance(ai_k.children[1], sym.DeferredTypeSymbol)
    assert ai_c.children[0].scope == routine
    assert ai_c.children[0].dimensions[0].scope == routine
    assert ai_c.children[1].scope == routine

    # Check that all variables are identified
    ai_c = associate.parse_expr('a(i)   +   c')
    assert ai_c == assigns[1].rhs
    assert isinstance(ai_c, sym.Sum)
    assert isinstance(ai_c.children[0], sym.Array)
    assert isinstance(ai_c.children[0].dimensions[0], sym.Scalar)
    assert isinstance(ai_c.children[1], sym.Scalar)
    assert ai_c.children[0].scope == routine
    assert ai_c.children[0].dimensions[0].scope == routine
    assert ai_c.children[1].scope == routine

    # Check that k is deferred
    ai_k = associate.parse_expr('a(i) + k')
    assert isinstance(ai_k, sym.Sum)
    assert isinstance(ai_k.children[0], sym.Array)
    assert isinstance(ai_k.children[0].dimensions[0], sym.Scalar)
    assert isinstance(ai_k.children[1], sym.DeferredTypeSymbol)
    assert ai_c.children[0].scope == routine
    assert ai_c.children[0].dimensions[0].scope == routine
    assert ai_c.children[1].scope == routine
loki-ecmwf-0.3.6/loki/ir/tests/test_expr_visitors.py0000664000175000017500000003153015167130205023006 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki import Sourcefile, Subroutine, Module, config_override
from loki.expression import symbols as sym, parse_expr
from loki.frontend import available_frontends, OMNI, SourceStatus
from loki.ir import (
    nodes as ir, FindNodes, FindVariables, FindTypedSymbols,
    SubstituteExpressions, SubstituteStringExpressions,
    FindLiterals, FindRealLiterals
)


@pytest.mark.parametrize('frontend', available_frontends())
def test_expression_finder_retrieval_function(frontend, tmp_path):
    """
    Verify that expression finder visitors work as intended and remain
    functional if re-used
    """
    fcode = """
module some_mod
    implicit none
contains
    function some_func() result(ret)
        integer :: ret
        ret = 1
    end function some_func

    subroutine other_routine
        integer :: var, tmp
        var = 5 + some_func()
    end subroutine other_routine
end module some_mod
    """.strip()

    source = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    expected_ts = {'var', 'some_func'}
    expected_vars = ('var',)

    # Instantiate the first expression finder and make sure it works as expected
    find_ts = FindTypedSymbols()
    assert find_ts.visit(source['other_routine'].body) == expected_ts

    # Verify that it works also on a repeated invocation
    assert find_ts.visit(source['other_routine'].body) == expected_ts

    # Instantiate the second expression finder and make sure it works as expected
    find_vars = FindVariables(unique=False)
    assert find_vars.visit(source['other_routine'].body) == expected_vars

    # Make sure the first expression finder still works
    assert find_ts.visit(source['other_routine'].body) == expected_ts


@pytest.mark.parametrize('frontend', available_frontends())
def test_find_variables(frontend, tmp_path):
    """ Test that :any:`FindVariables` finds all symbol uses. """

    fcode_external = """
module external_mod
implicit none
contains
subroutine rick(dave, never)
  real(kind=8), intent(inout) :: dave, never
end subroutine rick
end module external_mod
    """
    fcode = """
module test_mod
  use external_mod, only: rick
  implicit none

  type my_type
    real(kind=8) :: never
    real(kind=8), pointer :: give_you(:)
  end type my_type

contains

  subroutine test_routine(n, a, b, gonna)
    integer, intent(in) :: n
    real(kind=8), intent(inout) :: a, b(n)
    type(my_type), intent(inout) :: gonna
    integer :: i

    associate(will=>gonna%never, up=>n)
    do i=1, n
      b(i) = b(i) + a
    end do

    call rick(will, never=gonna%give_you(up))
    end associate
  end subroutine test_routine
end module test_mod
    """
    _ = Sourcefile.from_source(fcode_external, frontend=frontend, xmods=[tmp_path])
    source = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    routine = source['test_routine']

    # Test unique=True|False using the spec
    expected = ['n', 'a', 'gonna', 'i', 'b(n)']
    spec_vars = FindVariables(unique=True).visit(routine.spec)
    assert len(spec_vars) == 5
    assert all(v in spec_vars for v in expected)

    spec_vars = FindVariables(unique=False).visit(routine.spec)
    assert len(spec_vars) == 6
    assert all(v in spec_vars for v in expected)
    assert len([v for v in spec_vars if v == 'n']) == 2  # two occurences of 'n'

    # Test retrieval with associates and keyword arg calls
    expected = [
        'will', 'gonna', 'gonna%never', 'up', 'n', 'i', 'b(i)', 'a',
        'rick', 'gonna%give_you(up)'
    ]
    body_vars = FindVariables(unique=True).visit(routine.body)
    assert len(body_vars) == 10
    assert all(v in body_vars for v in expected)

@pytest.mark.parametrize('frontend', available_frontends())
def test_find_literals(frontend):
    """
    Test that :any:`FindLiterals` finds all literals
    and :any:`FindRealLiterals` all real/float literals.
    """
    fcode = """
subroutine test_find_literals()
  implicit none
  integer :: n, n1
  real(kind=8) :: x

  n = 1 + 5 + 42
  x = 1.0 / 10.5
  n1 = int(B'00000')
  if (.TRUE.) then
    call some_func(x, some_string='string_kwarg')
  endif

end subroutine test_find_literals
"""
    expected_int_literals = ('1', '5', '42')
    expected_real_literals = ('1.0', '10.5')
    # Omni evaluates BOZ constants, so it creates IntegerLiteral instead...
    expected_intrinsic_literals = ("B'00000'",) if frontend != OMNI else ('0',)
    expected_logic_literals = ('True',)
    expected_string_literals = ('string_kwarg',)
    expected_literals = expected_int_literals + expected_real_literals +\
            expected_intrinsic_literals + expected_logic_literals +\
            expected_string_literals
    routine = Subroutine.from_source(fcode, frontend=frontend)
    literals = FindLiterals().visit(routine.body)
    assert sorted(list(expected_literals)) == sorted([str(literal.value) for literal in literals])
    real_literals = FindRealLiterals().visit(routine.body)
    assert sorted(list(expected_real_literals)) == sorted([str(literal.value) for literal in real_literals])
    real_literals_isinstance = [literal for literal in literals if isinstance(literal, sym.FloatLiteral)]
    assert sorted(list(expected_real_literals)) == sorted([str(literal.value) for literal in real_literals_isinstance])


@pytest.mark.parametrize('frontend', available_frontends())
def test_substitute_expressions(frontend):
    """ Test symbol replacement with :any:`Expression` symbols. """

    fcode = """
subroutine test_routine(n, a, b)
  implicit none
  integer, intent(in) :: n
  real(kind=8), intent(inout) :: a, b(n)
  real(kind=8) :: c(n)
  integer :: i

  associate(d => a)
  do i=1, n
    c(i) = b(i) + a
  end do

  call another_routine(n, a, c(:), a2=d)

  end associate
end subroutine test_routine
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    calls = FindNodes(ir.CallStatement).visit(routine.body)
    assoc = FindNodes(ir.Associate).visit(routine.body)[0]
    assert calls[0].arguments == ('n', 'a', 'c(:)')
    assert calls[0].kwarguments == (('a2', 'd'),)

    n = routine.variable_map['n']
    i = routine.variable_map['i']
    a = routine.variable_map['a']
    b_i = parse_expr('b(i)', scope=routine)
    c_r = parse_expr('c(:)', scope=routine)
    d = parse_expr('d', scope=assoc)
    expr_map = {
        n: sym.Sum((n, sym.Product((-1, sym.Literal(1))))),
        b_i: b_i.clone(dimensions=sym.Sum((i, sym.Literal(1)))),
        c_r: c_r.clone(dimensions=sym.Range((sym.Literal(1), sym.Literal(2)))),
        a: d,
        d: a,
    }
    routine.body = SubstituteExpressions(expr_map).visit(routine.body)

    loops = FindNodes(ir.Loop).visit(routine.body)
    assert loops[0].bounds == '1:n-1'
    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert assigns[0].lhs == 'c(i)' and assigns[0].rhs == 'b(i+1) + d'
    calls = FindNodes(ir.CallStatement).visit(routine.body)
    assert calls[0].arguments == ('n - 1', 'd', 'c(1:2)')
    assert calls[0].kwarguments == (('a2', 'a'),)


@pytest.mark.parametrize('frontend', available_frontends())
def test_substitute_string_expressions(frontend):
    """ Test symbol replacement with symbol string mappping. """

    fcode = """
subroutine test_routine(n, a, b)
  implicit none
  integer, intent(in) :: n
  real(kind=8), intent(inout) :: a, b(n)
  real(kind=8) :: c(n)
  integer :: i

  associate(d => a)
  do i=1, n
    c(i) = b(i) + a
  end do

  call another_routine(n, a, c(:), a2=d)

  end associate
end subroutine test_routine
"""
    routine = Subroutine.from_source(fcode, frontend=frontend)

    calls = FindNodes(ir.CallStatement).visit(routine.body)
    assoc = FindNodes(ir.Associate).visit(routine.body)[0]
    assert calls[0].arguments == ('n', 'a', 'c(:)')
    assert calls[0].kwarguments == (('a2', 'd'),)

    expr_map = {
        'n': 'n - 1',
        'b(i)': 'b(i+1)',
        'c(:)': 'c(1:2)',
        'a': 'd',
        'd': 'a',
    }
    # Note that we need to use the associate block here, as it defines 'd'
    routine.body = SubstituteStringExpressions(expr_map, scope=assoc).visit(routine.body)

    loops = FindNodes(ir.Loop).visit(routine.body)
    assert loops[0].bounds == '1:n-1'
    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert assigns[0].lhs == 'c(i)' and assigns[0].rhs == 'b(i+1) + d'
    calls = FindNodes(ir.CallStatement).visit(routine.body)
    assert calls[0].arguments == ('n - 1', 'd', 'c(1:2)')
    assert calls[0].kwarguments == (('a2', 'a'),)


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('use_string', [True, False])
def test_substitute_expression_source_invalidation(use_string, frontend, tmp_path):
    """ Test source invalidation when using symbol or string substitution """

    fcode_type = """
module type_mod
  integer, parameter :: jprb = 8
end module type_mod
"""

    fcode = """
subroutine test_routine(n, a, b)
  use type_mod, only: jprb
  implicit none
  integer, intent(in) :: n
  real(kind=jprb), intent(inout) :: a, b
  real(kind=jprb) :: c(n)
  integer :: i

  associate(d => b)
  do i=1, n
    if (i > 2) then
      c(i) = b(i) + a
    else
      c(i) = 42.0
    end if

    if (a > 0.5) then
      c(i) = 66.6
    end if
  end do

  if (c(1) > 0.5) then
    call another_routine(n, a, c(:), a2=d)
  end if

  end associate
end subroutine test_routine
"""
    Module.from_source(fcode_type, frontend=frontend, xmods=[tmp_path])
    with config_override({'frontend-store-source': True}):
        routine = Subroutine.from_source(fcode, frontend=frontend, xmods=[tmp_path])

    calls = FindNodes(ir.CallStatement).visit(routine.body)
    assoc = FindNodes(ir.Associate).visit(routine.body)[0]
    assert calls[0].arguments == ('n', 'a', 'c(:)')
    assert calls[0].kwarguments == (('a2', 'd'),)

    if use_string:
        expr_map = {'a': 'd'}
        routine.body = SubstituteStringExpressions(expr_map, scope=assoc).visit(routine.body)
    else:
        a = routine.variable_map['a']
        expr_map = {a: parse_expr('d', scope=assoc)}
        routine.body = SubstituteExpressions(expr_map).visit(routine.body)


    loops = FindNodes(ir.Loop).visit(routine.body)
    assert len(loops) == 1
    assert loops[0].variable == 'i' and loops[0].bounds == '1:n'
    assert loops[0].source.status == SourceStatus.INVALID_CHILDREN

    assigns = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assigns) == 3
    assert all(a.lhs == 'c(i)' for a in assigns)
    assert assigns[0].rhs == 'b(i) + d'
    assert assigns[0].source.status == SourceStatus.INVALID_NODE
    assert assigns[1].source.status == SourceStatus.VALID
    assert assigns[2].source.status == SourceStatus.VALID

    calls = FindNodes(ir.CallStatement).visit(routine.body)
    assert len(calls) == 1
    assert calls[0].arguments == ('n', 'd', 'c(:)')
    assert calls[0].kwarguments == (('a2', 'd'),)
    assert calls[0].source.status == SourceStatus.INVALID_NODE

    conds = FindNodes(ir.Conditional).visit(routine.body)
    assert len(conds) == 3
    assert conds[0].condition == 'i > 2'
    assert conds[1].condition == 'd > 0.5'
    assert conds[2].condition == 'c(1) > 0.5'
    assert conds[0].source.status == SourceStatus.INVALID_CHILDREN
    assert conds[1].source.status == SourceStatus.INVALID_NODE
    assert conds[2].source.status == SourceStatus.INVALID_CHILDREN

    # Now test replacing the kind attribute in imports and declarations
    if use_string:
        expr_map = {'jprb': 'dbl'}
        routine.spec = SubstituteStringExpressions(expr_map, scope=assoc).visit(routine.spec)
    else:
        a = routine.imported_symbol_map['jprb']
        expr_map = {a: parse_expr('dbl', scope=assoc)}
        routine.spec = SubstituteExpressions(expr_map).visit(routine.spec)

    imports = FindNodes(ir.Import).visit(routine.spec)
    assert len(imports) == 1
    assert imports[0].module == 'type_mod' and imports[0].symbols == ('dbl',)
    assert imports[0].source.status == SourceStatus.INVALID_NODE

    # OMNI changes declarations too much
    if not frontend == OMNI:
        decls = FindNodes(ir.VariableDeclaration).visit(routine.spec)
        assert len(decls) == 4
        assert decls[0].symbols == ('n',) and decls[0].symbols[0].type.intent == 'in'
        assert decls[1].symbols == ('a', 'b') and decls[1].symbols[0].type.kind == 'dbl'
        assert decls[2].symbols == ('c(n)',) and decls[2].symbols[0].type.kind == 'dbl'
        assert decls[3].symbols == ('i',) and decls[3].symbols[0].type.intent is None
        assert decls[0].source.status == SourceStatus.VALID
        assert decls[1].source.status == SourceStatus.INVALID_NODE
        assert decls[2].source.status == SourceStatus.INVALID_NODE
        assert decls[3].source.status == SourceStatus.VALID
loki-ecmwf-0.3.6/loki/ir/tests/test_ir_graph.py0000664000175000017500000003074015167130205021663 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import re
from pathlib import Path
import pytest

from loki import Sourcefile, graphviz_present
from loki.analyse import dataflow_analysis_attached
from loki.ir import Node, FindNodes, ir_graph, GraphCollector


@pytest.fixture(scope="module", name="here")
def fixture_here():
    return Path(__file__).parent


@pytest.fixture(scope='module', name='testdir')
def fixture_testdir(here):
    return here.parent.parent/'tests'


test_files = [
    "sources/trivial_fortran_files/case_statement_subroutine.f90",
    "sources/trivial_fortran_files/if_else_statement_subroutine.f90",
    "sources/trivial_fortran_files/module_with_subroutines.f90",
    "sources/trivial_fortran_files/nested_if_else_statements_subroutine.f90",
]

solutions_default_parameters = {
    "sources/trivial_fortran_files/case_statement_subroutine.f90": {
        "node_count": 12,
        "edge_count": 11,
        "node_labels": {
            "0": "",
            "1": "",
            "2": "",
            "3": "",
            "4": "",
            "5": "",
            "6": "",
            "7": "",
            "8": "",
            "9": "",
            "10": "",
            "11": "",
        },
        "connectivity_list": {
            "0": ["1"],
            "1": ["2", "4"],
            "2": ["3"],
            "4": ["5"],
            "5": ["6", "7", "8", "9", "10", "11"],
        },
    },
    "sources/trivial_fortran_files/if_else_statement_subroutine.f90": {
        "node_count": 8,
        "edge_count": 7,
        "node_labels": {
            "0": "",
            "1": "",
            "2": "",
            "3": "",
            "4": "",
            "5": "x > 0.0",
            "6": "",
            "7": "",
        },
        "connectivity_list": {
            "0": ["1"],
            "1": ["2", "4"],
            "2": ["3"],
            "4": ["5"],
            "5": ["6", "7"],
        },
    },
    "sources/trivial_fortran_files/module_with_subroutines.f90": {
        "node_count": 24,
        "edge_count": 23,
        "node_labels": {
            "0": "",
            "1": "",
            "2": "",
            "3": "",
            "4": "",
            "5": "",
            "6": "",
            "7": "",
            "8": "",
            "9": "",
            "10": "",
            "11": "",
            "12": "",
            "13": "",
            "14": "",
            "15": "",
            "16": "",
            "17": "",
            "18": "",
            "19": "",
            "20": "",
            "21": "",
            "22": "",
            "23": "",
        },
        "connectivity_list": {
            "0": ["1"],
            "1": ["2", "4"],
            "10": ["11"],
            "12": ["13", "16"],
            "13": ["14", "15"],
            "16": ["17"],
            "18": ["19", "22"],
            "19": ["20", "21"],
            "2": ["3"],
            "22": ["23"],
            "4": ["5", "6", "12", "18"],
            "6": ["7", "10"],
            "7": ["8", "9"],
        },
    },
    "sources/trivial_fortran_files/nested_if_else_statements_subroutine.f90": {
        "node_count": 12,
        "edge_count": 11,
        "node_labels": {
            "0": "",
            "1": "",
            "2": "",
            "3": "",
            "4": "",
            "5": "x > 0",
            "6": "y > 0",
            "7": "",
            "8": "",
            "9": "y > 0",
            "10": "",
            "11": "",
        },
        "connectivity_list": {
            "0": ["1"],
            "1": ["2", "4"],
            "2": ["3"],
            "4": ["5"],
            "5": ["6", "9"],
            "6": ["7", "8"],
            "9": ["10", "11"],
        },
    },
}

solutions_node_edge_counts = {
    "sources/trivial_fortran_files/case_statement_subroutine.f90": {
        "node_count": [[12, 19], [14, 21]],
        "edge_count": [[11, 18], [13, 20]],
    },
    "sources/trivial_fortran_files/if_else_statement_subroutine.f90": {
        "node_count": [[8, 9], [10, 11]],
        "edge_count": [[7, 8], [9, 10]],
    },
    "sources/trivial_fortran_files/module_with_subroutines.f90": {
        "node_count": [[24, 39], [32, 47]],
        "edge_count": [[23, 38], [31, 46]],
    },
    "sources/trivial_fortran_files/nested_if_else_statements_subroutine.f90": {
        "node_count": [[12, 14], [14, 16]],
        "edge_count": [[11, 13], [13, 15]],
    },
}


def get_property(node_edge_info, name):
    for node_info, edge_info in node_edge_info:
        if name in node_info and name in edge_info:
            yield (node_info[name], edge_info[name])
            continue

        if name in node_info:
            yield (node_info[name], None)
            continue

        if name in edge_info:
            yield (None, edge_info[name])
            continue

        if node_info and edge_info:
            raise KeyError(f"Keyword {name} not found!")


@pytest.mark.skipif(not graphviz_present(), reason="Graphviz is not installed")
@pytest.mark.parametrize("test_file", test_files)
@pytest.mark.parametrize("show_comments", [True, False])
@pytest.mark.parametrize("show_expressions", [True, False])
def test_graph_collector_node_edge_count_only(
    testdir, test_file, show_comments, show_expressions, tmp_path
):
    solution = solutions_node_edge_counts[test_file]
    source = Sourcefile.from_file(testdir / test_file, xmods=[tmp_path])

    graph_collector = GraphCollector(
        show_comments=show_comments, show_expressions=show_expressions
    )
    node_edge_info = [
        item for item in graph_collector.visit(source.ir) if item is not None
    ]

    node_names = [name for (name, _) in get_property(node_edge_info, "name")]
    node_labels = [label for (label, _) in get_property(node_edge_info, "label")]

    assert (
        len(node_names)
        == len(node_labels)
        == solution["node_count"][show_comments][show_expressions]
    )
    edge_heads = [head for (_, head) in get_property(node_edge_info, "head_name")]
    edge_tails = [tail for (_, tail) in get_property(node_edge_info, "tail_name")]

    assert (
        len(edge_heads)
        == len(edge_tails)
        == solution["edge_count"][show_comments][show_expressions]
    )


@pytest.mark.skipif(not graphviz_present(), reason="Graphviz is not installed")
@pytest.mark.parametrize("test_file", test_files)
def test_graph_collector_detail(testdir, test_file, tmp_path):
    solution = solutions_default_parameters[test_file]
    source = Sourcefile.from_file(testdir / test_file, xmods=[tmp_path])

    graph_collector = GraphCollector()
    node_edge_info = [
        item for item in graph_collector.visit(source.ir) if item is not None
    ]

    node_names = [name for (name, _) in get_property(node_edge_info, "name")]
    node_labels = [label for (label, _) in get_property(node_edge_info, "label")]

    assert len(node_names) == len(node_labels) == solution["node_count"]

    for name, label in zip(node_names, node_labels):
        assert solution["node_labels"][name] == label

    edge_heads = [head for (_, head) in get_property(node_edge_info, "head_name")]
    edge_tails = [tail for (_, tail) in get_property(node_edge_info, "tail_name")]

    assert len(edge_heads) == len(edge_tails) == solution["edge_count"]

    for head, tail in zip(edge_heads, edge_tails):
        assert head in solution["connectivity_list"][tail]


@pytest.mark.skipif(not graphviz_present(), reason="Graphviz is not installed")
@pytest.mark.parametrize("test_file", test_files)
@pytest.mark.parametrize("linewidth", [40, 60, 80])
def test_graph_collector_maximum_label_length(testdir, test_file, linewidth, tmp_path):
    source = Sourcefile.from_file(testdir / test_file, xmods=[tmp_path])

    graph_collector = GraphCollector(
        show_comments=True, show_expressions=True, linewidth=linewidth
    )
    node_edge_info = [
        item for item in graph_collector.visit(source.ir) if item is not None
    ]
    node_labels = [label for (label, _) in get_property(node_edge_info, "label")]

    for label in node_labels:
        assert len(label) <= linewidth


def find_edges(input_text):
    pattern = re.compile(r"(\d+)\s*->\s*(\d+)", re.IGNORECASE)
    return re.findall(pattern, input_text)


def find_nodes(input_text):
    pattern = re.compile(r'\d+ *\[[^\[\]]*(?:"[^"]*"[^\[\]]*)*\]', re.IGNORECASE)
    return re.findall(pattern, input_text)


def find_node_id_inside_nodes(input_text):
    pattern = re.compile(r"(\d+)\s+\[", re.IGNORECASE)
    return re.findall(pattern, input_text)


def find_label_content_inside_nodes(input_text):
    pattern = re.compile(r'label="([^"]*"|\'[^\']*\'|[^\'"]*)"', re.IGNORECASE)
    return re.findall(pattern, input_text)


@pytest.mark.skipif(not graphviz_present(), reason="Graphviz is not installed")
@pytest.mark.parametrize("test_file", test_files)
def test_ir_graph_writes_correct_graphs(testdir, test_file, tmp_path):
    solution = solutions_default_parameters[test_file]
    source = Sourcefile.from_file(testdir / test_file, xmods=[tmp_path])

    graph = ir_graph(source.ir)

    edges = find_edges(str(graph))

    for start, stop in edges:
        assert stop in solution["connectivity_list"][start]

    nodes = find_nodes(str(graph))

    assert len(edges) == solution["edge_count"]
    assert len(nodes) == solution["node_count"]

    node_ids = [find_node_id_inside_nodes(node) for node in nodes]
    for found_node_id in node_ids:
        assert len(found_node_id) == 1

    found_labels = [find_label_content_inside_nodes(node) for node in nodes]
    for found_label in found_labels:
        assert len(found_label) == 1

    assert len(found_labels) == len(node_ids)

    for node, label in zip(node_ids, found_labels):
        assert solution["node_labels"][node[0]] == label[0]


@pytest.mark.parametrize("test_file", test_files)
def test_ir_graph_dataflow_analysis_attached(testdir, test_file, tmp_path):
    source = Sourcefile.from_file(testdir / test_file, xmods=[tmp_path])

    def find_lives_defines_uses(text):
        # Regular expression pattern to match content within square brackets after 'live:', 'defines:', and 'uses:'
        pattern = r"live:\s*\[([^\]]*?)\],\s*defines:\s*\[([^\]]*?)\],\s*uses:\s*\[([^\]]*?)\]"
        matches = re.search(pattern, text)
        assert matches

        def remove_spaces_and_newlines(text):
            return text.replace(" ", "").replace("\n", "")

        def disregard_empty_strings(elements):
            return set(element for element in elements if element != "")

        def apply_conversion(text):
            return disregard_empty_strings(remove_spaces_and_newlines(text).split(","))

        return (
            apply_conversion(matches.group(1)),
            apply_conversion(matches.group(2)),
            apply_conversion(matches.group(3)),
        )

    for routine in source.all_subroutines:
        with dataflow_analysis_attached(routine):
            for node in FindNodes(Node).visit(routine.body):
                node_info, _ = GraphCollector(show_comments=True).visit(node)[0]
                lives, defines, uses = find_lives_defines_uses(node_info["label"])
                assert node.live_symbols == set(lives)
                assert node.uses_symbols == set(uses)
                assert node.defines_symbols == set(defines)
loki-ecmwf-0.3.6/loki/ir/tests/test_visitor.py0000664000175000017500000003703415167130205021572 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest
from pymbolic.primitives import Expression

from loki import Module, Subroutine
from loki.expression import (
    symbols as sym, ExpressionCallbackMapper, ExpressionRetriever
)
from loki.frontend import available_frontends, OMNI
from loki.ir import (
    nodes as ir, is_parent_of, is_child_of, FindNodes, FindScopes,
    FindVariables, ExpressionFinder
)
from loki.tools import OrderedSet


@pytest.mark.parametrize('frontend', available_frontends())
def test_find_nodes_greedy(frontend):
    """
    Test the FindNodes visitor's greedy property.
    """
    fcode = """
subroutine routine_find_nodes_greedy(n, m)
  integer, intent(in) :: n, m

  if (n > m) then
    if (n == 3) then
      print *,"Inner if"
    endif
    print *,"Outer if"
  endif
end subroutine routine_find_nodes_greedy
"""

    # Test the internals of the subroutine
    routine = Subroutine.from_source(fcode, frontend=frontend)

    conditionals = FindNodes(ir.Conditional).visit(routine.body)
    assert len(conditionals) == 2

    outer_cond = FindNodes(ir.Conditional, greedy=True).visit(routine.body)
    assert len(outer_cond) == 1
    assert outer_cond[0] in conditionals
    assert str(outer_cond[0].condition) == 'n > m'


@pytest.mark.parametrize('frontend', available_frontends())
def test_find_scopes(frontend):
    """
    Test the FindScopes visitor.
    """
    fcode = """
subroutine routine_find_nodes_greedy(n, m)
  integer, intent(in) :: n, m

  if (n > m) then
    if (n == 3) then
      print *,"Inner if"
    endif
    print *,"Outer if"
  endif
end subroutine routine_find_nodes_greedy
""".strip()
    routine = Subroutine.from_source(fcode, frontend=frontend)

    intrinsics = FindNodes(ir.Intrinsic).visit(routine.body)
    assert len(intrinsics) == 2
    inner = [i for i in intrinsics if 'Inner' in i.text][0]
    outer = [i for i in intrinsics if 'Outer' in i.text][0]

    conditionals = FindNodes(ir.Conditional).visit(routine.body)
    assert len(conditionals) == 2

    scopes = FindScopes(inner).visit(routine.body)
    assert len(scopes) == 1  # returns a list containing a list of nested nodes
    assert len(scopes[0]) == 4  # should have found 3 scopes and the node itself
    assert all(c in scopes[0] for c in conditionals)  # should have found all if
    assert routine.body is scopes[0][0]  # body section should be outermost scope
    assert str(scopes[0][1].condition) == 'n > m'  # outer if should come first
    assert inner is scopes[0][-1]  # node itself should be last in list

    scopes = FindScopes(outer).visit(routine.body)
    assert len(scopes) == 1  # returns a list containing a list of nested nodes
    assert len(scopes[0]) == 3  # should have found 2 scopes and the node itself
    assert all(c in scopes[0] or str(c.condition == 'n == 3')
               for c in conditionals)  # should have found only the outer if
    assert routine.body is scopes[0][0]  # body section should be outermost scope
    assert outer is scopes[0][-1]  # node itself should be last in list


@pytest.mark.parametrize('frontend', available_frontends())
def test_expression_finder(frontend):
    """
    Test the expression finder's ability to yield only all variables.
    """
    fcode = """
subroutine routine_simple (x, y, scalar, vector, matrix)
  integer, parameter :: jprb = selected_real_kind(13,300)
  integer, intent(in) :: x, y
  real(kind=jprb), intent(in) :: scalar
  real(kind=jprb), intent(inout) :: vector(x), matrix(x, y)
  integer :: i

  do i=1, x
     vector(i) = vector(i) + scalar
     matrix(i, :) = i * vector(i)
  end do
end subroutine routine_simple
"""

    # Test the internals of the subroutine
    routine = Subroutine.from_source(fcode, frontend=frontend)

    variables = FindVariables(unique=False).visit(routine.body)
    assert len(variables) == 12
    assert all(isinstance(v, Expression) for v in variables)

    assert sorted([str(v) for v in variables]) == (
        ['i'] * 6 + ['matrix(i, :)', 'scalar'] + ['vector(i)'] * 3 + ['x'])


@pytest.mark.parametrize('frontend', available_frontends())
def test_expression_finder_unique(frontend):
    """
    Test the expression finder's ability to yield unique variables.
    """
    fcode = """
subroutine routine_simple (x, y, scalar, vector, matrix)
  integer, parameter :: jprb = selected_real_kind(13,300)
  integer, intent(in) :: x, y
  real(kind=jprb), intent(in) :: scalar
  real(kind=jprb), intent(inout) :: vector(x), matrix(x, y)
  integer :: i

  do i=1, x
     vector(i) = vector(i) + scalar
     matrix(i, :) = i * vector(i)
  end do
end subroutine routine_simple
"""

    # Test the internals of the subroutine
    routine = Subroutine.from_source(fcode, frontend=frontend)

    variables = FindVariables().visit(routine.body)
    assert isinstance(variables, OrderedSet)
    assert len(variables) == 5
    assert all(isinstance(v, Expression) for v in variables)

    assert sorted([str(v) for v in variables]) == ['i', 'matrix(i, :)', 'scalar', 'vector(i)', 'x']


@pytest.mark.parametrize('frontend', available_frontends())
def test_expression_finder_with_ir_node(frontend):
    """
    Test the expression finder's ability to yield the root node.
    """
    fcode = """
subroutine routine_simple (x, y, scalar, vector, matrix)
  integer, parameter :: jprb = selected_real_kind(13,300)
  integer, intent(in) :: x, y
  real(kind=jprb), intent(in) :: scalar
  real(kind=jprb), intent(inout) :: vector(x), matrix(x, y)
  integer :: i

  do i=1, x
     vector(i) = vector(i) + scalar
     matrix(i, :) = i * vector(i)
  end do
end subroutine routine_simple
"""

    # Test the internals of the subroutine
    routine = Subroutine.from_source(fcode, frontend=frontend)

    variables = FindVariables(unique=False, with_ir_node=True).visit(routine.body)
    assert len(variables) == 3
    assert all(isinstance(v, tuple) and len(v) == 2 for v in variables)

    # Verify that the variables in the loop definition are found
    loops = [v for v in variables if isinstance(v[0], ir.Loop)]
    assert len(loops) == 1
    assert sorted([str(v) for v in loops[0][1]]) == ['i', 'x']

    # Verify that the variables in the statements are found
    stmts = [v for v in variables if isinstance(v[0], ir.Assignment)]
    assert len(stmts) == 2

    assert sorted([str(v) for v in stmts[0][1]]) == ['i', 'i', 'scalar', 'vector(i)', 'vector(i)']
    assert sorted([str(v) for v in stmts[1][1]]) == ['i', 'i', 'i', 'matrix(i, :)', 'vector(i)']


@pytest.mark.parametrize('frontend', available_frontends())
def test_expression_finder_unique_with_ir_node(frontend):
    """
    Test the expression finder's ability to yield the ir node combined with only unique
    variables.
    """
    fcode = """
subroutine routine_simple (x, y, scalar, vector, matrix)
  integer, parameter :: jprb = selected_real_kind(13,300)
  integer, intent(in) :: x, y
  real(kind=jprb), intent(in) :: scalar
  real(kind=jprb), intent(inout) :: vector(x), matrix(x, y)
  integer :: i

  do i=1, x
     vector(i) = vector(i) + scalar
     matrix(i, :) = i * vector(i)
  end do
end subroutine routine_simple
"""

    # Test the internals of the subroutine
    routine = Subroutine.from_source(fcode, frontend=frontend)

    variables = FindVariables(with_ir_node=True).visit(routine.body)
    assert len(variables) == 3
    assert all(isinstance(v, tuple) and len(v) == 2 for v in variables)

    # Verify that the variables in the loop definition are found
    loops = [v for v in variables if isinstance(v[0], ir.Loop)]
    assert len(loops) == 1
    assert sorted([str(v) for v in loops[0][1]]) == ['i', 'x']

    # Verify that the variables in the statements are found
    stmts = [v for v in variables if isinstance(v[0], ir.Assignment)]
    assert len(stmts) == 2

    assert sorted([str(v) for v in stmts[0][1]]) == ['i', 'scalar', 'vector(i)']
    assert sorted([str(v) for v in stmts[1][1]]) == ['i', 'matrix(i, :)', 'vector(i)']


@pytest.mark.parametrize('frontend', available_frontends())
def test_expression_callback_mapper(frontend):
    """
    Test the ExpressionFinder together with ExpressionCallbackMapper. This is just a very basic
    sanity check and does not cover all angles.
    """
    fcode = """
subroutine routine_simple (x, y, scalar, vector, matrix)
  integer, parameter :: jprb = selected_real_kind(13,300)
  integer, intent(in) :: x, y
  real(kind=jprb), intent(in) :: scalar
  real(kind=jprb), intent(inout) :: vector(x), matrix(x, y)
  integer :: i, j

  do i=1, x
    vector(i) = vector(i) + scalar
    do j=1, y
      if (j > i) then
        matrix(i, j) = real(i * j, kind=jprb) + 1.
      else
        matrix(i, j) = i * vector(j)
      end if
    end do
  end do
end subroutine routine_simple
"""

    # Test the internals of the subroutine
    routine = Subroutine.from_source(fcode, frontend=frontend)

    # Nonsense example that singles out anything that is a matrix
    def is_matrix(expr, *args, **kwargs):  # pylint: disable=unused-argument
        if isinstance(expr, sym.Array) and expr.type.shape and len(expr.type.shape) == 2:
            return expr
        return None

    class FindMatrix(ExpressionFinder):
        retriever = ExpressionCallbackMapper(
            callback=is_matrix,
            combine=lambda v: tuple(e for e in v if e is not None)
        )

    matrix_count = FindMatrix(unique=False).visit(routine.body)
    assert len(matrix_count) == 2

    matrix_count = FindMatrix().visit(routine.body)
    assert len(matrix_count) == 1
    assert str(matrix_count.pop()) == 'matrix(i, j)'


@pytest.mark.parametrize('frontend', available_frontends())
def test_expression_retriever_recurse_query(frontend):
    """
    Test the ExpressionRetriever with a custom recurse query that allows to terminate recursion
    early.
    """
    fcode = """
subroutine routine_simple (x, y, scalar, vector, matrix)
  integer, parameter :: jprb = selected_real_kind(13,300)
  integer, intent(in) :: x, y
  real(kind=jprb), intent(in) :: scalar
  real(kind=jprb), intent(inout) :: vector(x), matrix(x, y)
  integer :: i, j

  do i=1, x
    vector(i) = vector(i) + scalar
    do j=1, y
      if (j > i) then
        matrix(i, j) = real(i * j + 2, kind=jprb) + 1.
      else
        matrix(i, j) = i * vector(j)
      end if
    end do
  end do
end subroutine routine_simple
"""

    # Test the internals of the subroutine
    routine = Subroutine.from_source(fcode, frontend=frontend)

    # Find all literals except when they appear in array subscripts or loop ranges
    class FindLiteralsNotInSubscriptsOrRanges(ExpressionFinder):
        retriever = ExpressionRetriever(
            query=lambda expr: isinstance(expr, (sym.IntLiteral, sym.FloatLiteral, sym.LogicLiteral)),
            recurse_query=lambda expr, *args, **kwargs: not isinstance(expr, (sym.ArraySubscript, sym.LoopRange))
        )
    literals = FindLiteralsNotInSubscriptsOrRanges(unique=False).visit(routine.body)

    if frontend == OMNI:
        # OMNI substitutes jprb
        assert len(literals) == 4
        assert sorted([str(l) for l in literals]) == ['1.', '13', '2', '300']
    else:
        assert len(literals) == 2
        assert sorted([str(l) for l in literals]) == ['1.', '2']


@pytest.mark.parametrize('frontend', available_frontends())
def test_find_variables_associates(frontend):
    """
    Test correct discovery of variables in associates.
    """
    fcode = """
subroutine find_variables_associates (x, y, scalar, vector, matrix)
  integer, parameter :: jprb = selected_real_kind(13,300)
  integer, intent(in) :: x, y
  real(kind=jprb), intent(in) :: scalar
  real(kind=jprb), intent(inout) :: vector(x), matrix(x, y)
  integer :: i, j

  do i=1, x
    associate (v => vector(i), m => matrix(i, :))
      vector(i) = vector(i) + scalar
      do j=1, y
        if (j > i) then
          m(j) = real(i * j, kind=jprb) + 1.
        else
          matrix(i, j) = i * vector(j)
        end if
      end do
    end associate
  end do
end subroutine find_variables_associates
"""
    # Test the internals of the subroutine
    routine = Subroutine.from_source(fcode, frontend=frontend)

    variables = FindVariables(unique=False).visit(routine.body)
    assert len(variables) == 27 if frontend == OMNI else 28
    assert len([v for v in variables if v.name == 'v']) == 1
    assert len([v for v in variables if v.name == 'm']) == 2


@pytest.mark.parametrize('frontend', available_frontends())
def test_is_parent_of(frontend):
    """
    Test the ``is_parent_of`` utility.
    """
    fcode = """
subroutine test_is_parent_of
  implicit none
  integer :: a, j, n=10

  a = 0
  do j=1,n
    if (j > 3) then
      a = a + 1
    end if
  end do
end subroutine test_is_parent_of
    """.strip()
    routine = Subroutine.from_source(fcode, frontend=frontend)

    loop = FindNodes(ir.Loop).visit(routine.body)[0]
    conditional = FindNodes(ir.Conditional).visit(routine.body)[0]
    assignments = FindNodes(ir.Assignment).visit(routine.body)

    assert is_parent_of(loop, conditional)
    assert not is_parent_of(conditional, loop)

    for node in [loop, conditional]:
        assert {is_parent_of(node, a) for a in assignments} == {True, False}
        assert all(not is_parent_of(a, node) for a in assignments)


@pytest.mark.parametrize('frontend', available_frontends())
def test_is_child_of(frontend):
    """
    Test the ``is_child_of`` utility.
    """
    fcode = """
subroutine test_is_child_of
  implicit none
  integer :: a, j, n=10

  a = 0
  do j=1,n
    if (j > 3) then
      a = a + 1
    end if
  end do
end subroutine test_is_child_of
    """.strip()
    routine = Subroutine.from_source(fcode, frontend=frontend)

    loop = FindNodes(ir.Loop).visit(routine.body)[0]
    conditional = FindNodes(ir.Conditional).visit(routine.body)[0]
    assignments = FindNodes(ir.Assignment).visit(routine.body)

    assert not is_child_of(loop, conditional)
    assert is_child_of(conditional, loop)

    for node in [loop, conditional]:
        assert {is_child_of(a, node) for a in assignments} == {True, False}
        assert all(not is_child_of(node, a) for a in assignments)


@pytest.mark.parametrize('frontend', available_frontends())
def test_attach_scopes_associates(frontend, tmp_path):
    fcode = """
module attach_scopes_associates_mod
    implicit none

    type other_type
        integer :: foo
    end type other_type

    type some_type
        type(other_type) :: var
    end type some_type

contains

    subroutine attach_scopes_associates
        type(some_type) :: blah
        integer :: a

        associate(var=>blah%var)
            associate(bar=>5+3)
                a = var%foo
            end associate
        end associate
    end subroutine attach_scopes_associates
end module attach_scopes_associates_mod
    """.strip()

    module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
    routine = module['attach_scopes_associates']
    associates = FindNodes(ir.Associate).visit(routine.body)
    assert len(associates) == 2
    assignment = FindNodes(ir.Assignment).visit(routine.body)
    assert len(assignment) == 1
    assert len(FindVariables().visit(assignment)) == 3
    var_map = {str(var): var for var in FindVariables().visit(assignment)}
    assert len(var_map) == 3
    assert associates[1].parent is associates[0]
    assert var_map['a'].scope is routine
    assert var_map['var%foo'].scope is associates[0]
    assert var_map['var%foo'].parent.scope is associates[0]
    assert var_map['var%foo'].parent is var_map['var']
loki-ecmwf-0.3.6/loki/ir/tests/test_ir_nodes.py0000664000175000017500000004033615167130205021674 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from dataclasses import FrozenInstanceError

import pytest
from pymbolic.primitives import Expression
from pydantic import ValidationError

from loki.expression import symbols as sym, parse_expr
from loki.function import Function
from loki.ir import nodes as ir
from loki.types import Scope


@pytest.fixture(name='scope')
def fixture_scope():
    return Scope()

@pytest.fixture(name='one')
def fixture_one():
    return sym.Literal(1)

@pytest.fixture(name='i')
def fixture_i(scope):
    return sym.Scalar('i', scope=scope)

@pytest.fixture(name='n')
def fixture_n(scope):
    return sym.Scalar('n', scope=scope)

@pytest.fixture(name='a_i')
def fixture_a_i(scope, i):
    return sym.Array('a', dimensions=(i,), scope=scope)

@pytest.fixture(name='a_n')
def fixture_a_n(scope, n):
    return sym.Array('a', dimensions=(n,), scope=scope)


def test_assignment(scope, a_i):
    """
    Test constructors of :any:`Assignment`.
    """
    assign = ir.Assignment(lhs=a_i, rhs=sym.Literal(42.0))
    assert isinstance(assign.lhs, Expression)
    assert isinstance(assign.rhs, Expression)
    assert assign.comment is None

    # Ensure "frozen" status of node objects
    with pytest.raises(FrozenInstanceError):
        assign.lhs = sym.Scalar('b', scope=scope)
    with pytest.raises(FrozenInstanceError):
        assign.rhs = sym.Scalar('b', scope=scope)

    # Test errors for wrong contructor usage
    with pytest.raises(ValidationError):
        ir.Assignment(lhs='a', rhs=sym.Literal(42.0))
    with pytest.raises(ValidationError):
        ir.Assignment(lhs=a_i, rhs='42.0 + 6.0')
    with pytest.raises(ValidationError):
        ir.Assignment(lhs=a_i, rhs=sym.Literal(42.0), comment=a_i)


def test_loop(scope, one, i, n, a_i):
    """
    Test constructors of :any:`Loop`.
    """
    assign = ir.Assignment(lhs=a_i, rhs=sym.Literal(42.0))
    bounds = sym.Range((one, n))

    loop = ir.Loop(variable=i, bounds=bounds, body=(assign,))
    assert isinstance(loop.variable, Expression)
    assert isinstance(loop.bounds, Expression)
    assert isinstance(loop.body, tuple)
    assert all(isinstance(n, ir.Node) for n in loop.body)
    assert loop.children == ( i, bounds, (assign,) )

    # Ensure "frozen" status of node objects
    with pytest.raises(FrozenInstanceError):
        loop.variable = sym.Scalar('j', scope=scope)
    with pytest.raises(FrozenInstanceError):
        loop.bounds = sym.Range((n, sym.Scalar('k', scope=scope)))
    with pytest.raises(FrozenInstanceError):
        loop.body = (assign, assign, assign)

    # Test auto-casting of the body to tuple
    loop = ir.Loop(variable=i, bounds=bounds, body=assign)
    assert loop.body == (assign,)
    loop = ir.Loop(variable=i, bounds=bounds, body=( (assign,), ))
    assert loop.body == (assign,)
    loop = ir.Loop(variable=i, bounds=bounds, body=( assign, (assign,), assign, None))
    assert loop.body == (assign, assign, assign)

    # Test auto-casting with unnamed constructor args
    loop = ir.Loop(i, bounds, assign)
    assert loop.body == (assign,)
    loop = ir.Loop(i, bounds, [(assign,), None, assign])
    assert loop.body == (assign, assign)

    # Test errors for wrong contructor usage
    with pytest.raises(ValidationError):
        ir.Loop(variable=i, bounds=bounds, body=n)
    with pytest.raises(ValidationError):
        ir.Loop(variable=None, bounds=bounds, body=(assign,))
    with pytest.raises(ValidationError):
        ir.Loop(variable=i, bounds=None, body=(assign,))

    # TODO: Test pragmas, names and labels


def test_conditional(scope, n, a_i):
    """
    Test constructors of :any:`Conditional`.
    """
    assign = ir.Assignment(lhs=a_i, rhs=sym.Literal(42.0))
    condition = parse_expr('i >= 2', scope=scope)

    cond = ir.Conditional(
        condition=condition, body=(assign,assign,), else_body=(assign,)
    )
    assert isinstance(cond.condition, Expression)
    assert isinstance(cond.body, tuple) and len(cond.body) == 2
    assert all(isinstance(n, ir.Node) for n in cond.body)
    assert isinstance(cond.else_body, tuple) and len(cond.else_body) == 1
    assert all(isinstance(n, ir.Node) for n in cond.else_body)
    assert cond.children == ( condition, (assign, assign), (assign,) )

    with pytest.raises(FrozenInstanceError):
        cond.condition = parse_expr('k == 0', scope=scope)
    with pytest.raises(FrozenInstanceError):
        cond.body = (assign, assign, assign)
    with pytest.raises(FrozenInstanceError):
        cond.else_body = (assign, assign, assign)

    # Test auto-casting of the body / else_body to tuple
    cond = ir.Conditional(condition=condition, body=assign)
    assert cond.body == (assign,) and not cond.else_body
    cond = ir.Conditional(condition=condition, body=( (assign,), ))
    assert cond.body == (assign,) and not cond.else_body
    cond = ir.Conditional(condition=condition, body=( assign, (assign,), assign, None))
    assert cond.body == (assign, assign, assign) and not cond.else_body

    cond = ir.Conditional(condition=condition, body=(), else_body=assign)
    assert not cond.body and cond.else_body == (assign,)
    cond = ir.Conditional(condition=condition, body=(), else_body=( (assign,), ))
    assert not cond.body and cond.else_body == (assign,)
    cond = ir.Conditional(
        condition=condition, body=(), else_body=( assign, (assign,), assign, None)
    )
    assert not cond.body and cond.else_body == (assign, assign, assign)

    # Test auto-casting with unnamed constructor args
    cond = ir.Conditional(condition)
    assert cond.body is () and not cond.else_body
    cond = ir.Conditional(condition, assign)
    assert cond.body == (assign,) and not cond.else_body
    cond = ir.Conditional(condition, body=[assign, (assign,)], else_body=[assign, None, (assign,)])
    assert cond.body == (assign, assign) and cond.else_body == (assign, assign)

    # TODO: Test inline, name, has_elseif


def test_conditional_nested(i, a_i):
    """
    Test nested chains of constructors of :any:`Conditional` to form
    multi-conditional.
    """
    multicond = ir.Conditional(
        condition=sym.Comparison(i, '==', sym.IntLiteral(1)),
        body=ir.Assignment(lhs=a_i, rhs=sym.Literal(1.0)),
        else_body=ir.Assignment(lhs=a_i, rhs=sym.Literal(42.0))
    )
    for idx in range(2, 4):
        multicond = ir.Conditional(
            condition=sym.Comparison(i, '==', sym.IntLiteral(idx)),
            body=ir.Assignment(lhs=a_i, rhs=sym.Literal(float(idx))),
            else_body=multicond, has_elseif=True
        )

    # Check that we can recover all bodies from a nested else-if construct
    else_bodies = multicond.else_bodies
    assert len(else_bodies) == 3
    assert all(isinstance(b, tuple) for b in else_bodies)
    assert isinstance(else_bodies[0][0], ir.Assignment)
    assert else_bodies[0][0].lhs == 'a(i)' and else_bodies[0][0].rhs == '2.0'
    assert isinstance(else_bodies[1][0], ir.Assignment)
    assert else_bodies[1][0].lhs == 'a(i)' and else_bodies[1][0].rhs == '1.0'
    assert isinstance(else_bodies[2][0], ir.Assignment)
    assert else_bodies[2][0].lhs == 'a(i)' and else_bodies[2][0].rhs == '42.0'

    # Not try without the final else
    multicond = ir.Conditional(
        condition=sym.Comparison(i, '==', sym.IntLiteral(1)),
        body=ir.Assignment(lhs=a_i, rhs=sym.Literal(1.0)),
    )
    for idx in range(2, 4):
        multicond = ir.Conditional(
            condition=sym.Comparison(i, '==', sym.IntLiteral(idx)),
            body=ir.Assignment(lhs=a_i, rhs=sym.Literal(float(idx))),
            else_body=multicond, has_elseif=True
        )
    else_bodies = multicond.else_bodies
    assert len(else_bodies) == 2
    assert all(isinstance(b, tuple) for b in else_bodies)
    assert isinstance(else_bodies[0][0], ir.Assignment)
    assert else_bodies[0][0].lhs == 'a(i)' and else_bodies[0][0].rhs == '2.0'
    assert isinstance(else_bodies[1][0], ir.Assignment)
    assert else_bodies[1][0].lhs == 'a(i)' and else_bodies[1][0].rhs == '1.0'


def test_section(n, a_n, a_i):
    """
    Test constructors and behaviour of :any:`Section` nodes.
    """
    assign = ir.Assignment(lhs=a_i, rhs=sym.Literal(42.0))
    decl = ir.VariableDeclaration(symbols=(a_n,))
    func = Function(name='F', spec=(decl,), body=(assign,))

    # Test constructor for nodes and subroutine objects
    sec = ir.Section(body=(assign, assign))
    assert isinstance(sec.body, tuple) and len(sec.body) == 2
    assert all(isinstance(n, ir.Node) for n in sec.body)
    with pytest.raises(FrozenInstanceError):
        sec.body = (assign, assign)

    sec = ir.Section(body=(func, func))
    assert isinstance(sec.body, tuple) and len(sec.body) == 2
    assert all(isinstance(n, Scope) for n in sec.body)
    with pytest.raises(FrozenInstanceError):
        sec.body = (func, func)

    sec = ir.Section((assign, assign))
    assert sec.body == (assign, assign)

    # Test auto-casting of the body to tuple
    sec = ir.Section(body=assign)
    assert sec.body == (assign,)
    sec = ir.Section(body=( (assign,), ))
    assert sec.body == (assign,)
    sec = ir.Section(body=( assign, (assign,), assign, None))
    assert sec.body == (assign, assign, assign)
    sec = ir.Section((assign, (func,), assign, None))
    assert sec.body == (assign, func, assign)

    # Test auto-casting with unnamed constructor args
    sec = ir.Section()
    assert sec.body is ()
    sec = ir.Section([(assign,), assign, None, assign])
    assert sec.body == (assign, assign, assign)

    # Test prepend/insert/append additions
    sec = ir.Section(body=func)
    assert sec.body == (func,)
    sec.prepend(assign)
    assert sec.body == (assign, func)
    sec.append((assign, assign))
    assert sec.body == (assign, func, assign, assign)
    sec.insert(pos=3, node=func)
    assert sec.body == (assign, func, assign, func, assign)


def test_callstatement(scope, one, i, n, a_i):
    """ Test constructor of :any:`CallStatement` nodes. """

    cname = sym.ProcedureSymbol(name='test', scope=scope)
    call = ir.CallStatement(
        name=cname, arguments=(n, a_i), kwarguments=(('i', i), ('j', one))
    )
    assert isinstance(call.name, Expression)
    assert isinstance(call.arguments, tuple)
    assert all(isinstance(e, Expression) for e in call.arguments)
    assert isinstance(call.kwarguments, tuple)
    assert all(isinstance(e, tuple) for e in call.kwarguments)
    assert all(
        isinstance(k, str) and isinstance(v, Expression)
        for k, v in call.kwarguments
    )

    # Ensure "frozen" status of node objects
    with pytest.raises(FrozenInstanceError):
        call.name = sym.ProcedureSymbol('dave', scope=scope)
    with pytest.raises(FrozenInstanceError):
        call.arguments = (a_i, n, one)
    with pytest.raises(FrozenInstanceError):
        call.kwarguments = (('i', one), ('j', i))

    # Test auto-casting of the body to tuple
    call = ir.CallStatement(name=cname, arguments=[a_i, one])
    assert call.arguments == (a_i, one) and not call.kwarguments
    call = ir.CallStatement(name=cname, arguments=None)
    assert not call.arguments and not call.kwarguments
    call = ir.CallStatement(name=cname, kwarguments=[('i', i), ('j', one)])
    assert not call.arguments and call.kwarguments == (('i', i), ('j', one))
    call = ir.CallStatement(name=cname, kwarguments=None)
    assert not call.arguments and not call.kwarguments

    # Test auto-casting with unnamed constructor args
    call = ir.CallStatement(cname, a_i)
    assert call.arguments == (a_i,) and not call.kwarguments
    call = ir.CallStatement(cname, [a_i, one], [('i', i), ('j', one)])
    assert call.arguments == (a_i, one) and call.kwarguments == (('i', i), ('j', one))

    # Test errors for wrong contructor usage
    with pytest.raises(ValidationError):
        ir.CallStatement(name='a', arguments=(sym.Literal(42.0),))
    with pytest.raises(ValidationError):
        ir.CallStatement(name=cname, arguments=('a',))
    with pytest.raises(ValidationError):
        ir.Assignment(
            name=cname, arguments=(sym.Literal(42.0),), kwarguments=('i', 'i')
        )

    # TODO: Test pragmas, active and chevron


def test_associate(scope, a_i):
    """
    Test constructors and scoping bahviour of :any:`Associate`.
    """
    b = sym.Scalar(name='b', scope=scope)
    b_a = sym.Array(name='a', parent=b, scope=scope)
    a = sym.Array(name='a', scope=scope)
    assign = ir.Assignment(lhs=a_i, rhs=sym.Literal(42.0))
    assign2 = ir.Assignment(lhs=a_i.clone(parent=b), rhs=sym.Literal(66.6))

    assoc = ir.Associate(associations=((b_a, a),), body=(assign, assign2), parent=scope)  # pylint: disable=unexpected-keyword-arg
    assert isinstance(assoc.associations, tuple)
    assert all(isinstance(n, tuple) and len(n) == 2 for n in assoc.associations)
    assert isinstance(assoc.body, tuple)
    assert all(isinstance(n, ir.Node) for n in assoc.body)

    # TODO: Check constructor failures, auto-casting and frozen status

    # Check provided symbol maps
    assert 'B%a' in assoc.association_map and assoc.association_map['B%a'] is a
    assert b_a in assoc.association_map and assoc.association_map[b_a] is a
    assert 'a' in assoc.inverse_map and assoc.inverse_map['a'] is b_a
    assert a in assoc.inverse_map and assoc.inverse_map[a] is b_a

    # Check rescoping facility
    assert assign.lhs.scope is scope
    assert assign2.lhs.scope is scope
    assoc.rescope_symbols()
    assert assign.lhs.scope is assoc
    assert assign2.lhs.scope is scope


def test_multiconditional(scope, a_i, i):
    """
    Test constructors and scoping behaviour of :any:`MultiConditional`.
    """
    assign1 = ir.Assignment(lhs=a_i, rhs=sym.Literal(42.0))
    assign2 = ir.Assignment(lhs=a_i, rhs=sym.Literal(66.6))
    assign3 = ir.Assignment(lhs=a_i, rhs=sym.Literal(12.3))
    values = ((sym.Literal(1),), (sym.Literal(2),))
    multicond = ir.MultiConditional(
        expr=i, values=values, bodies=((assign1,), (assign2,)), else_body=(assign3,)
    )
    assert isinstance(multicond.expr, Expression)
    assert isinstance(multicond.values, tuple)
    assert all(isinstance(e, tuple) for e in multicond.values)
    assert all(
        all(isinstance(v, Expression) for v in val) for val in multicond.values
    )
    assert isinstance(multicond.bodies, tuple)
    assert all(isinstance(b, tuple) for b in multicond.bodies)
    assert all(
        all(isinstance(b, ir.Node) for b in body) for body in multicond.bodies
    )
    assert isinstance(multicond.else_body, tuple)
    assert all(isinstance(b, ir.Node) for b in multicond.else_body)
    assert multicond.children == (
        i, ((sym.Literal(1),), (sym.Literal(2),)), ((assign1,), (assign2,)), (assign3,)
    )

    # Ensure "frozen" status of node objects
    with pytest.raises(FrozenInstanceError):
        multicond.expr = parse_expr('k', scope=scope)
    with pytest.raises(FrozenInstanceError):
        multicond.values = ((sym.Literal(3),), (sym.Literal(4),))
    with pytest.raises(FrozenInstanceError):
        multicond.bodies = ((assign1,), (assign1,))
    with pytest.raises(FrozenInstanceError):
        multicond.else_body = (assign1, assign2)

    # Test auto-casting of the bodies and else_body to (nested) tuple(s)
    multicond = ir.MultiConditional(
        expr=i, values=sym.Literal(1), bodies=(()), else_body=()
    )
    assert multicond.values == ((sym.Literal(1),),)
    multicond = ir.MultiConditional(
        expr=i, values=((sym.Literal(1),),sym.Literal(2),), bodies=(()), else_body=()
    )
    assert multicond.values == ((sym.Literal(1),), (sym.Literal(2),))
    multicond = ir.MultiConditional(
        expr=i, values=values, bodies=assign1, else_body=()
    )
    assert multicond.bodies == ((assign1,),)
    multicond = ir.MultiConditional(
        expr=i, values=values, bodies=((assign1,), assign2), else_body=()
    )
    assert multicond.bodies == ((assign1,), (assign2,))
    multicond = ir.MultiConditional(
        expr=i, values=(()), bodies=(()), else_body=assign3
    )
    assert multicond.else_body == (assign3,)
    multicond = ir.MultiConditional(
        expr=i, values=(()), bodies=(()), else_body=((assign3,), assign2)
    )
    assert multicond.else_body == (assign3, assign2)
loki-ecmwf-0.3.6/loki/ir/tests/test_control_flow.py0000664000175000017500000005156215167130205022604 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest
import numpy as np

from loki import Subroutine
from loki.backend import fgen
from loki.jit_build import jit_compile, clean_test
from loki.frontend import available_frontends, OMNI
from loki.ir import nodes as ir, FindNodes


@pytest.mark.parametrize('frontend', available_frontends())
def test_loop_nest_fixed(tmp_path, frontend):
    """
    Test basic loops and reductions with fixed sizes.

    Basic loop nest loop:
        out1(i, j) = in1(i, j) + in2(i, j)

    Basic reduction:
        out2(j) = out2(j) + in1(i, j) * in1(i, j)
    """

    fcode = """
subroutine loop_nest_fixed(in1, in2, out1, out2)

  integer, parameter :: jprb = selected_real_kind(13,300)
  real(kind=jprb), dimension(3, 2), intent(in) :: in1, in2
  real(kind=jprb), intent(inout) :: out1(3, 2), out2(2)
  integer :: i, j

  do j=1, 2
     do i=1, 3
        out1(i, j) = in1(i, j) + in2(i, j)
     end do
  end do

  do j=1, 2
     out2(j) = 0.
     do i=1, 3
        out2(j) = out2(j) + in1(i, j) * in2(i, j)
     end do
  end do
end subroutine loop_nest_fixed
"""
    filepath = tmp_path/(f'control_flow_loop_nest_fixed_{frontend}.f90')
    routine = Subroutine.from_source(fcode, frontend=frontend)
    function = jit_compile(routine, filepath=filepath, objname='loop_nest_fixed')

    in1 = np.array([[1., 2.], [2., 3.], [3., 4.]], order='F')
    in2 = np.array([[2., 3.], [3., 4.], [4., 5.]], order='F')
    out1 = np.zeros((3, 2), order='F')
    out2 = np.zeros(2, order='F')

    function(in1, in2, out1, out2)
    assert (out1 == [[3, 5], [5, 7], [7, 9]]).all()
    assert (out2 == [20, 38]).all()
    clean_test(filepath)


@pytest.mark.parametrize('frontend', available_frontends())
def test_loop_nest_variable(tmp_path, frontend):
    """
    Test basic loops and reductions with passed sizes.

    Basic loop nest loop:
        out1(i, j) = in1(i, j) + in2(i, j)

    Basic reduction:
        out2(j) = out2(j) + in1(i, j) * in1(i, j)
    """

    fcode = """
subroutine loop_nest_variable(dim1, dim2, in1, in2, out1, out2)
  integer, parameter :: jprb = selected_real_kind(13,300)
  integer, intent(in) :: dim1, dim2
  real(kind=jprb), dimension(dim1, dim2), intent(in) :: in1, in2
  real(kind=jprb), intent(inout) :: out1(dim1, dim2), out2(dim2)

  integer :: i, j

  do j=1, dim2
     do i=1, dim1
        out1(i, j) = in1(i, j) + in2(i, j)
     end do
  end do

  do j=1, dim2
     out2(j) = 0.
     do i=1, dim1
        out2(j) = out2(j) + in1(i, j) * in2(i, j)
     end do
  end do
end subroutine loop_nest_variable
"""
    filepath = tmp_path/(f'control_flow_loop_nest_variable_{frontend}.f90')
    routine = Subroutine.from_source(fcode, frontend=frontend)
    function = jit_compile(routine, filepath=filepath, objname='loop_nest_variable')

    in1 = np.array([[1., 2.], [2., 3.], [3., 4.]], order='F')
    in2 = np.array([[2., 3.], [3., 4.], [4., 5.]], order='F')
    out1 = np.zeros((3, 2), order='F')
    out2 = np.zeros(2, order='F')

    function(3, 2, in1, in2, out1, out2)
    assert (out1 == [[3, 5], [5, 7], [7, 9]]).all()
    assert (out2 == [20, 38]).all()
    clean_test(filepath)


@pytest.mark.parametrize('frontend', available_frontends())
def test_loop_scalar_logical_expr(tmp_path, frontend):
    """
    Test a while loop with a logical expression as condition.
    """

    fcode = """
subroutine loop_scalar_logical_expr(outvar)
  integer, intent(out) :: outvar

  outvar = 0
  do while (outvar < 5)
    outvar = outvar + 1
  end do
end subroutine loop_scalar_logical_expr
"""
    filepath = tmp_path/(f'control_flow_loop_scalar_logical_expr_{frontend}.f90')
    routine = Subroutine.from_source(fcode, frontend=frontend)
    function = jit_compile(routine, filepath=filepath, objname='loop_scalar_logical_expr')

    outvar = function()
    assert outvar == 5
    clean_test(filepath)


@pytest.mark.parametrize('frontend', available_frontends())
def test_loop_unbounded(tmp_path, frontend):
    """
    Test unbounded loops.
    """

    fcode = """
subroutine loop_unbounded(out)
  integer, intent(out) :: out

  out = 1
  do
    out = out + 1
    if (out > 5) then
      exit
    endif
  enddo
end subroutine loop_unbounded
"""
    filepath = tmp_path/(f'control_flow_loop_unbounded_{frontend}.f90')
    routine = Subroutine.from_source(fcode, frontend=frontend)
    function = jit_compile(routine, filepath=filepath, objname='loop_unbounded')

    outvar = function()
    assert outvar == 6
    clean_test(filepath)


@pytest.mark.parametrize('frontend', available_frontends())
def test_loop_labeled_continue(tmp_path, frontend):
    """
    Test labeled loops with continue statement.

    Note that this does not get represented 1:1 as we always insert ENDDO
    statements in fgen. But this does not harm the outcome as the resulting
    loop behaviour will still be the same.
    """

    fcode = """
subroutine loop_labeled_continue(out)
  integer, intent(out) :: out
  integer :: j

  out = 1
  do 101 j=1,10
    out = out + 1
101 continue
end subroutine loop_labeled_continue
"""
    filepath = tmp_path/(f'control_flow_loop_labeled_continue_{frontend}.f90')
    routine = Subroutine.from_source(fcode, frontend=frontend)

    if frontend != OMNI:  # OMNI doesn't read the Loop label...
        assert FindNodes(ir.Loop).visit(routine.ir)[0].loop_label == '101'

    function = jit_compile(routine, filepath=filepath, objname='loop_labeled_continue')

    outvar = function()
    assert outvar == 11
    clean_test(filepath)


@pytest.mark.parametrize('frontend', available_frontends())
def test_inline_conditionals(tmp_path, frontend):
    """
    Test the use of inline conditionals.
    """

    fcode = """
subroutine inline_conditionals(in1, in2, out1, out2)
  integer, intent(in) :: in1, in2
  integer, intent(out) :: out1, out2

  out1 = in1
  out2 = in2

  if (in1 < 0) out1 = 0
  if (in2 > 5) out2 = 5
end subroutine inline_conditionals
"""
    filepath = tmp_path/(f'control_flow_inline_conditionals_{frontend}.f90')
    routine = Subroutine.from_source(fcode, frontend=frontend)
    function = jit_compile(routine, filepath=filepath, objname='inline_conditionals')

    in1, in2 = 2, 2
    out1, out2 = function(in1, in2)
    assert out1 == 2 and out2 == 2

    in1, in2 = -2, 10
    out1, out2 = function(in1, in2)
    assert out1 == 0 and out2 == 5
    clean_test(filepath)


@pytest.mark.parametrize('frontend', available_frontends())
def test_multi_body_conditionals(tmp_path, frontend):
    fcode = """
subroutine multi_body_conditionals(in1, out1, out2)
  integer, intent(in) :: in1
  integer, intent(out) :: out1, out2

  if (in1 > 5) then
    out1 = 5
  else
    out1 = 1
  end if

  if (in1 < 0) then
    out2 = 0
  else if (in1 > 5) then
    out2 = 6
    out2 = out2 - 1
  else if (3 < in1 .and. in1 <= 5) then
    out2 = 4
  else
    out2 = in1
  end if
end subroutine multi_body_conditionals
"""
    filepath = tmp_path/(f'control_flow_multi_body_conditionals_{frontend}.f90')
    routine = Subroutine.from_source(fcode, frontend=frontend)

    conditionals = FindNodes(ir.Conditional).visit(routine.body)
    assert len(conditionals) == 4
    if frontend != OMNI:
        assert sum(int(cond.has_elseif) for cond in conditionals) == 2

    function = jit_compile(routine, filepath=filepath, objname='multi_body_conditionals')

    out1, out2 = function(5)
    assert out1 == 1 and out2 == 4

    out1, out2 = function(2)
    assert out1 == 1 and out2 == 2

    out1, out2 = function(-1)
    assert out1 == 1 and out2 == 0

    out1, out2 = function(10)
    assert out1 == 5 and out2 == 5
    clean_test(filepath)


@pytest.mark.parametrize('frontend', available_frontends())
def test_goto_stmt(tmp_path, frontend):
    fcode = """
subroutine goto_stmt(var)
  implicit none
  integer, intent(out) :: var
  var = 3
  go to 1234
  var = 5
  1234 return
  var = 7
end subroutine goto_stmt
"""
    filepath = tmp_path/(f'control_flow_goto_stmt_{frontend}.f90')
    routine = Subroutine.from_source(fcode, frontend=frontend)
    function = jit_compile(routine, filepath=filepath, objname='goto_stmt')

    result = function()
    assert result == 3
    clean_test(filepath)


@pytest.mark.parametrize('frontend', available_frontends())
def test_select_case(tmp_path, frontend):
    fcode = """
subroutine select_case(cmd, out1)
  implicit none
  integer, intent(in) :: cmd
  integer, intent(out) :: out1

  select case (cmd)
    case (0)
      out1 = 0
    case (1:9)
      out1 = 1
    case (10, 11)
      out1 = 2
    case default
      out1 = -1
  end select
end subroutine select_case
"""
    filepath = tmp_path/(f'control_flow_select_case_{frontend}.f90')
    routine = Subroutine.from_source(fcode, frontend=frontend)
    function = jit_compile(routine, filepath=filepath, objname='select_case')

    in_out_pairs = {0: 0, 1: 1, 2: 1, 5: 1, 9: 1, 10: 2, 11: 2, 12: -1}
    for cmd, ref in in_out_pairs.items():
        out1 = function(cmd)
        assert out1 == ref
    clean_test(filepath)


@pytest.mark.parametrize('frontend', available_frontends())
def test_select_case_nested(tmp_path, frontend):
    fcode = """
subroutine select_case(cmd, out1)
  implicit none
  integer, intent(in) :: cmd
  integer, intent(out) :: out1

  out1 = -1000

  ! comment 1
  select case (cmd)
    ! comment 2
    case (0)
      out1 = 0
    ! comment 3
    case (1:9)
      out1 = 1
      select case (cmd)
        case (2:3)
          out1 = out1 + 100
        case (4:5)
          out1 = out1 + 200
      end select
    ! comment 4
    ! comment 5

    ! comment 6
    case (10, 11)
      out1 = 2
    ! comment 7
    case default
      out1 = -1
  end select
end subroutine select_case
"""
    filepath = tmp_path/(f'control_flow_select_case_nested_{frontend}.f90')
    routine = Subroutine.from_source(fcode, frontend=frontend)
    function = jit_compile(routine, filepath=filepath, objname='select_case')

    in_out_pairs = {0: 0, 1: 1, 2: 101, 5: 201, 9: 1, 10: 2, 11: 2, 12: -1}
    for cmd, ref in in_out_pairs.items():
        out1 = function(cmd)
        assert out1 == ref
    clean_test(filepath)

    assert routine.to_fortran().count('! comment') == 7


@pytest.mark.parametrize('frontend', available_frontends())
def test_cycle_stmt(tmp_path, frontend):
    fcode = """
subroutine cycle_stmt(var)
  implicit none
  integer, intent(out) :: var
  integer :: i

  var = 0
  do i=1,10
    if (var > 5) cycle
    var = var + 1
  end do
end subroutine cycle_stmt
"""
    filepath = tmp_path/(f'control_flow_cycle_stmt_{frontend}.f90')
    routine = Subroutine.from_source(fcode, frontend=frontend)
    function = jit_compile(routine, filepath=filepath, objname='cycle_stmt')

    result = function()
    assert result == 6
    clean_test(filepath)


@pytest.mark.parametrize('frontend', available_frontends())
def test_conditional_bodies(frontend):
    """Verify that conditional bodies and else-bodies are tuples of :class:`Node`"""
    fcode = """
subroutine conditional_body(nanana, zzzzz, trololo, tralala, xoxoxoxo, yyyyyy, kidia, kfdia)
integer, intent(inout) :: nanana, zzzzz, trololo, tralala, xoxoxoxo, yyyyyy, kidia, kfdia
integer :: jlon
if (nanana == 1) then
    zzzzz = 1
else
    zzzzz = 4
end if
if (trololo == 1) then
    tralala = 1
else if (trololo == 2) then
    tralala = 2
else if (trololo == 3) then
    tralala = 3
else
    tralala = 4
end if
if (xoxoxoxo == 1) then
    do jlon = kidia, kfdia
        yyyyyy = 1
    enddo
else
    do jlon = kidia, kfdia
        yyyyyy = 4
    enddo
end if
end subroutine conditional_body
    """.strip()

    routine = Subroutine.from_source(fcode, frontend=frontend)
    conditionals = FindNodes(ir.Conditional).visit(routine.ir)
    assert len(conditionals) == 5
    assert all(
        c.body and isinstance(c.body, tuple) and all(isinstance(n, ir.Node) for n in c.body)
        for c in conditionals
    )
    assert all(
        c.else_body and isinstance(c.else_body, tuple) and all(isinstance(n, ir.Node) for n in c.else_body)
        for c in conditionals
    )


@pytest.mark.parametrize('frontend', available_frontends())
def test_conditional_else_body_return(frontend):
    fcode = """
FUNCTION FUNC(PX,KN)
IMPLICIT NONE
INTEGER,INTENT(INOUT) :: KN
REAL,INTENT(IN) :: PX
REAL :: FUNC
INTEGER :: J
REAL :: Z0, Z1, Z2
Z0= 1.0
Z1= PX
IF (KN == 0) THEN
  FUNC= Z0
  RETURN
ELSEIF (KN == 1) THEN
  FUNC= Z1
  RETURN
ELSE
  DO J=2,KN
    Z2= Z0+Z1
    Z0= Z1
    Z1= Z2
  ENDDO
  FUNC= Z2
  RETURN
ENDIF
END FUNCTION FUNC
    """.strip()

    routine = Subroutine.from_source(fcode, frontend=frontend)
    conditionals = FindNodes(ir.Conditional).visit(routine.body)
    assert len(conditionals) == 2
    assert isinstance(conditionals[0].body[-1], ir.Intrinsic)
    assert conditionals[0].body[-1].text.upper() == 'RETURN'
    assert conditionals[0].else_body == (conditionals[1],)
    assert isinstance(conditionals[1].body[-1], ir.Intrinsic)
    assert conditionals[1].body[-1].text.upper() == 'RETURN'
    assert isinstance(conditionals[1].else_body[-1], ir.Intrinsic)
    assert conditionals[1].else_body[-1].text.upper() == 'RETURN'


@pytest.mark.parametrize('frontend', available_frontends(
        xfail=[(OMNI, 'Renames index variable to omnitmp000')]
))
def test_single_line_forall_stmt(tmp_path, frontend):
    fcode = """
subroutine forall_stmt(n, a)
    implicit none
    integer, parameter :: jprb = selected_real_kind(13,300)
    integer, intent(in) :: n
    real(kind=jprb), dimension(n, n), intent(inout) :: a
    integer :: i

    ! Create a diagonal square matrix
    forall (i=1:n)  a(i, i) = 1
end subroutine forall_stmt
    """.strip()
    routine = Subroutine.from_source(fcode, frontend=frontend)

    # Check generated IR for the Forall statement
    statements = FindNodes(ir.Forall).visit(routine.ir)
    assert len(statements) == 1
    # Check the i=1:n bound
    assert len(statements[0].named_bounds) == 1
    bound_var, bound_range = statements[0].named_bounds[0]
    assert bound_var.name == 'i'
    assert bound_range == '1:n'
    # Check the a(i, i) = 1 assignment
    assignments = FindNodes(ir.Assignment).visit(statements[0])
    assert len(assignments) == 1, "Single-line FORALL statement must have only one assignment"
    assert assignments[0].lhs == "a(i, i)"  # Assign to array `a`
    assert assignments[0].rhs == '1'  # Assign 1 on the diagonal

    # Check execution and produced results
    filepath = tmp_path/f'single_line_forall_stmt_{frontend}.f90'
    fun_forall_stmt = jit_compile(routine, filepath=filepath, objname="forall_stmt")
    n = 3
    a = np.zeros((n, n), order="F")
    fun_forall_stmt(n, a)
    assert (a == [[1.0, 0.0, 0.0],
                  [0.0, 1.0, 0.0],
                  [0.0, 0.0, 1.0]]).all()
    n = 5
    a = np.empty((n, n), order="F")
    a.fill(3.0)
    fun_forall_stmt(n, a)
    assert (a == [[1.0, 3.0, 3.0, 3.0, 3.0],
                  [3.0, 1.0, 3.0, 3.0, 3.0],
                  [3.0, 3.0, 1.0, 3.0, 3.0],
                  [3.0, 3.0, 3.0, 1.0, 3.0],
                  [3.0, 3.0, 3.0, 3.0, 1.0]]).all()

    # Check the fgen code generation
    expected_fcode = "FORALL(i = 1:n) a(i, i) = 1"
    assert fgen(statements[0]) == expected_fcode
    assert expected_fcode in routine.to_fortran()


@pytest.mark.parametrize('frontend', available_frontends(
        xfail=[(OMNI, 'Renames index variable to omnitmp000')]
))
def test_single_line_forall_masked_stmt(tmp_path, frontend):
    fcode = """
subroutine forall_masked_stmt(n, a, b)
    implicit none
    integer, parameter :: jprb = selected_real_kind(13,300)
    integer, intent(in) :: n
    real(kind=jprb), dimension(n, n), intent(inout) :: a, b
    integer :: i, j

    forall(i = 1:n, j = 1:n, a(i, j) .ne. 0.0) b(i, j) = 1.0 / a(i, j)
end subroutine forall_masked_stmt
    """.strip()
    routine = Subroutine.from_source(fcode, frontend=frontend)

    # Check generated IR for the Forall statement
    statements = FindNodes(ir.Forall).visit(routine.ir)
    assert len(statements) == 1
    assert len(statements[0].named_bounds) == 2
    # Check the i=1:n bound
    bound_var, bound_range = statements[0].named_bounds[0]
    assert bound_var == "i"
    assert bound_range == '1:n'
    # Check the j=1:n bound
    bound_var, bound_range = statements[0].named_bounds[1]
    assert bound_var == "j"
    assert bound_range == '1:n'
    # Check the array mask
    assert statements[0].mask == 'a(i, j) != 0.0'
    # Quickly check assignment
    assignments = FindNodes(ir.Assignment).visit(statements[0])
    assert len(assignments) == 1
    assert assignments[0].lhs.name == "b" and len(assignments[0].lhs.dimensions) == 2
    assert assignments[0].rhs == '1.0 / a(i, j)'

    # Check execution and produced results
    filepath = tmp_path / (f'single_line_forall_masked_stmt_{frontend}.f90')
    fun_forall_masked_stmt = jit_compile(routine, filepath=filepath, objname="forall_masked_stmt")
    n = 3
    a = np.array([[2.0, 0.0, 2.0],
                  [0.0, 4.0, 0.0],
                  [10.0, 10.0, 0.0]], order="F")
    b = np.zeros((n, n), order="F")
    fun_forall_masked_stmt(n, a, b)
    assert (b == [[0.5, 0.0, 0.5], [0, 0.25, 0], [0.1, 0.1, 0]]).all()

    # Check the fgen code generation
    expected_fcode = "FORALL(i = 1:n, j = 1:n, a(i, j) /= 0.0) b(i, j) = 1.0 / a(i, j)"
    assert fgen(statements[0]) == expected_fcode
    assert expected_fcode in routine.to_fortran()


@pytest.mark.parametrize('frontend', available_frontends(xfail=[
    (OMNI, 'Renames index variable to omnitmp000'),
]))
def test_multi_line_forall_construct(tmp_path, frontend):
    fcode = """
subroutine forall_construct(n, c, d)
    implicit none
    integer, parameter :: jprb = selected_real_kind(13,300)
    integer, intent(in) :: n
    real(kind=jprb), dimension(n, n), intent(inout) :: c, d
    integer :: i, j

    forall(i = 3:n - 2, j = 3:n - 2)
        c(i, j) = c(i, j + 2) + c(i, j - 2) + c(i + 2, j) + c(i - 2, j)
        d(i, j) = c(i, j)
    end forall
end subroutine forall_construct
    """.strip()
    routine = Subroutine.from_source(fcode, frontend=frontend)

    # Check generated IR for the Forall statement
    statements = FindNodes(ir.Forall).visit(routine.ir)
    assert len(statements) == 1
    assert len(statements[0].named_bounds) == 2
    # Check the i=3:(n-2) bound
    bound_var, bound_range = statements[0].named_bounds[0]
    assert bound_var.name == "i"
    assert bound_range == '3:n-2'
    # Check the j=3:(n-2) bound
    bound_var, bound_range = statements[0].named_bounds[1]
    assert bound_var.name == "j"
    assert bound_range == '3:n-2'
    # Check assignments
    assignments = FindNodes(ir.Assignment).visit(statements[0])
    assert len(assignments) == 2
    # Quickly check first assignment
    assert assignments[0].lhs == 'c(i, j)'
    assert assignments[0].rhs == 'c(i, j + 2) + c(i, j - 2) + c(i + 2, j) + c(i - 2, j)'
    # Check the second assignment
    assert assignments[1].lhs == 'd(i, j)'
    assert assignments[1].rhs == 'c(i, j)'

    filepath = tmp_path / (f'multi_line_forall_construct_{frontend}.f90')
    fun_forall_construct = jit_compile(routine, filepath=filepath, objname="forall_construct")
    n = 6
    c = np.zeros((n, n), order="F")
    c.fill(1)
    d = np.zeros((n, n), order="F")
    fun_forall_construct(n, c, d)
    assert (c == [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
                  [1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
                  [1.0, 1.0, 4.0, 4.0, 1.0, 1.0],
                  [1.0, 1.0, 4.0, 4.0, 1.0, 1.0],
                  [1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
                  [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]).all()
    assert (d == [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                  [0.0, 0.0, 4.0, 4.0, 0.0, 0.0],
                  [0.0, 0.0, 4.0, 4.0, 0.0, 0.0],
                  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]).all()

    # Check the fgen code generation
    regenerated_code = routine.to_fortran().split("\n")
    assert regenerated_code[7].strip() == "FORALL(i = 3:n - 2, j = 3:n - 2)"
    assert regenerated_code[8].strip() == "c(i, j) = c(i, j + 2) + c(i, j - 2) + c(i + 2, j) + c(i - 2, j)"
    assert regenerated_code[9].strip() == "d(i, j) = c(i, j)"
    assert regenerated_code[10].strip() == "END FORALL"


@pytest.mark.parametrize('frontend', available_frontends(
    xfail=[(OMNI, 'No support for Cray Pointers')]
))
def test_cray_pointers(frontend):
    fcode = """
SUBROUTINE SUBROUTINE_WITH_CRAY_POINTER (KLON,KLEV,POOL)
IMPLICIT NONE
INTEGER, INTENT(IN) :: KLON, KLEV
REAL, INTENT(INOUT) :: POOL(:)
REAL, DIMENSION(KLON,KLEV) :: ZQ
POINTER(IP_ZQ, ZQ)
IP_ZQ = LOC(POOL)
END SUBROUTINE
    """.strip()
    routine = Subroutine.from_source(fcode, frontend=frontend)
    intrinsics = FindNodes(ir.Intrinsic).visit(routine.spec)
    assert len(intrinsics) == 2
    assert 'IMPLICIT NONE' in intrinsics[0].text
    assert 'POINTER(IP_ZQ, ZQ)' in intrinsics[1].text
    assert 'POINTER(IP_ZQ, ZQ)' in routine.to_fortran()
loki-ecmwf-0.3.6/loki/ir/transformer.py0000664000175000017500000005572615167130205020244 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Visitor classes for transforming the IR
"""
from loki.frontend.source import Source
from loki.ir.nodes import Node, Conditional, ScopedNode
from loki.ir.visitor import Visitor
from loki.tools import flatten, is_iterable, as_tuple, replace_windowed


__all__ = [
    'Transformer', 'NestedTransformer', 'MaskedTransformer',
    'NestedMaskedTransformer'
]


def is_source_valid(source):
    """ Determine the validty status of a given :any:`Node` """
    if source and isinstance(source, Source):
        return source.is_valid()
    return False


class Transformer(Visitor):
    r"""
    Visitor class to rebuild the tree and replace nodes according to a mapper.

    Given a control flow tree :math:`T` and a mapper from nodes in :math:`T`
    to a set of new nodes :math:`L, M : N \rightarrow L`, build a new control
    flow tree :math:`T'` where a node :math:`n \in N` is replaced with
    :math:`M(n)`.

    .. important::
       The mapping is applied before visiting any children of a node.

    *Removing nodes*: In the special case in which :math:`M(n)` is `None`,
    :math:`n` is dropped from :math:`T'`.

    *One to many mapping*: In the special case in which :math:`M(n)` is an
    iterable of nodes, all nodes in :math:`M(n)` are inserted into the tuple
    containing :math:`n`.

    .. warning::
       Applying a :class:`Transformer` to an IR tree rebuilds all nodes by
       default, which means individual nodes from the original IR are no longer
       found in the new tree. To update references to IR nodes, the attribute
       :any:`Transformer.rebuilt` provides a mapping from original to rebuilt
       nodes. Alternatively, with :data:`inplace` the mapping can be
       applied without rebuilding the tree, leaving existing references to
       individual IR nodes intact (as long as the mapping does not replace or
       remove them in the tree).

    Parameters
    ----------
    mapper : dict
        The mapping :math:`M : N \rightarrow L`.
    invalidate_source : bool, optional
        If set to `True`, this triggers invalidating the :data:`source`
        property of all parent nodes of a node :math:`n` if :math:`M(n)`
        has :data:`source=None`.
    inplace : bool, optional
        If set to `True`, all updates are performed on existing :any:`Node`
        objects, instead of rebuilding them, keeping the original tree intact.
    rebuild_scopes : bool, optional
        If set to `True`, this will also rebuild :class:`ScopedNode` in the IR.
        This requires updating :attr:`TypedSymbol.scope` properties, which is
        expensive and thus carried out only when explicitly requested.

    Attributes
    ----------
    rebuilt : dict
        After applying the :class:`Transformer` to an IR, this contains a
        mapping :math:`n \rightarrow n'` for every node of the original tree
        :math:`n \in T` to the rebuilt nodes in the new tree :math:`n' \in T'`.
    """

    def __init__(self, mapper=None, invalidate_source=True, inplace=False, rebuild_scopes=False):
        super().__init__()
        self.mapper = mapper.copy() if mapper is not None else {}
        self.invalidate_source = invalidate_source
        self.rebuilt = {}
        self.inplace = inplace
        self.rebuild_scopes = rebuild_scopes

    def _rebuild(self, o, children, **args):
        """
        Utility method to rebuild the given node with the provided children.

        If :data:`invalidate_source` is `True`, :data:`Node.source` is set to
        `None` whenever any of the children has :data:`source == None`.
        """
        args_frozen = o.args_frozen
        args_frozen.update(args)
        if self.invalidate_source and 'source' in args_frozen:
            # If any child node has been invalidated, mark this node as invalid too
            if is_source_valid(args_frozen.get('source')):
                if any(isinstance(c, Node) and not is_source_valid(c) for c in flatten(children)):
                    args_frozen['source'] = args_frozen['source'].clone()
                    args_frozen['source'].invalidate(children=True)

        if self.inplace:
            # Updated nodes in place, if requested
            o._update(*children, **args_frozen)
            return o

        # Rebuild updated nodes by default
        return o._rebuild(*children, **args_frozen)

    def visit_object(self, o, **kwargs):
        """Return the object unchanged."""
        return o

    def _inject_tuple_mapping(self, o):
        """
        Utility method for one-to-many mappings to insert iterables for
        the replaced node into a tuple.
        """
        def _inject_handle(nodes, i, old, new):
            """Utility to replace `old` in `nodes[i:]` by `new`"""
            j = nodes.index(old, i)
            new = tuple(new)
            nodes = nodes[:j] + new + nodes[j+1:]
            return nodes, j + len(new)

        for k, handle in self.mapper.items():
            if is_iterable(k):
                o = replace_windowed(o, k, subs=handle)
            if k in o and is_iterable(handle):
                # Replace k by the iterable that is provided by handle
                o, i = _inject_handle(o, 0, k, handle)
                while k in o[i:]:
                    # Repeat in case there are multiple occurences of k in the tuple,
                    # but only look in the tail of the original tuple to avoid running
                    # into infinite recursion if k is included in the handle
                    o, i = _inject_handle(o, i, k, handle)
        return o

    def visit_tuple(self, o, **kwargs):
        """
        Visit all elements in a tuple, injecting any one-to-many mappings.
        """
        # First inject tuples that match at least a sub-set of current nodes
        o = self._inject_tuple_mapping(o)

        # Then recurse over the new nodes
        visited = tuple(self.visit(i, **kwargs) for i in o)

        # Strip empty sublists/subtuples or None entries
        return tuple(i for i in visited if i is not None and as_tuple(i))

    visit_list = visit_tuple

    def visit_Node(self, o, **kwargs):
        """
        Handler for :any:`Node` objects.

        It replaces :data:`o` by :data:`mapper[o]`, if it is in the mapper,
        otherwise visits all children before rebuilding the node.
        """
        if o in self.mapper:
            handle = self.mapper[o]
            if handle is None:
                # None -> drop /o/
                return None

            # For one-to-many mappings making sure this is not replaced again
            # as it has been inserted by visit_tuple already
            if not is_iterable(handle) or o not in handle:
                return handle._rebuild(**handle.args)

        rebuilt = tuple(self.visit(i, **kwargs) for i in o.children)
        return self._rebuild(o, rebuilt)

    def visit_ScopedNode(self, o, **kwargs):
        """
        Handler for :class:`ScopedNode` objects.

        It replaces :data:`o` by :data:`mapper[o]`, if it is in the mapper,
        otherwise its behaviour differs slightly from the default
        :meth:`visit_Node` as it rebuilds the node first, then visits all
        children and then updates in-place the rebuilt node.
        This is to make sure upwards-pointing references to this scope
        (such as :attr:`ScopedNode.parent` properties) can be updated correctly.

        Additionally, it passes down the currently active scope in :attr:`kwargs`
        when recursing to children.
        """
        if o in self.mapper:
            handle = self.mapper[o]
            if handle is None:
                # None -> drop /o/
                return None

            # For one-to-many mappings making sure this is not replaced again
            # as it has been inserted by visit_tuple already
            if not is_iterable(handle) or o not in handle:
                return handle._rebuild(**handle.args)

        # Rebuild the node (and update parent pointer if necessary)
        if self.rebuild_scopes:
            if 'scope' in kwargs:
                o = self._rebuild(o, o.children, parent=kwargs['scope'])
            else:
                o = self._rebuild(o, o.children)
        elif 'scope' in kwargs and kwargs['scope'] is not o.parent:
            o._update(parent=kwargs['scope'])

        # Recurse to children, passing down the scope
        kwargs['scope'] = o
        rebuilt = tuple(self.visit(i, **kwargs) for i in o.children)

        # Update in-place the node with rebuilt children
        o._update(*rebuilt)
        return o

    def visit(self, o, *args, **kwargs):
        """
        Apply this :class:`Transformer` to an IR tree.

        Parameters
        ----------
        o : :any:`Node`
            The node to visit.
        *args :
            Optional arguments to pass to the visit methods.
        **kwargs :
            Optional keyword arguments to pass to the visit methods.

        Returns
        -------
        :any:`Node` or tuple
            The rebuilt control flow tree.
        """
        obj = super().visit(o, *args, **kwargs)
        if isinstance(o, Node) and obj is not o:
            self.rebuilt[o] = obj
        return obj


class NestedTransformer(Transformer):
    """
    A :class:`Transformer` that applies replacements in a depth-first fashion.
    """

    def visit_tuple(self, o, **kwargs):
        """
        Visit all elements in a tuple, injecting any one-to-many mappings.
        """

        # Recurse to children first !
        visited = tuple(self.visit(i, **kwargs) for i in o)

        # Inject any matching sub-set of nodes into current tuple
        visited = self._inject_tuple_mapping(visited)

        # Strip empty sublists/subtuples or None entries
        return tuple(i for i in visited if i is not None and as_tuple(i))

    visit_list = visit_tuple

    def visit_Node(self, o, **kwargs):
        """
        Handler for :any:`Node` objects.

        It visits all children before applying the :data:`mapper`.
        """
        # Get the handle to bail out early if we drop the node
        handle = self.mapper.get(o, o)
        if handle is None:
            # None -> drop /o/
            return None

        # Recurse to children
        rebuilt = [self.visit(i, **kwargs) for i in o.children]

        # Rebuild the node with rebuilt children
        if is_iterable(handle):
            if not o.children:
                raise ValueError
            extended = [tuple(handle) + rebuilt[0]] + rebuilt[1:]
            if self.invalidate_source:
                return self._rebuild_without_source(o, extended)
            return o._rebuild(*extended, **o.args_frozen)
        return self._rebuild(handle, rebuilt)

    def visit_ScopedNode(self, o, **kwargs):
        """
        Handler for :class:`ScopedNode` objects.

        Its behaviour differs slightly from the default :meth:`visit_Node` as
        it rebuilds the node first, then visits all
        children and then updates in-place the rebuilt node.
        This is to make sure upwards-pointing references to this scope
        (such as :attr:`ScopedNode.parent` properties) can be updated correctly.

        Additionally, it passes down the currently active scope in :attr:`kwargs`
        when recursing to children.
        """
        # Get the handle to bail out early if we drop the node
        handle = self.mapper.get(o, o)
        if handle is None:
            # None -> drop /o/
            return None
        handle = self.mapper.get(o, o)

        # Rebuild the handle (and update parent pointer if necessary)
        if self.rebuild_scopes:
            if 'scope' in kwargs and isinstance(handle, ScopedNode):
                handle = self._rebuild(handle, handle.children, parent=kwargs['scope'])
            else:
                handle = self._rebuild(handle, handle.children)
        elif 'scope' in kwargs and isinstance(handle, ScopedNode) and kwargs['scope'] is not handle.parent:
            handle._update(parent=kwargs['scope'])

        # Rebuild children
        if is_iterable(handle):
            kwargs['scope'] = o
        elif isinstance(handle, ScopedNode):
            kwargs['scope'] = handle
        rebuilt = [self.visit(i, **kwargs) for i in o.children]

        # Update the node with rebuilt children
        if is_iterable(handle):
            if not o.children:
                raise ValueError
            extended = [tuple(handle) + rebuilt[0]] + rebuilt[1:]
            if self.invalidate_source:
                o._update(*extended, source=None)
            else:
                o._update(*extended)
            return o
        handle._update(*rebuilt)
        return handle


class MaskedTransformer(Transformer):
    """
    An enriched :class:`Transformer` that can selectively include or exclude
    parts of the tree.

    For that :class:`MaskedTransformer` is selectively switched on and
    off while traversing the tree. Nodes are only included in the new tree
    while it is "switched on".
    The transformer is switched on or off when it encounters nodes from
    :data:`start` or :data:`stop`, respectively. This can be used, e.g., to
    extract everything between two nodes, or to create a copy of the entire
    tree but without all nodes between two nodes.
    Multiple such ranges can be defined by providing more than one
    :data:`start` and :data:`stop` node, respectively.

    The sets :data:`start` and :data:`stop` are to be understood in a Pythonic
    way, i.e., :data:`start` nodes will be included in the result and
    :data:`stop` excluded.

    .. important::
       When recursing down a tree, any :any:`InternalNode` are only included
       in the tree if the :class:`MaskedTransformer` was switched on before
       visiting that :any:`InternalNode`. Importantly, this means the node is
       also not included if the transformer is switched on while traversing
       the internal node's body. In such a case, only the body nodes that are
       included are retained.

    Optionally as a variant, switching on can also be delayed until all nodes
    from :data:`start` have been encountered by setting
    :data:`require_all_start` to `True`.

    Optionally, traversal can be terminated early with :data:`greedy_stop`.
    If enabled, the :class:`MaskedTransformer` will stop completely to
    traverse the tree as soon as encountering a node from :data:`stop`.

    .. note::
       Enabling :data:`require_all_start` and :data:`greedy_stop` at the same
       time can be useful when you require the minimum number of nodes
       in-between multiple start and end nodes without knowing in which order
       they appear.

    .. note::
       :any:`MaskedTransformer` rebuilds also :class:`ScopedNode` by default
       (i.e., it calls the parent constructor with ``rebuild_scopes=True``).

    Parameters
    ----------
    start : (iterable of) :any:`Node`, optional
        Encountering a node from :data:`start` during traversal switches the
        :class:`MaskedTransformer` on and includes that node and all
        subsequently traversed nodes in the produced tree.
    stop : (iterable of) :any:`Node`, optional
        Encountering a node from :data:`stop` during traversal switches the
        :class:`MaskedTransformer` off and excludes that node and all
        subsequently traversed nodes from the produced tree.
    active : bool, optional
        Switch the :class:`MaskedTransformer` on at the beginning of the
        traversal. By default, it is switched on only after encountering a node
        from :data:`start`.
    require_all_start : bool, optional
        Switch the :class:`MaskedTransformer` on only after encountering `all`
        nodes from :data:`start`. By default, it is switched on after
        encountering `any` node from :data:`start`.
    greedy_stop : bool, optional
        Stop traversing the tree as soon as any node from :data:`stop` is
        encountered. By default, traversal continues but nodes are excluded
        from the new tree until a node from :data:`start` is encountered.
    **kwargs : optional
        Keyword arguments that are passed to the parent class constructor.
    """

    def __init__(self, start=None, stop=None, active=False,
                 require_all_start=False, greedy_stop=False, **kwargs):
        kwargs.setdefault('rebuild_scopes', True)
        super().__init__(**kwargs)

        self.start = set(as_tuple(start))
        self.stop = set(as_tuple(stop))
        self.active = active
        self.require_all_start = require_all_start
        self.greedy_stop = greedy_stop

    def visit(self, o, *args, **kwargs):
        # Vertical active status update
        if self.require_all_start:
            if o in self.start:
                # to record encountered nodes we remove them from the set of
                # start nodes and only if it is then empty we set active=True
                self.start.remove(o)
                self.active = self.active or not self.start
            else:
                self.active = self.active and o not in self.stop
        else:
            self.active = (self.active and o not in self.stop) or o in self.start
        if self.greedy_stop and o in self.stop:
            # to make sure that we don't include any following nodes we clear start
            self.start.clear()
            self.active = False
        return super().visit(o, *args, **kwargs)

    def visit_object(self, o, **kwargs):
        if kwargs['parent_active']:
            # this is not an IR node but usually an expression tree or similar
            # we need to retain this only if the "parent" IR node is active
            return o
        return None

    def visit_Node(self, o, **kwargs):
        if o in self.mapper:
            return super().visit_Node(o, **kwargs)

        # pass to children if this node is active
        kwargs['parent_active'] = self.active
        rebuilt = tuple(self.visit(i, **kwargs) for i in o.children)
        if kwargs['parent_active']:
            return self._rebuild(o, rebuilt)
        return tuple(i for i in rebuilt if i is not None) or None

    def visit_ScopedNode(self, o, **kwargs):
        if o in self.mapper:
            return super().visit_ScopedNode(o, **kwargs)

        # Rebuild the node (and update parent pointer if necessary)
        if self.rebuild_scopes:
            if 'scope' in kwargs:
                o = self._rebuild(o, o.children, parent=kwargs['scope'])
            else:
                o = self._rebuild(o, o.children)
        elif 'scope' in kwargs and kwargs['scope'] is not o.parent:
            o._update(parent=kwargs['scope'])

        # Recurse to children, passing down the scope and if this node is active
        kwargs['scope'] = o
        kwargs['parent_active'] = self.active
        rebuilt = tuple(self.visit(i, **kwargs) for i in o.children)

        # Update rebuilt node
        if kwargs['parent_active']:
            o._update(*rebuilt)
            return o
        return tuple(i for i in rebuilt if i is not None) or None


class NestedMaskedTransformer(MaskedTransformer):
    """
    A :class:`MaskedTransformer` that retains parents for children that
    are included in the produced tree.

    In contrast to :class:`MaskedTransformer`, any encountered
    :any:`InternalNode` are included in the new tree as long as any of its
    children are included.
    """

    # Handler for leaf nodes

    def visit_object(self, o, **kwargs):
        """
        Return the object unchanged.

        Note that we need to keep them here regardless of the transformer
        being active because this handler takes care of properties for
        inactive parents that may still be retained if other children switch
        on the transformer.
        """
        return o

    def visit_LeafNode(self, o, **kwargs):
        """
        Handler for :any:`LeafNode` that are included in the tree if the
        :class:`NestedMaskedTransformer` is active.
        """
        if o in self.mapper:
            return super().visit_Node(o, **kwargs)
        if not self.active:
            # because any section/scope nodes are treated separately we can
            # simply drop inactive nodes
            return None

        rebuilt = tuple(self.visit(i, **kwargs) for i in o.children)
        return self._rebuild(o, rebuilt)

    # Handler for block nodes

    def visit_InternalNode(self, o, **kwargs):
        """
        Handler for :any:`InternalNode` that are included in the tree as long
        as any :attr:`body` node is included.
        """
        if o in self.mapper:
            return super().visit_Node(o, **kwargs)

        rebuilt = [self.visit(i, **kwargs) for i in o.children]
        body_index = o._traversable.index('body')

        if rebuilt[body_index]:
            rebuilt[body_index] = as_tuple(flatten(rebuilt[body_index]))

        # check if body still exists, otherwise delete this node
        if not rebuilt[body_index]:
            return None
        return self._rebuild(o, rebuilt)

    def visit_Conditional(self, o, **kwargs):
        """
        Handler for :any:`Conditional` to account for the :attr:`else_body`.

        .. note::
           This removes the :any:`Conditional` if :attr:`body` is empty. In
           that case, :attr:`else_body` is returned (which can be empty, too).
        """
        if o in self.mapper:
            return super().visit(o, **kwargs)

        condition = self.visit(o.condition, **kwargs)
        body = as_tuple(flatten(as_tuple(self.visit(o.body, **kwargs))))
        else_body = as_tuple(flatten(as_tuple(self.visit(o.else_body, **kwargs))))

        if not body:
            return else_body

        has_elseif = o.has_elseif and bool(else_body) and isinstance(else_body[0], Conditional)
        return self._rebuild(o, tuple((condition,) + (body,) + (else_body,)), has_elseif=has_elseif)

    def visit_MultiConditional(self, o, **kwargs):
        """
        Handler for :any:`MultiConditional` to account for all bodies.

        .. note::
           This removes the :any:`MultiConditional` if all of the
           :attr:`bodies` are empty. In that case, :attr:`else_body` is
           returned (which can be empty, too).
        """
        if o in self.mapper:
            return super().visit(o, **kwargs)

        # need to make (value, body) pairs to track vanishing bodies
        expr = self.visit(o.expr, **kwargs)
        branches = tuple((self.visit(c, **kwargs), self.visit(b, **kwargs))
                         for c, b in zip(o.values, o.bodies))
        branches = tuple((c, b) for c, b in branches if flatten(as_tuple(b)))
        else_body = self.visit(o.else_body, **kwargs)

        # retain whatever is in the else body if all other branches are gone
        if not branches:
            return else_body

        # rebuild conditional with remaining branches
        values, bodies = zip(*branches)
        return self._rebuild(o, tuple((expr,) + (values,) + (bodies,) + (else_body,)))

    visit_TypeConditional = visit_MultiConditional
loki-ecmwf-0.3.6/loki/ir/visitor.py0000664000175000017500000001374615167130205017375 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Visitor base classes for traversing the IR
"""

import inspect

__all__ = ['GenericVisitor', 'Visitor']


class GenericVisitor:
    """
    A generic visitor class to traverse the IR tree.

    To define handlers, subclasses should define :data:`visit_Foo`
    methods for each class :data:`Foo` they want to handle.
    If a specific method for a class :data:`Foo` is not found, the MRO
    of the class is walked in order until a matching method is found.
    The method signature is:

    .. code-block::

       def visit_Foo(self, o, [*args, **kwargs]):
           pass

    The handler is responsible for visiting the children (if any) of
    the node :data:`o`.  :data:`*args` and :data:`**kwargs` may be
    used to pass information up and down the call stack.  You can also
    pass named keyword arguments, e.g.:

    .. code-block::

        def visit_Foo(self, o, parent=None, *args, **kwargs):
            pass
    """

    def __init__(self):
        handlers = {}
        # visit methods are spelt visit_Foo.
        prefix = "visit_"
        # Inspect the methods on this instance to find out which
        # handlers are defined.
        for (name, meth) in inspect.getmembers(self, predicate=inspect.ismethod):
            if not name.startswith(prefix):
                continue
            # Check the argument specification
            # Valid options are:
            #    visit_Foo(self, o, [*args, **kwargs])
            argspec = inspect.getfullargspec(meth)
            if len(argspec.args) < 2:
                raise RuntimeError("Visit method signature must be "
                                   "visit_Foo(self, o, [*args, **kwargs])")
            handlers[name[len(prefix):]] = meth
        self._handlers = handlers

    default_args = {}
    """
    Dict of default keyword arguments for the visitor. These are not used by
    default in :meth:`visit`, however, a caller may pass them explicitly to
    :meth:`visit` by accessing :attr:`default_args`. For example:

    .. code-block:: python

       v = FooVisitor()
       v.visit(node, **v.default_args)
    """

    @classmethod
    def default_retval(cls):
        """
        Default return value for handler methods.

        This method returns an object to use to populate return values.
        If your visitor combines values in a tree-walk, it may be useful to
        provide an object to combine the results into. :meth:`default_retval`
        may be defined by the visitor to be called to provide an empty object
        of appropriate type.

        Returns
        -------
        None
        """
        return None

    def lookup_method(self, instance):
        """Look up a handler method for a visitee.

        :param instance: The instance to look up a method for.
        """
        cls = instance.__class__
        try:
            # Do we have a method handler defined for this type name
            return self._handlers[cls.__name__]
        except KeyError:
            # No, walk the MRO.
            for klass in cls.mro()[1:]:
                entry = self._handlers.get(klass.__name__)
                if entry:
                    # Save it on this type name for faster lookup next time
                    self._handlers[cls.__name__] = entry
                    return entry
        raise RuntimeError(f'No handler found for class {cls.__name__}')

    def visit(self, o, *args, **kwargs):
        """
        Apply this :class:`Visitor` to an IR tree.

        Parameters
        ----------
        o : :any:`Node`
            The node to visit.
        *args :
            Optional arguments to pass to the visit methods.
        **kwargs :
            Optional keyword arguments to pass to the visit methods.
        """
        meth = self.lookup_method(o)
        return meth(o, *args, **kwargs)

    def visit_object(self, o, **kwargs):  # pylint: disable=unused-argument
        """
        Fallback method for objects that do not match any handler.

        Parameters
        ----------
        o :
            The object to visit.
        **kwargs :
            Optional keyword arguments passed to the visit methods.

        Returns
        -------
        :py:meth:`GenericVisitor.default_retval`
            The default return value.
        """
        return self.default_retval()


class Visitor(GenericVisitor):
    """
    The basic visitor-class for traversing Loki's control flow tree.

    It enhances the generic visitor class :class:`GenericVisitor` with the
    ability to recurse for all children of a :any:`Node`.
    """

    def visit_tuple(self, o, **kwargs):
        """
        Visit all elements in a tuple and return the results as a tuple.
        """
        return tuple(self.visit(c, **kwargs) for c in o)

    visit_list = visit_tuple

    def visit_Node(self, o, **kwargs):
        """
        Visit all children of a :any:`Node`.
        """
        return self.visit(o.children, **kwargs)

    @staticmethod
    def reuse(o, *args, **kwargs):  # pylint: disable=unused-argument
        """A visit method to reuse a node, ignoring children."""
        return o

    def maybe_rebuild(self, o, *args, **kwargs):
        """A visit method that rebuilds nodes if their children have changed."""
        ops, okwargs = o.operands()
        new_ops = [self.visit(op, *args, **kwargs) for op in ops]
        if all(a is b for a, b in zip(ops, new_ops)):
            return o
        return o._rebuild(*new_ops, **okwargs)

    def always_rebuild(self, o, *args, **kwargs):
        """A visit method that always rebuilds nodes."""
        ops, okwargs = o.operands()
        new_ops = [self.visit(op, *args, **kwargs) for op in ops]
        return o._rebuild(*new_ops, **okwargs)
loki-ecmwf-0.3.6/loki/ir/pragma_utils.py0000664000175000017500000007067215167130205020366 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import re
from collections import defaultdict
from contextlib import contextmanager
from codetiming import Timer

#from loki.backend.fgen import fgen
from loki.ir.nodes import VariableDeclaration, Pragma, PragmaRegion
from loki.ir.find import FindNodes
from loki.ir.transformer import Transformer
from loki.ir.visitor import Visitor
from loki.tools.util import as_tuple, replace_windowed
from loki.logging import debug, warning


__all__ = [
    'is_loki_pragma', 'get_pragma_parameters', 'process_dimension_pragmas',
    'attach_pragmas', 'detach_pragmas',
    'pragmas_attached', 'attach_pragma_regions', 'detach_pragma_regions',
    'pragma_regions_attached', 'PragmaAttacher', 'PragmaDetacher',
    'get_pragma_command_and_parameters', 'SubstitutePragmaStrings'
]


def is_loki_pragma(pragma, starts_with=None):
    """
    Checks for a pragma annotation and, if it exists, for the ``loki`` keyword.
    Optionally, the pragma content is tested for a specific start.

    Parameters
    ----------
    pragma : :any:`Pragma` or `list`/`tuple` of :any:`Pragma` or `None`
        the pragma or list of pragmas to check.
    starts_with : str, optional
        the keyword the pragma content must start with.
    """
    pragma = as_tuple(pragma)
    if not pragma:
        return False
    loki_pragmas = [p for p in pragma if p.keyword.lower() == 'loki']
    if not loki_pragmas:
        return False
    if starts_with is not None and not any(p.content and p.content.startswith(starts_with) for p in loki_pragmas):
        return False
    return True


class PragmaParameters:
    """
    Utility class to parse strings for parameters in the form ``[()]`` and
    return them as a map ``{:  or None}``.
    """

    _pattern_opening_parenthesis = re.compile(r'\(')
    _pattern_closing_parenthesis = re.compile(r'\)')
    _pattern_quoted_string = re.compile(r'(?:\'.*?\')|(?:".*?")')

    @classmethod
    def find(cls, string):
        """
        Find parameters in the form ``[()]`` and
        return them as a map ``{:  or None}``.

        .. note::
            This allows nested parenthesis by matching pairs of
            parentheses starting at the end by pushing and popping
            from a stack.
        """
        string = cls._pattern_quoted_string.sub('', string)
        if not string.strip():
            # Early bail-out on empty strings
            return {}

        p_open = [match.start() for match in cls._pattern_opening_parenthesis.finditer(string)]
        p_close = [match.start() for match in cls._pattern_closing_parenthesis.finditer(string)]
        assert len(p_open) == len(p_close)

        def _match_spans(open_, close_):
            # We match pairs of parentheses starting at the end by pushing and popping from a stack.
            # Whenever the stack runs out, we have fully resolved a set of (nested) parenthesis and
            # record the corresponding span
            if not close_:
                return []
            spans = []
            stack = [close_.pop()]
            while open_:
                if not close_ or open_[-1] > close_[-1]:
                    assert stack
                    start = open_.pop()
                    end = stack.pop()
                    if not stack:
                        spans.append((start, end))
                else:
                    stack.append(close_.pop())
            assert not (stack or open_ or close_)
            return spans

        p_spans = _match_spans(p_open, p_close)
        spans = []
        while p_spans:
            spans.append(p_spans.pop())
        if p_spans:
            spans += p_spans[::-1]

        # Build the list of parameters from the matched spans
        parameters = defaultdict(list)
        for i, span in enumerate(spans):
            keys = string[spans[i-1][1]+1 if i>=1 else 0:span[0]].strip().split(' ')
            for key in keys[:-1]:
                if key:
                    parameters[key].append(None)
            parameters[keys[-1]].append(string[span[0]+1:span[1]])

        # Tail handling (including strings without any matched spans)
        tail_span = spans[-1][1] + 1 if spans else 0
        for key in string[tail_span:].strip().split(' '):
            if key != '':
                parameters[key].append(None)
        parameters = {k: v if len(v) > 1 else v[0] for k, v in parameters.items()}
        return parameters

def get_pragma_command_and_parameters(pragma, only_loki_pragmas=True):
    """
    Parse a pragma in the form

    ``!$loki [end] command [param] [param_with_val(val)]``

    and return ``command, {param: None, param_with_val: val}``

    Parameters
    ----------
    pragma : :any:`Pragma`
        the pragma to parse and return parameters

    Returns
    -------
    tuple(str, dict) :
        tuple being (command, parameters as dictionary)
    """
    pragma_parameters = list(get_pragma_parameters(pragma, only_loki_pragmas=only_loki_pragmas).items())
    if not pragma_parameters:
        return None, None
    if pragma_parameters[0][0] in ['end', 'exit']:
        if len(pragma_parameters) < 2:
            debug('get_pragma_command_and_parameters: Failed to match end-command in pragma {pragma}')
            return None, None
        pragma_parameters = [
            (f'{pragma_parameters[0][0]}-{pragma_parameters[1][0]}', None),
            *pragma_parameters[2:]
        ]
    return pragma_parameters[0][0], dict(pragma_parameters[1:])

def get_pragma_parameters(pragma, starts_with=None, only_loki_pragmas=True):
    """
    Parse the pragma content for parameters in the form ``[()]`` and
    return them as a map ``{:  or None}``.

    Optionally, look only at the pragma with the given keyword at the beginning.

    Note that if multiple pragma are given as a tuple/list, arguments with the same
    name will overwrite previous definitions.

    Parameters
    ----------
    pragma : :any:`Pragma` or `list`/`tuple` of :any:`Pragma` or `None`
        the pragma or list of pragmas to check.
    starts_with : str, optional
        the keyword the pragma content should start with.
    only_loki_pragmas : bool, optional
        restrict parameter extraction to ``loki`` pragmas only.

    Returns
    -------
    dict :
        Mapping of parameters ``{:  or }`` with the values being a list
        when multiple entries have the same key
    """
    pragma_parameters = PragmaParameters()
    pragma = as_tuple(pragma)
    parameters = defaultdict(list)
    for p in pragma:
        if only_loki_pragmas and p.keyword.lower() != 'loki':
            continue
        content = p.content or ''
        # Remove any line-continuation markers
        content = content.replace('&', '')
        if starts_with is not None:
            if not content.lower().startswith(starts_with.lower()):
                continue
            content = content[len(starts_with):]
        parameter = pragma_parameters.find(content)
        for key in parameter:
            parameters[key].append(parameter[key])
    parameters = {k: v if len(v) > 1 else v[0] for k, v in parameters.items()}
    return parameters

def process_dimension_pragmas(ir, scope=None):
    """
    Process any ``!$loki dimension`` pragmas to override deferred dimensions

    Note that this assumes :any:`attach_pragmas` has been run on :data:`ir` to
    attach any pragmas to the :any:`VariableDeclaration` nodes.

    Parameters
    ----------
    ir : :any:`Node`
        Root node of the (section of the) internal representation to process
    """
    from loki.expression.parser import parse_expr  # pylint: disable=import-outside-toplevel

    for decl in FindNodes(VariableDeclaration).visit(ir):
        if is_loki_pragma(decl.pragma, starts_with='dimension'):
            for v in decl.symbols:
                # Found dimension override for variable
                dims = get_pragma_parameters(decl.pragma)['dimension']
                dims = [d.strip() for d in dims.split(',')]
                # parse each dimension
                shape = tuple(parse_expr(d, scope=scope) for d in dims)
                # update symbol table
                v.scope.symbol_attrs[v.name] = v.type.clone(shape=shape)
    return ir


class PragmaAttacher(Visitor):
    """
    Utility visitor that finds pragmas preceding (or optionally also
    trailing) nodes of given types and attaches them to these nodes as
    ``pragma`` property.

    Note that this operates by updating (instead of rebuilding) the relevant
    nodes, thus only nodes to which pragmas are attached get modified and
    the tree as a whole is not modified if no pragmas are found. This means
    existing node references should remain valid.

    .. note::
        When using :data:`attach_pragma_post` and two nodes qualifying according to
        :data:`node_type` are separated only by :any:`Pragma` nodes inbetween, it
        is not possible to decide to which node these pragmas belong. In such cases,
        they are attached to the second node as ``pragma`` property takes precedence.
        Such situations can only be resolved by full knowledge about the pragma
        language specification (_way_ out of scope) or modifying the original source,
        e.g. by inserting a comment between the relevant pragmas.

    Parameters
    ----------
    node_type :
        the IR node type (or a list of them) to attach pragmas to.
    attach_pragma_post : bool, optional
        look for pragmas after the node, too, and attach as ``pragma_post`` if applicable.

    """

    def __init__(self, node_type, attach_pragma_post=True):
        super().__init__()
        self.node_type = as_tuple(node_type)
        self.attach_pragma_post = attach_pragma_post

    def visit_tuple(self, o, **kwargs):
        pragmas = []
        updated = []
        for i in o:
            if isinstance(i, Pragma):
                # Collect pragmas, anticipating a possible node to attach to
                pragmas += [i]
            else:
                # Recurse first
                i = self.visit(i, **kwargs)
                if pragmas:
                    if isinstance(i, self.node_type):
                        # Found a node of given type: attach pragmas
                        i._update(pragma=as_tuple(pragmas))
                    elif (
                          self.attach_pragma_post and updated and
                          isinstance(updated[-1], self.node_type) and
                          hasattr(updated[-1], 'pragma_post')
                    ):
                        # Encountered a different node but have some pragmas: attach to last
                        # node as pragma_post if type matches
                        updated[-1]._update(pragma_post=as_tuple(pragmas))
                    else:
                        # Not attaching pragmas anywhere: re-insert into list
                        updated += pragmas
                    pragmas = []
                updated += [i]
        if self.attach_pragma_post and pragmas:
            # Take care of leftover pragmas
            if updated and isinstance(updated[-1], self.node_type):
                updated[-1]._update(pragma_post=as_tuple(pragmas))
                pragmas = []
        return as_tuple(updated + pragmas)

    visit_list = visit_tuple

    def visit_Node(self, o, **kwargs):
        children = tuple(self.visit(i, **kwargs) for i in o.children)
        # Modify the node in-place instead of rebuilding it to leave existing references
        # to IR nodes intact
        o._update(*children)
        return o

    def visit_object(self, o, **kwargs):
        # Any other objects (e.g., expression trees) are to be left untouched
        return o


class PragmaDetacher(Visitor):
    """
    Utility visitor that detaches inlined pragmas from nodes of given types
    and inserts them before/after the nodes into the IR.

    Note that this operates by updating (instead of rebuilding) the relevant
    nodes, thus only nodes to which pragmas are attached get modified and
    the tree as a whole is not modified if no pragmas are found. This means
    existing node references should remain valid.

    Parameters
    ----------
    node_type :
        the IR node type (or a list of them) to detach pragmas from.
    detach_pragma_post : bool, optional
        detach ``pragma_post`` properties, if applicable.
    """

    def __init__(self, node_type, detach_pragma_post=False):
        super().__init__()
        self.node_type = as_tuple(node_type)
        self.detach_pragma_post = detach_pragma_post

    def visit_tuple(self, o, **kwargs):
        updated = ()
        for i in o:
            i = self.visit(i, **kwargs)
            if isinstance(i, self.node_type) and getattr(i, 'pragma', None):
                # Pragmas need to go before the node
                updated += as_tuple(i.pragma)
                # Modify the node in-place to leave existing references intact
                i._update(pragma=None)
            # Insert node into the tuple
            updated += (i,)
            if self.detach_pragma_post and isinstance(i, self.node_type) and getattr(i, 'pragma_post', None):
                # pragma_post need to go after the node
                updated += as_tuple(i.pragma_post)
                # Modify the node in-place to leave existing references intact
                i._update(pragma_post=None)
        return updated

    visit_list = visit_tuple

    def visit_Node(self, o, **kwargs):
        children = tuple(self.visit(i, **kwargs) for i in o.children)
        # Modify the node in-place instead of rebuilding it to leave existing references
        # to IR nodes intact
        o._update(*children)
        return o

    def visit_object(self, o, **kwargs):
        # Any other objects (e.g., expression trees) are to be left untouched
        return o


def attach_pragmas(ir, node_type, attach_pragma_post=True):
    """
    Find pragmas and merge them onto the given node type(s).

    This can be done for all IR nodes that have a ``pragma`` property
    (:any:`VariableDeclaration`, :any:`Loop`, :any:`WhileLoop`,
    :any:`CallStatement`).
    Optionally, attaching pragmas after nodes as ``pragma_post`` can be
    disabled by setting :data:`attach_pragma_post` to `False`
    (relevant only for :any:`Loop` and :any:`WhileLoop`).

    .. note::
        Pragmas are not discovered by :any:`FindNodes` while attached to IR nodes.

    This is implemented using :any:`PragmaAttacher`. Therefore, the IR
    is not rebuilt but updated and existing references should remain valid.

    Parameters
    ----------
    ir : :any:`Node`
        the root of (a section of the) intermediate representation in which
        pragmas are to be attached.
    node_type : list
        the (list of) :any:`Node` types pragmas should be attached to.
    attach_pragma_post : bool, optional
        process ``pragma_post`` attachments.
    """
    return PragmaAttacher(node_type, attach_pragma_post=attach_pragma_post).visit(ir)


def detach_pragmas(ir, node_type, detach_pragma_post=True):
    """
    Revert the inlining of pragmas, e.g. as done by :any:`attach_pragmas`.

    This can be done for all IR nodes that have a ``pragma`` property
    (:class:``Declaration``, :class:``Loop``, :class:``WhileLoop`,
    :class:``CallStatement``).
    Optionally, detaching of pragmas after nodes (for nodes with a
    ``pragma_post`` property) can be disabled by setting
    :data:`detach_pragma_post` to `False` (relevant only for :any:`Loop`
    and :any:`WhileLoop`).

    This is implemented using :any:`PragmaDetacher`. Therefore, the IR
    is not rebuilt but updated and existing references should remain valid.

    Parameters
    ----------
    ir : :any:`Node`
        the root node of the (section of the) intermediate representation
        in which pragmas are to be detached.
    node_type :
        the (list of) :any:`Node` types that pragmas should be detached from.
    detach_pragma_post: bool, optional
        process ``pragma_post`` attachments.
    """
    return PragmaDetacher(node_type, detach_pragma_post=detach_pragma_post).visit(ir)


@contextmanager
def pragmas_attached(module_or_routine, node_type, attach_pragma_post=True):
    """
    Create a context in which pragmas preceding nodes of given type(s) inside
    the module's or routine's IR are attached to these nodes.

    This can be done for all IR nodes that have a ``pragma`` property
    (:any:`VariableDeclaration`, :any:`ProcedureDeclaration`, :any:`Loop`,
    :any:`WhileLoop`, :any:`CallStatement`). Inside the created context,
    attached pragmas are no longer standalone IR nodes but accessible via the
    corresponding node's ``pragma`` property.

    Pragmas after nodes are attached as ``pragma_post``, which can be disabled
    by setting :data:`attach_pragma_post` to `False` (for :any:`Loop` and
    :any:`WhileLoop`).

    .. note::
        Pragmas are not discovered by :any:`FindNodes` while attached to IR nodes.

    When leaving the context all pragmas for nodes of the given type
    are detached, irrespective of whether they had already been attached or not
    when entering the context.

    .. note::
        Pragma attachment is only done for the object itself (i.e. its spec and
        body), not for any contained subroutines.

    This is implemented using :any:`PragmaAttacher` and
    :any:`PragmaDetacher`, respectively. Therefore, the IR is not rebuilt
    but updated and existing references should remain valid when entering the
    context and stay valid beyond exiting the context.

    Example:

    .. code-block:: python

        loop_of_interest = None
        with pragmas_attached(routine, Loop):
            for loop in FindNodes(Loop).visit(routine.body):
                if is_loki_pragma(loop.pragma, starts_with='foobar'):
                    loop_of_interest = loop
                    break
        # Do something with that loop
        loop_body = loop_of_interest.body
        # Note that loop_body.pragma == None!

    Parameters
    ----------
    module_or_routine : :any:`Module` or :any:`Subroutine`
        the program unit in which pragmas are to be inlined.
    node_type :
        the (list of) :any:`Node` types, that pragmas should be
        attached to.
    attach_pragma_post : bool, optional
        process ``pragma_post`` attachments.
    """
    if hasattr(module_or_routine, 'spec'):
        module_or_routine.spec = attach_pragmas(module_or_routine.spec, node_type,
                                                attach_pragma_post=attach_pragma_post)
    if hasattr(module_or_routine, 'body'):
        module_or_routine.body = attach_pragmas(module_or_routine.body, node_type,
                                                attach_pragma_post=attach_pragma_post)
    try:
        yield module_or_routine
    finally:
        if hasattr(module_or_routine, 'spec'):
            module_or_routine.spec = detach_pragmas(module_or_routine.spec, node_type,
                                                    detach_pragma_post=attach_pragma_post)
        if hasattr(module_or_routine, 'body'):
            module_or_routine.body = detach_pragmas(module_or_routine.body, node_type,
                                                    detach_pragma_post=attach_pragma_post)


def get_matching_region_pragmas(pragmas):
    """
    Given a list of :any:`Pragma` objects return a list of matching pairs
    that define a pragma region.

    Matching pragma pairs are assumed to be of the form
    ``!$ `` and ``!$ end ``.
    """

    def _matches_starting_pragma(start, p):
        """ Definition of which pragmas match """
        stok = start.content.lower().split(' ')
        ptok = p.content.lower().split(' ')
        if 'end' not in ptok:
            return False
        if not start.keyword.lower() == p.keyword.lower():
            return False
        idx = ptok.index('end')
        return ptok[idx+1] == stok[idx]

    matches = []
    stack = []
    for i, p in enumerate(pragmas):
        if 'end' not in p.content.lower().split(' '):
            # If we encounter one that does have a match, stack it
            if any(_matches_starting_pragma(p, p2) for p2 in pragmas[i:]):
                stack.append(p)

        elif 'end' in p.content.lower().split(' ') and stack:
            # If we and end that matches our last stacked, keep it!
            if _matches_starting_pragma(stack[-1], p):
                p1 = stack.pop()
                matches.append((p1, p))

    return matches


class PragmaRegionAttacher(Transformer):
    """
    Utility transformer that inserts :any:`PragmaRegion` objects to
    mark code section between matching :any:`Pragma` pairs.

    Matching pragma pairs are assumed to be of the form
    ``!$ `` and ``!$ end ``.

    The matching of pragma pairs only happens if the matching pragmas
    are stored within the same tuple, or in other words at the same
    depth of the IR tree. Ending a pragma region in a different
    nesting depth, eg. inside a loop body, will result in a warning
    and no region object being inserted into the IR tree.

    Parameters
    ----------
    pragma_pairs : tuple of tuple of :any:`Pragma`
        Tuple of 2-tuples of matching pragma pairs
    """

    def __init__(self, pragma_pairs=None, **kwargs):
        self.pragma_pairs = pragma_pairs

        super().__init__(**kwargs)

    def visit_tuple(self, o, **kwargs):
        """ Replace pragma-body-end in tuples """
        for start, stop in self.pragma_pairs:
            if start in o:
                # If a pair does not live in the same tuple we have a problem.
                if stop not in o:
                    warning(f'[Loki::IR] Cannot find matching end for pragma {start} at same IR level!')
                    continue

                # Create the PragmaRegion node and replace in tuple
                idx_start = o.index(start)
                idx_stop = o.index(stop)
                region = PragmaRegion(
                    body=o[idx_start+1:idx_stop], pragma=start, pragma_post=stop
                )
                o = o[:idx_start] + (region,) + o[idx_stop+1:]

        # Then recurse over the new nodes
        visited = tuple(self.visit(i, **kwargs) for i in o)

        # Strip empty sublists/subtuples or None entries
        return tuple(i for i in visited if i is not None and as_tuple(i))

    visit_list = visit_tuple


@Timer(logger=debug, text=lambda s: f'[Loki::IR] Executed attach_pragma_regions in {s:.2f}s')
def attach_pragma_regions(ir, keyword=None):
    """
    Create :any:`PragmaRegion` node objects for all matching pairs of
    region pragmas.

    Matching pragma pairs are assumed to be of the form
    ``!$ `` and ``!$ end ``.

    The defining :any:`Pragma` nodes are accessible via the ``pragma``
    and ``pragma_post`` attributes of the region object. Insertion
    is performed in-place, without rebuilding any IR nodes.

    Parameters
    ----------
    ir : :any:`Node`
        The IR root node of the tree in which pragma regions are to be formed
    keyword : str, optional
        Only create pragma regions for pragmas with the given keyword
    """
    pragmas = FindNodes(Pragma).visit(ir)
    if keyword:
        pragmas = [pragma for pragma in pragmas if pragma.keyword.lower() == keyword.lower()]
    pragma_pairs = get_matching_region_pragmas(pragmas)

    return PragmaRegionAttacher(pragma_pairs=pragma_pairs, inplace=True).visit(ir)


class PragmaRegionDetacher(Transformer):
    """
    Remove any :any:`PragmaRegion` node objects and insert the tuple
    of ``(r.pragma, r.body, r.pragma_post)`` in the enclosing tuple.
    """

    def visit_tuple(self, o, **kwargs):
        """ Unpack :any:`PragmaRegion` objects and insert in current tuple """

        # We unpack regions here to avoid creating nested tuples, or
        # forcing general tuple-flattening, which can affect other
        # nodes types.
        regions = tuple(n for n in o if isinstance(n, PragmaRegion))
        for r in regions:
            handle = (r.pragma,) + self.visit(r.body, **kwargs) + (r.pragma_post,)
            o = replace_windowed(o, r, subs=handle)

        # First recurse over the new nodes
        visited = tuple(self.visit(i, **kwargs) for i in o)

        # Strip empty sublists/subtuples or None entries
        return tuple(i for i in visited if i is not None and as_tuple(i))

    visit_list = visit_tuple


@Timer(logger=debug, text=lambda s: f'[Loki::IR] Executed detach_pragma_regions in {s:.2f}s')
def detach_pragma_regions(ir):
    """
    Remove any :any:`PragmaRegion` node objects and replace each with a
    tuple of ``(r.pragma, r.body, r.pragma_post)``, where ``r`` is the
    :any:`PragmaRegion` node object.

    All replacements are performed in-place, without rebuilding any IR
    nodes.
    """

    return PragmaRegionDetacher(inplace=True).visit(ir)


@contextmanager
def pragma_regions_attached(module_or_routine, keyword=None):
    """
    Create a context in which :any:`PragmaRegion` node objects are
    inserted into the IR to define code regions marked by matching
    pairs of pragmas.

    Matching pragma pairs are assumed to be of the form
    ``!$ `` and ``!$ end ``.

    In the resulting context ``FindNodes(PragmaRegion).visit(ir)`` can
    be used to select code regions marked by pragma pairs as node
    objects.

    The defining :any:`Pragma` nodes are accessible via the ``pragma``
    and ``pragma_post`` attributes of the region object. Importantly,
    Pragmas are not discovered by :any:`FindNodes` while attached
    to IR nodes.

    When leaving the context all :any:`PragmaRegion` objects are replaced
    with a tuple of ``(r.pragma, r.body, r.pragma_post)``, where ``r``
    is the :any:`PragmaRegion` node object.

    Throughout the setup and teardown of the context IR nodes are only
    updated, never rebuild, meaning node mappings from inside the
    context are valid outside of it.

    Example:

    .. code-block:: python

        with pragma_regions_attached(routine):
            for region in FindNodes(PragmaRegion).visit(routine.body):
                if is_loki_pragma(region.pragma, starts_with='foobar'):
                    

    Parameters
    ----------
    module_or_routine : :any:`Module` or :any:`Subroutine` in
        which :any:`PragmaRegion` objects are to be inserted.
    keyword : str, optional
        Limit pragma attachment to pragmas with the given keyword
    """
    if hasattr(module_or_routine, 'spec'):
        module_or_routine.spec = attach_pragma_regions(module_or_routine.spec, keyword=keyword)
    if hasattr(module_or_routine, 'body'):
        module_or_routine.body = attach_pragma_regions(module_or_routine.body, keyword=keyword)

    try:
        yield module_or_routine
    finally:
        if hasattr(module_or_routine, 'spec'):
            module_or_routine.spec = detach_pragma_regions(module_or_routine.spec)
        if hasattr(module_or_routine, 'body'):
            module_or_routine.body = detach_pragma_regions(module_or_routine.body)


class SubstitutePragmaStrings(Transformer):
    """
    A :any:`Transformer` that updates the content of a :any:`Pragma`
    using the provided string map. The string search and replace is
    based on literal matching and does not support regex patterns.
    """

    _sanitise_map = {
        r'(': r'\(',
        r')': r'\)',
        r']': r'\]',
        r'[': r'\[',
        r'.': r'\.',
        r'+': r'\+',
        r'?': r'\?',
        r'*': r'\*',
        r':': r'\:',
        r'%': r'\%'
    }

    def __init__(self, str_map):
        super().__init__(inplace=True)

        # Remove continuation markers for substitutions across lines
        self.str_map = {re.compile(r'\&', flags=re.IGNORECASE): ''}

        # Sanitise str_map so that regex performs a literal search rather
        # than pattern matching
        for k, v in str_map.items():
            for _k, _v, in self._sanitise_map.items():
                k = k.replace(_k, _v)
            self.str_map.update({re.compile(k, flags=re.IGNORECASE): v})

    def visit_Pragma(self, o, **kwargs):
        """
        Update the content of a pragma using the given str map.
        """
        #pylint: disable=unused-argument

        _content = o.content
        for k, v in self.str_map.items():
            _content = k.sub(v, _content)

        o._update(content=_content)
        return o
loki-ecmwf-0.3.6/loki/ir/ir_graph.py0000664000175000017500000003046015167130205017461 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
GraphCollector classes for IR
"""

from itertools import chain
from codetiming import Timer

try:
    from graphviz import Digraph, nohtml

    HAVE_IR_GRAPH = True
    """Indicate wheater the graphviz package is available."""
except ImportError:
    HAVE_IR_GRAPH = False

from loki.tools import JoinableStringList, is_iterable, as_tuple
from loki.ir.visitor import Visitor


__all__ = ["HAVE_IR_GRAPH", "GraphCollector", "ir_graph"]


class GraphCollector(Visitor):
    """
    Convert a given IR tree to a node and edge list via the visit mechanism.

    This serves as base class for backends and provides a number of helpful
    routines that ease implementing automatic recursion and line wrapping.
    It is adapted from the Stringifier in "pprint.py". It doubles as a means
    to produce a human readable graph representation of the IR, which is
    useful for debugging purposes and first visualization.

    Parameters
    ----------
    linewidth : int, optional
        The line width limit after which to break a line.
    symgen : str, optional
        A function handle that accepts a :any:`pymbolic.primitives.Expression`
        and produces a string representation for that.
    show_comments : bool, optional, default: False
        Whether to show comments in the output
    show_expressions : bool, optional, default: False
        Whether to further expand expressions in the output
    """

    def __init__(
        self, show_comments=False, show_expressions=False, linewidth=40, symgen=str
    ):
        super().__init__()
        self.linewidth = linewidth
        self._symgen = symgen
        self._id = 0
        self._id_map = {}
        self.show_comments = show_comments
        self.show_expressions = show_expressions

    @property
    def symgen(self):
        """
        Formatter for expressions.
        """
        return self._symgen

    def join_items(self, items, sep=", ", separable=True):
        """
        Concatenate a list of items into :any:`JoinableStringList`.

        The return value can be passed to :meth:`format_line` or
        :meth:`format_node` or converted to a string with `str`, using
        the :any:`JoinableStringList` as an argument.
        Upon expansion, lines will be wrapped automatically to stay within
        the linewidth limit.

        Parameters
        ----------
        items : list
            The list of strings to be joined.
        sep : str, optional
            The separator to be inserted between items.
        separable : bool, optional
            Allow line breaks between individual :data:`items`.

        Returns
        -------
        :any:`JoinableStringList`
        """
        return JoinableStringList(
            items,
            sep=sep,
            width=self.linewidth,
            cont="\n",
            separable=separable,
        )

    def format_node(self, name, *items):
        """
        Default format for a node.

        Creates a string of the form ````.
        """
        content = ""
        if items:
            content = self.format_line("<", name, " ", self.join_items(items), ">")
        else:
            content = self.format_line("<", name, ">")

        # disregard all quotes to ensure nice graphviz label behaviour
        return content.replace('"', "")

    def format_line(self, *items, comment=None, no_wrap=False):
        """
        Format a line by concatenating all items and applying indentation while observing
        the allowed line width limit.

        Note that the provided comment will simply be extended to the line and no line
        width limit will be enforced for that.

        Parameters
        ----------
        items : list
            The items to be put on that line.
        comment : str
            An optional inline comment to be put at the end of the line.
        no_wrap : bool
            Disable line wrapping.

        Returns
        -------
        str the string of the current line, potentially including line breaks if
                 required to observe the line width limit.
        """

        if no_wrap:
            # Simply concatenate items and extend the comment
            line = "".join(str(item) for item in items)
        else:
            # Use join_items to concatenate items
            line = str(self.join_items(items, sep=""))
        if comment:
            return line + comment
        return line

    def visit_all(self, item, *args, **kwargs):
        """
        Convenience function to call :meth:`visit` for all given arguments.

        If only a single argument is given that is iterable,
        :meth:`visit` is called on all of its elements instead.
        """
        if is_iterable(item) and not args:
            return chain.from_iterable(
                as_tuple(self.visit(i, **kwargs) for i in item if i is not None)
            )
        return list(
            chain.from_iterable(
                as_tuple(
                    self.visit(i, **kwargs) for i in [item, *args] if i is not None
                )
            )
        )

    def __add_node(self, node, **kwargs):
        """
        Adds a node to the graphical representation of the IR. Utilizes the
        formatting provided by :meth:`format_node`.

        Parameters
        ----------
        node: :any: `Node` object
        kwargs["shape"]: str, optional (default: "oval")
        kwargs["label"]: str, optional (default: format_node(repr(node)))
        kwargs["parent"]: :any: `Node` object, optional (default: None)
            If not available no edge is drawn.

        Returns
        -------
        list[tuple[dict[str,str], dict[str,str]]]]
            A list of a tuple of a node and potentially a edge information
        """
        label = kwargs.get("label", "")

        if label == "":
            label = self.format_node(repr(node))

        try:
            live_symbols = "live: [" + ", ".join(
                str(symbol) for symbol in node.live_symbols
            )
            defines_symbols = "defines: [" + ", ".join(
                str(symbol) for symbol in node.defines_symbols
            )
            uses_symbols = "uses: [" + ", ".join(
                str(symbol) for symbol in node.uses_symbols
            )
            label = self.format_line(
                label,
                "\n",
                live_symbols,
                "], ",
                defines_symbols,
                "], ",
                uses_symbols,
                "]",
            )
        except (RuntimeError, KeyError, AttributeError) as _:
            pass

        shape = kwargs.get("shape", "oval")

        node_key = str(id(node))
        if node_key not in self._id_map:
            self._id_map[node_key] = str(self._id)
            self._id += 1

        node_info = {
            "name": str(self._id_map[node_key]),
            "label": nohtml(str(label)),
            "shape": str(shape),
        }

        parent = kwargs.get("parent")
        edge_info = {}
        if parent:
            parent_id = self._id_map[str(id(parent))]
            child_id = self._id_map[str(id(node))]
            edge_info = {"tail_name": str(parent_id), "head_name": str(child_id)}

        return [(node_info, edge_info)]

    # Handler for outer objects
    def visit_Module(self, o, **kwargs):
        """
        Add a :any:`Module`, mark parent node and visit all "spec" and "subroutine" nodes.

        Returns
        -------
        list[tuple[dict[str,str], dict[str,str]]]]
            An extended list of tuples of a node and potentially a edge information
        """
        node_edge_info = self.__add_node(o, **kwargs)
        kwargs["parent"] = o

        node_edge_info.extend(self.visit(o.spec, **kwargs))
        node_edge_info.extend(self.visit_all(o.contains, **kwargs))

        return node_edge_info

    def visit_Subroutine(self, o, **kwargs):
        """
        Add a :any:`Subroutine`, mark parent node and visit all "docstring", "spec", "body", "members" nodes.

        Returns
        -------
        list[tuple[dict[str,str], dict[str,str]]]]
            An extended list of tuples of a node and potentially a edge information
        """
        node_edge_info = self.__add_node(o, **kwargs)
        kwargs["parent"] = o

        node_edge_info.extend(self.visit(o.docstring, **kwargs))
        node_edge_info.extend(self.visit(o.spec, **kwargs))
        node_edge_info.extend(self.visit(o.body, **kwargs))
        node_edge_info.extend(self.visit_all(o.contains, **kwargs))

        return node_edge_info

    # Handler for AST base nodes
    def visit_Comment(self, o, **kwargs):
        """
        Enables turning off comments.

        Returns
        -------
        list[tuple[dict[str,str], dict[str,str]]]]
            An extended list of tuples of a node and potentially a edge information, or list of nothing.
        """
        if self.show_comments:
            return self.visit_Node(o, **kwargs)
        return []

    visit_CommentBlock = visit_Comment

    def visit_Node(self, o, **kwargs):
        """
        Add a :any:`Node`, mark parent and visit all children.

        Returns
        -------
        list[tuple[dict[str,str], dict[str,str]]]]
            An extended list of tuples of a node and potentially a edge information
        """
        node_edge_info = self.__add_node(o, **kwargs)
        kwargs["parent"] = o

        node_edge_info.extend(self.visit_all(o.children, **kwargs))
        return node_edge_info

    def visit_Expression(self, o, **kwargs):
        """
        Dispatch routine to add nodes utilizing expression tree stringifier,
        mark parent and stop.

        Returns
        -------
        list[tuple[dict[str,str], dict[str,str]]]]
            An extended list of tuples of a node and potentially a edge information or list of nothing.
        """
        if self.show_expressions:
            content = self.symgen(o)
            parent = kwargs.get("parent")
            return self.__add_node(o, label=content, parent=parent, shape="box")
        return []

    def visit_tuple(self, o, **kwargs):
        """
        Recurse for each item in the tuple.
        """
        return self.visit_all(o, **kwargs)

    visit_list = visit_tuple

    def visit_Conditional(self, o, **kwargs):
        """
        Add a :any:`Conditional`, mark parent and visit first body then else body.

        Returns
        -------
        list[tuple[dict[str,str], dict[str,str]]]]
            An extended list of tuples of a node and potentially a edge information
        """
        parent = kwargs.get("parent")
        label = self.symgen(o.condition)
        node_edge_info = self.__add_node(o, label=label, parent=parent, shape="diamond")
        kwargs["parent"] = o
        node_edge_info.extend(self.visit_all(o.body, **kwargs))

        if o.else_body:
            node_edge_info.extend(self.visit_all(o.else_body, **kwargs))
        return node_edge_info


def ir_graph(ir, show_comments=False, show_expressions=False, linewidth=40, symgen=str):
    """
    Pretty-print the given IR using :class:`GraphCollector`.

    Parameters
    ----------
    ir : :any:`Node`
        The IR node starting from which to produce the tree
    show_comments : bool, optional, default: False
        Whether to show comments in the output
    show_expressions : bool, optional, default: False
        Whether to further expand expressions in the output
    """

    if not HAVE_IR_GRAPH:
        raise ImportError("ir_graph is not available.")

    log = "[Loki::Graph Visualization] Created graph visualization in {:.2f}s"

    with Timer(text=log):
        graph_representation = GraphCollector(
            show_comments, show_expressions, linewidth, symgen
        )
        node_edge_info = [
            item for item in graph_representation.visit(ir) if item is not None
        ]

        graph = Digraph()
        graph.attr(rankdir="LR")
        for node_info, edge_info in node_edge_info:
            if node_info:
                graph.node(**node_info)
            if edge_info:
                graph.edge(**edge_info)
        return graph
loki-ecmwf-0.3.6/loki/ir/find.py0000664000175000017500000001710115167130205016603 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Visitor classes that allow searching the IR
"""
from itertools import groupby

from loki.ir.visitor import Visitor
from loki.tools import flatten

__all__ = [
    'FindNodes', 'SequenceFinder', 'PatternFinder', 'is_parent_of',
    'is_child_of', 'FindScopes'
]


class FindNodes(Visitor):
    """
    Find :any:`Node` instances that match a given criterion.

    Parameters
    ----------
    match :
        Node type(s) or node instance to look for.
    mode : optional
        Drive the search. Accepted values are:

        * ``'type'`` (default) : Collect all instances of type :data:`match`.
        * ``'scope'`` : Return the :any:`InternalNode` in which the object
          :data:`match` appears.
    greedy : bool, optional
        Do not recurse for children of a matched node.

    Returns
    -------
    list
        All nodes in the traversed IR that match the criteria.
    """

    @classmethod
    def default_retval(cls):
        """
        Default return value is an empty list.

        Returns
        -------
        list
        """
        return []

    rules = {
        'type': lambda match, o: isinstance(o, match),
        'scope': lambda match, o: match in flatten(o.children)
    }
    """
    Mapping of available :data:`mode` selectors to match rules.
    """

    def __init__(self, match, mode='type', greedy=False):
        super().__init__()
        self.match = match
        self.rule = self.rules[mode]
        self.greedy = greedy

    def visit_object(self, o, **kwargs):
        ret = kwargs.get('ret')
        return ret or self.default_retval()

    def visit_tuple(self, o, **kwargs):
        """
        Visit all elements in the iterable and return the combined result.
        """
        ret = kwargs.pop('ret', self.default_retval())
        for i in o:
            ret = self.visit(i, ret=ret, **kwargs)
        return ret or self.default_retval()

    visit_list = visit_tuple

    def visit_Node(self, o, **kwargs):
        """
        Add the node to the returned list if it matches the criteria and visit
        all children.
        """
        ret = kwargs.pop('ret', self.default_retval())
        if self.rule(self.match, o):
            ret.append(o)
            if self.greedy:
                return ret
        for i in o.children:
            ret = self.visit(i, ret=ret, **kwargs)
        return ret or self.default_retval()

    def visit_TypeDef(self, o, **kwargs):
        """
        Custom handler for :any:`TypeDef` nodes that does not traverse the
        body (reason being that discovering nodes such as declarations from
        inside the type definition would be unexpected if called on a
        containing :any:`Subroutine` or :any:`Module`)
        """
        ret = kwargs.pop('ret', self.default_retval())
        if self.rule(self.match, o):
            ret.append(o)
            if self.greedy:
                return ret
        # Do not traverse children (i.e., TypeDef's body)
        return ret or self.default_retval()


def is_child_of(node, other):
    """
    Utility function to test relationship between nodes.

    Note that this can be expensive for large subtrees.

    Returns
    -------
    bool
        Return `True` if :data:`node` is contained in the IR below
        :data:`other`, otherwise return `False`.
    """
    return len(FindNodes(node, mode='scope', greedy=True).visit(other)) > 0


def is_parent_of(node, other):
    """
    Utility function to test relationship between nodes.

    Note that this can be expensive for large subtrees.

    Returns
    -------
    bool
        Return `True` if :data:`other` is contained in the IR below
        :data:`node`, otherwise return `False`.
    """
    return len(FindNodes(other, mode='scope', greedy=True).visit(node)) > 0


class FindScopes(FindNodes):
    """
    Find all parent nodes for node :data:`match`.

    Parameters
    ----------
    match : :any:`Node`
        The node for which the parent nodes are to be found.
    greedy : bool, optional
        Stop traversal when :data:`match` was found.
    """
    def __init__(self, match, greedy=True):
        super().__init__(match=match, greedy=greedy)
        self.rule = lambda match, o: match is o

    def visit_Node(self, o, **kwargs):
        """
        Add the node to the list of ancestors that is passed down to the
        children and, if :data:`o` is :data:`match`, return the list of
        ancestors.
        """
        ret = kwargs.pop('ret', self.default_retval())
        ancestors = kwargs.pop('ancestors', []) + [o]

        if self.rule(self.match, o):
            ret.append(ancestors)
            if self.greedy:
                return ret

        for i in o.children:
            ret = self.visit(i, ret=ret, ancestors=ancestors, **kwargs)
        return ret or self.default_retval()


class SequenceFinder(Visitor):
    """
    Find repeated nodes of the same type in lists/tuples within a given tree.

    Parameters
    ----------
    node_type :
        The node type to look for.
    """

    def __init__(self, node_type):
        super().__init__()
        self.node_type = node_type

    @classmethod
    def default_retval(cls):
        """
        Default return value is an empty list.

        Returns
        -------
        list
        """
        return []

    def visit_tuple(self, o, **kwargs):
        """
        Visit all children and look for sequences of matching type.
        """
        groups = []
        for c in o:
            # First recurse...
            subgroups = self.visit(c)
            if subgroups is not None and len(subgroups) > 0:
                groups += subgroups
        for t, group in groupby(o, type):
            # ... then add new groups
            g = tuple(group)
            if t is self.node_type and len(g) > 1:
                groups.append(g)
        return groups

    visit_list = visit_tuple


class PatternFinder(Visitor):
    """
    Find a pattern of nodes given as tuple/list of types within a given tree.

    Parameters
    ----------
    pattern : iterable of types
        The type pattern to look for.
    """

    def __init__(self, pattern):
        super().__init__()
        self.pattern = pattern

    @classmethod
    def default_retval(cls):
        """
        Default return value is an empty list.

        Returns
        -------
        list
        """
        return []

    @staticmethod
    def match_indices(pattern, sequence):
        """ Return indices of matched patterns in sequence. """
        matches = []
        for i, elem in enumerate(sequence):
            if elem == pattern[0]:
                if tuple(sequence[i:i+len(pattern)]) == tuple(pattern):
                    matches.append(i)
        return matches

    def visit_tuple(self, o, **kwargs):
        """
        Visit all children and look for sequences of nodes with types matching
        the pattern.
        """
        matches = []
        for c in o:
            # First recurse...
            submatches = self.visit(c)
            if submatches is not None and len(submatches) > 0:
                matches += submatches
        types = list(map(type, o))
        idx = self.match_indices(self.pattern, types)
        for i in idx:
            matches.append(o[i:i+len(self.pattern)])
        return matches

    visit_list = visit_tuple
loki-ecmwf-0.3.6/loki/ir/expr_visitors.py0000664000175000017500000004510415167130205020607 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Visitor classes for traversing and transforming all expression trees in
:doc:`internal_representation`.
"""
from pymbolic.primitives import Expression

from loki.ir.nodes import Node
from loki.ir.visitor import Visitor
from loki.ir.transformer import Transformer
from loki.tools import flatten, as_tuple, dict_override, OrderedSet
from loki.expression.mappers import (
    SubstituteExpressionsMapper, ExpressionRetriever,
    AttachScopesMapper, LokiIdentityMapper
)
from loki.expression.symbols import (
    Array, Scalar, InlineCall, TypedSymbol, FloatLiteral, IntLiteral,
    LogicLiteral, StringLiteral, IntrinsicLiteral, DeferredTypeSymbol
)

__all__ = [
    'ExpressionFinder', 'FindExpressions', 'FindVariables',
    'FindTypedSymbols', 'FindInlineCalls', 'FindLiterals',
    'FindRealLiterals', 'ExpressionTransformer',
    'SubstituteExpressions', 'SubstituteStringExpressions',
    'AttachScopes'
]


class ExpressionFinder(Visitor):
    """
    Base class visitor to collect specific sub-expressions,
    eg. functions or symbols, from all nodes in an IR tree.

    Note that specialized ``FindXXX`` classes are provided in :py:mod:`loki.expression`
    to find some of the most common sub-expression types, eg. symbols, functions
    and variables.

    Attributes
    ----------
    retriever : :class:`pymbolic.mapper.Mapper`
        An implementation of an expression mapper, e.g., :any:`ExpressionRetriever`,
        that is used to search expression trees. Note that it needs to provide a
        ``retrieve`` method to initiate the traversal and retrieve the list of expressions.

    Parameters
    ----------
    unique : bool, optional
        If `True` the visitor will return a `OrderedSet` of unique sub-expression
        instead of a list of possibly repeated instances.
    with_ir_node : bool, optional
        If `True` the visitor will return tuples which contain the
        sub-expression and the corresponding IR node in which the
        expression is contained.
    """
    # pylint: disable=unused-argument

    retriever = ExpressionRetriever(lambda _: False)

    def __init__(self, unique=True, with_ir_node=False):
        super().__init__()
        self.unique = unique
        self.with_ir_node = with_ir_node

    def find_uniques(self, variables):
        """
        Reduces the number of matched sub-expressions to a set of unique sub-expressions,
        if self.unique is `True`.

        Currently, two sub-expressions are considered NOT to be unique if they have the same
        - :attr:`name`
        - :attr:`parent.name` (or `None`)
        - :attr:`dimensions` (for :any:`Array`)
        """
        def dict_key(var):
            assert isinstance(var, Expression)
            if isinstance(var, (Scalar, Array)):
                return (var.name,
                        var.parent.name if hasattr(var, 'parent') and var.parent else None,
                        var.dimensions if isinstance(var, Array) else None)
            return str(var)

        if self.unique:
            var_dict = {dict_key(var): var for var in variables}
            return OrderedSet(var_dict.values())
        return variables

    @classmethod
    def retrieve(cls, expr):
        """
        Internal retrieval function used on expressions.
        """
        return cls.retriever.retrieve(expr)

    def _return(self, node, expressions):
        """
        Create the return value from the found expressions.
        """
        if not expressions:
            return ()
        if self.with_ir_node:
            # A direct call to flatten() would destroy our tuples, thus we need to
            # sort through the list and single out existing tuple-value pairs and
            # plain expressions before finding uniques
            def is_leaf(el):
                return isinstance(el, tuple) and len(el) == 2 and isinstance(el[0], Node)
            newlist = flatten(expressions, is_leaf=is_leaf)
            tuple_list = [el for el in newlist if is_leaf(el)]
            exprs = [el for el in newlist if not is_leaf(el)]
            if exprs:
                tuple_list += [(node, self.find_uniques(exprs))]
            return as_tuple(tuple_list)

        # Flatten the (possibly nested) list
        return self.find_uniques(as_tuple(flatten(expressions)))

    default_retval = tuple

    def visit_tuple(self, o, **kwargs):
        expressions = [self.visit(c, **kwargs) for c in o]
        return self._return(o, as_tuple(expressions))

    visit_list = visit_tuple

    def visit_Expression(self, o, **kwargs):
        return as_tuple(self.retrieve(o))

    def visit_Node(self, o, **kwargs):
        expressions = [self.visit(c, **kwargs) for c in flatten(o.children)]
        return self._return(o, as_tuple(expressions))

    def visit_TypeDef(self, o, **kwargs):
        """
        Custom handler for :any:`TypeDef` nodes that does not traverse the
        body (reason being that discovering variables used or declared
        inside the type definition would be unexpected if called on a
        containing :any:`Subroutine` or :any:`Module`)
        """
        return self._return(o, ())

    def visit_VariableDeclaration(self, o, **kwargs):
        expressions = as_tuple(super().visit(o.children, **kwargs))
        for v in o.symbols:
            if v.type.initial is not None:
                expressions += as_tuple(self.retrieve(v.type.initial))
        return self._return(o, expressions)


class FindExpressions(ExpressionFinder):
    """
    A visitor to collect all expression tree nodes
    (i.e., :class:`pymbolic.primitives.Expression`) in an IR tree.

    See :any:`ExpressionFinder`
    """
    retriever = ExpressionRetriever(lambda e: isinstance(e, Expression))


class FindTypedSymbols(ExpressionFinder):
    """
    A visitor to collect all :any:`TypedSymbol` used in an IR tree.

    See :any:`ExpressionFinder`
    """
    retriever = ExpressionRetriever(lambda e: isinstance(e, TypedSymbol))


class FindVariables(ExpressionFinder):
    """
    A visitor to collect all variables used in an IR tree

    This refers to expression tree nodes :any:`Scalar`, :any:`Array` and also
    :any:`DeferredTypeSymbol`.

    See :class:`ExpressionFinder` for further details
    """
    retriever = ExpressionRetriever(lambda e: isinstance(e, (Scalar, Array, DeferredTypeSymbol)))


class FindInlineCalls(ExpressionFinder):
    """
    A visitor to collect all :any:`InlineCall` symbols used in an IR tree.

    See :class:`ExpressionFinder`
    """
    retriever = ExpressionRetriever(lambda e: isinstance(e, InlineCall))


class FindLiterals(ExpressionFinder):
    """
    A visitor to collect all literals (which includes :any:`FloatLiteral`,
    :any:`IntLiteral`, :any:`LogicLiteral`, :any:`StringLiteral`,
    and :any:`IntrinsicLiteral`) used in an IR tree.

    See :class:`ExpressionFinder`
    """
    retriever = ExpressionRetriever(lambda e: isinstance(e, (
        FloatLiteral, IntLiteral, LogicLiteral, StringLiteral, IntrinsicLiteral
    )))

class FindRealLiterals(ExpressionFinder):
    """
    A visitor to collect all real/float literals (which includes :any:`FloatLiteral`)
    used in an IR tree.

    See :class:`ExpressionFinder`
    """
    retriever = ExpressionRetriever(lambda e: isinstance(e, FloatLiteral))


class ExpressionTransformer(Transformer):
    """
    The :any:`Transformer` base class for manipulating expressions.

    This transformer uses the class attribute :data:`expr_mapper` to
    map an existing expression sub-tree to an new one. By default, it
    uses the :any:`LokiIdentityMapper` to replicate the existing tree.

    Attributes
    ----------
    expr_mapper : :class:`pymbolic.mapper.Mapper`
        An implementation of an expression mapper, e.g.,
        :any:`SubstituteExpressionsMapper`, that is used to map an
        expression tree to a new one.

    Parameters
    ----------
    inplace : bool, optional
        If set to `True`, all updates are performed on existing :any:`Node`
        objects, instead of rebuilding them, keeping the original tree intact.
    """
    expr_mapper = LokiIdentityMapper()

    def visit(self, o, *args, **kwargs):
        source = kwargs.get('source')
        if hasattr(o, 'source'):
            source = o.source
        # Pass down the enclosing `Source` object, so we may invalidate it
        with dict_override(kwargs, {'source': source}):
            obj = super().visit(o, *args, **kwargs)
        return obj

    def visit_Expression(self, o, **kwargs):
        """
        Call the associated mapper for the given expression node
        """
        if kwargs.get('recurse_to_declaration_attributes'):
            new = self.expr_mapper(o, recurse_to_declaration_attributes=True)
        else:
            new = self.expr_mapper(o)
        # Invalidate `Source` object if we've changed the expression
        if kwargs.get('source') and o != new:
            kwargs['source'].invalidate()
        return new


class SubstituteExpressions(ExpressionTransformer):
    """
    A dedicated visitor to perform expression substitution in all IR nodes

    It applies :any:`SubstituteExpressionsMapper` with the provided :data:`expr_map`
    to every expression in the traversed IR tree.

    .. note::
       No recursion is performed on substituted expression nodes, they are taken
       as-is from the map. Otherwise substitutions that involve the original node
       would result in infinite recursion - for example a replacement that wraps
       a variable in an inline call:  ``my_var -> wrapped_in_call(my_var)``.

       When there is a need to recursively apply the mapping, the mapping needs to
       be applied to itself first. A potential use-case is renaming of variables,
       which may appear as the name of an array subscript as well as in the ``dimensions``
       attribute of the same expression: ``SOME_ARR(SOME_ARR > SOME_VAL)``.
       The mapping can be applied to itself using the utility function
       :any:`recursive_expression_map_update`.

    Parameters
    ----------
    expr_map : dict
        Expression mapping to apply to the expression tree.
    invalidate_source : bool, optional
        By default the :attr:`source` property of nodes is discarded
        when rebuilding the node, setting this to `False` allows to
        retain that information
    """
    # pylint: disable=unused-argument

    def __init__(self, expr_map, invalidate_source=True, **kwargs):
        super().__init__(invalidate_source=invalidate_source, **kwargs)

        # Override the static default with a substitution mapper from ``expr_map``
        self.expr_mapper = SubstituteExpressionsMapper(expr_map)

    def visit_Import(self, o, **kwargs):
        """
        For :any:`Import` we set ``recurse_to_declaration_attributes=True``
        to make sure properties in the symbol table are updated during
        dispatch to the expression mapper.
        """
        kwargs['recurse_to_declaration_attributes'] = True
        return super().visit_Node(o, **kwargs)

    def visit_VariableDeclaration(self, o, **kwargs):
        """
        For :any:`VariableDeclaration`  or :any:`ProcedureDeclaration`
        we set ``recurse_to_declaration_attributes=True`` to make sure
        properties in the symbol table are updated during dispatch to
        the expression mapper.

        If source invalidation is being requested, we also check the
        associated type (on first symbol) to track changes there.
        """
        kwargs['recurse_to_declaration_attributes'] = True

        # Store a copy of the old type, as it will be in-place updated
        old_type = o.symbols[0].type.clone() if self.invalidate_source else None
        new = super().visit_Node(o, **kwargs)

        # Check the type if we're tracking source invalidation
        if self.invalidate_source and o.source:
            if old_type != o.symbols[0].type:
                new.source.invalidate()

        return new

    # visit_VariableDeclaration = visit_Import
    visit_ProcedureDeclaration = visit_VariableDeclaration


class SubstituteStringExpressions(SubstituteExpressions):
    """
    Extension to :any:`SubstituteExpressions` that allows symbol
    substitution of pure string mappings via :any:`parse_expr`.

    In addition to the input string mapping this requires a :any:`Scope`
    (eg. :any:`Subroutine` or :any:`Module`) to parse the respective strings.

    Parameters
    ----------
    expr_map : dict
        String-to-string mapping of expressions to apply to the expression tree.
    scope : :any:`Scope`
        The scope to which symbol names inside the expression belong
    invalidate_source : bool, optional
        By default the :attr:`source` property of nodes is discarded
        when rebuilding the node, setting this to `False` allows to
        retain that information
    """
    def __init__(self, str_map, scope, invalidate_source=True):
        from loki.expression.parser import parse_expr  # pylint: disable=import-outside-toplevel,cyclic-import
        expr_map = {
            parse_expr(k, scope=scope): parse_expr(v, scope=scope)
            for k, v in str_map.items()
        }
        super().__init__(expr_map=expr_map, invalidate_source=invalidate_source)


class AttachScopes(Visitor):
    """
    Scoping visitor that traverses the control flow tree and uses
    :any:`AttachScopesMapper` to update all :any:`TypedSymbol` expression
    tree nodes with a pointer to their corresponding scope.

    Parameters
    ----------
    fail : bool, optional
        If set to True, this lets the visitor fail if it encounters a node
        without a declaration or an entry in any of the symbol tables
        (default: False).
    """

    def __init__(self, fail=False):
        super().__init__()
        self.fail = fail
        self.expr_mapper = AttachScopesMapper(fail=fail)

    @staticmethod
    def _update(o, children, **args):
        """
        Utility routine to update the IR node
        """
        args_frozen = o.args_frozen
        args_frozen.update(args)
        o._update(*children, **args_frozen)
        return o

    def visit_object(self, o, **kwargs):
        """Return any foreign object unchanged."""
        return o

    def visit(self, o, *args, **kwargs):
        """
        Default visitor method that dispatches the node-specific handler
        """
        kwargs.setdefault('scope', None)
        return super().visit(o, *args, **kwargs)

    def visit_Expression(self, o, **kwargs):
        """
        Dispatch :any:`AttachScopesMapper` for :any:`Expression` tree nodes
        """
        if kwargs.get('recurse_to_declaration_attributes'):
            return self.expr_mapper(o, scope=kwargs['scope'], recurse_to_declaration_attributes=True)
        return self.expr_mapper(o, scope=kwargs['scope'])

    def visit_list(self, o, **kwargs):
        """
        Visit each entry in a list and return as a tuple
        """
        return tuple(self.visit(c, **kwargs) for c in o)

    visit_tuple = visit_list

    def visit_Node(self, o, **kwargs):
        """
        Generic handler for IR :any:`Node` objects

        Recurses to children and updates the node
        """
        children = tuple(self.visit(i, **kwargs) for i in o.children)
        return self._update(o, children)

    def visit_Import(self, o, **kwargs):
        """
        For :any:`Import` (as well as :any:`VariableDeclaration` and :any:`ProcedureDeclaration`)
        we set ``recurse_to_declaration_attributes=True`` to make sure properties in the symbol
        table are updated during dispatch to the expression mapper
        """
        kwargs['recurse_to_declaration_attributes'] = True
        return self.visit_Node(o, **kwargs)

    visit_VariableDeclaration = visit_Import
    visit_ProcedureDeclaration = visit_Import

    def visit_Scope(self, o, **kwargs):
        """
        Generic handler for :any:`Scope` objects

        Makes sure that declared variables and imported symbols have an
        entry in that node's symbol table before recursing to children with
        this node as new scope.
        """
        # First, make sure declared variables and imported symbols have an
        # entry in the scope's table
        self._update_symbol_table_with_decls_and_imports(o)

        # Attach parent scope if it is new before passing self down to children
        parent_scope = kwargs.get('scope', o.parent)
        if o.parent is not parent_scope and o is not parent_scope:
            o._reset_parent(parent=parent_scope)

        # Then recurse to all children
        kwargs['scope'] = o
        children = tuple(self.visit(i, **kwargs) for i in o.children)
        return self._update(o, children, symbol_attrs=o.symbol_attrs, rescope_symbols=False)

    @staticmethod
    def _update_symbol_table_with_decls_and_imports(o):
        """
        Utility function to insert default entries for symbols declared or
        imported in a node
        """
        for v in getattr(o, 'variables', ()):
            o.symbol_attrs.setdefault(v.name, v.type)
        for s in getattr(o, 'imported_symbols', ()):
            o.symbol_attrs.setdefault(s.name, s.type)

    def visit_Subroutine(self, o, **kwargs):
        """
        Handler for :any:`Subroutine` nodes

        Makes sure that declared variables and imported symbols have an
        entry in the routine's symbol table before recursing to spec, body,
        and member routines with this routine as new scope.
        """
        # First, make sure declared variables and imported symbols have an
        # entry in the scope's table
        self._update_symbol_table_with_decls_and_imports(o)

        # Then recurse to all children
        kwargs['scope'] = o
        o.spec = self.visit(o.spec, **kwargs)
        o.body = self.visit(o.body, **kwargs)
        o._members = self.visit(o.members, **kwargs)
        return o

    def visit_Module(self, o, **kwargs):
        """
        Handler for :any:`Module` nodes

        Makes sure that declared variables and imported symbols have an
        entry in the module's symbol table before recursing to spec,
        and member routines with this module as new scope.
        """
        # First, make sure declared variables and imported symbols have an
        # entry in the scope's table
        self._update_symbol_table_with_decls_and_imports(o)

        # Then recurse to all children
        kwargs['scope'] = o
        o.spec = self.visit(o.spec, **kwargs)
        o.contains = self.visit(o.contains, **kwargs)
        return o
loki-ecmwf-0.3.6/loki/ir/nodes.py0000664000175000017500000020735015167130205017002 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

# pylint: disable=too-many-lines
"""
Control flow node classes for
:ref:`internal_representation:Control flow tree`
"""

from abc import abstractmethod
from collections import OrderedDict
from dataclasses import dataclass
from functools import partial
from itertools import chain
from typing import Any, Tuple, Union, Optional

from pymbolic.primitives import Expression

from pydantic.dataclasses import dataclass as dataclass_validated
from pydantic import field_validator

from loki.expression import (
    symbols as sym, Variable, parse_expr, AttachScopesMapper,
    ExpressionDimensionsMapper
)
from loki.frontend.source import Source
from loki.tools import flatten, as_tuple, is_iterable, truncate_string, CaseInsensitiveDict
from loki.types import DataType, BasicType, DerivedType, SymbolAttributes, Scope

__all__ = [
    # Abstract base classes
    'Node', 'InternalNode', 'LeafNode', 'ScopedNode',
    # Internal node classes
    'Section', 'Associate', 'Loop', 'WhileLoop', 'Conditional',
    'PragmaRegion', 'Interface',
    # Leaf node classes
    'Assignment', 'ConditionalAssignment', 'CallStatement',
    'Allocation', 'Deallocation', 'Nullify',
    'Comment', 'CommentBlock', 'Pragma', 'PreprocessorDirective',
    'Import', 'VariableDeclaration', 'ProcedureDeclaration', 'DataDeclaration',
    'StatementFunction', 'TypeDef', 'MultiConditional', 'TypeConditional',
    'Forall', 'MaskedStatement',
    'Intrinsic', 'Enumeration', 'RawSource',
]

# Configuration for validation mechanism via pydantic
dataclass_validation_config  = {
    'arbitrary_types_allowed': True,
}

# Using this decorator, we can force strict validation
dataclass_strict = partial(dataclass_validated, config=dataclass_validation_config)


def _sanitize_tuple(t):
    """
    Small helper method to ensure non-nested tuples without ``None``.
    """
    return tuple(n for n in flatten(as_tuple(t)) if n is not None)


# Abstract base classes

@dataclass_strict(frozen=True)
class Node:
    """
    Base class for all node types in Loki's internal representation.

    Provides the common functionality shared by all node types; specifically,
    this comprises functionality to update or rebuild a node, and source
    metadata.

    Attributes
    ----------
    traversable : list of str
        The traversable fields of the Node; that is, fields walked over by
        a :any:`Visitor`. All arguments in :py:meth:`__init__` whose
        name appear in this list are treated as traversable fields.

    Parameters
    ----------
    source : :any:`Source`, optional
        the information about the original source for the Node.
    label : str, optional
        the label assigned to the statement in the original source
        corresponding to the Node.

    """

    source: Optional[Union[Source, str]] = None
    label: Optional[str] = None

    _traversable = []

    def __post_init__(self):
        # Create private placeholders for dataflow analysis fields that
        # do not show up in the dataclass field definitions, as these
        # are entirely transient.
        self._update(_live_symbols=None, _defines_symbols=None, _uses_symbols=None)

    @property
    def children(self):
        """
        The traversable children of the node.
        """
        return tuple(getattr(self, i) for i in self._traversable)

    def _rebuild(self, *args, **kwargs):
        """
        Rebuild the node.

        Constructs an identical copy of the node from when it was first
        created. Optionally, some or all of the arguments for it can
        be overwritten.

        Parameters
        ----------
        *args : optional
            The traversable arguments used to create the node. By default,
            ``args`` are used.
        **kwargs : optional
            The non-traversable arguments used to create the node, By
            default, ``args_frozen`` are used.
        """
        handle = self.args
        argnames = [i for i in self._traversable if i not in kwargs]
        handle.update(OrderedDict(zip(argnames, args)))
        handle.update(kwargs)
        return type(self)(**handle)

    clone = _rebuild

    def _update(self, *args, **kwargs):
        """
        In-place update that modifies (re-initializes) the node
        without rebuilding it. Use with care!

        Parameters
        ----------
        *args : optional
            The traversable arguments used to create the node. By default,
            ``args`` are used.
        **kwargs : optional
            The non-traversable arguments used to create the node, By
            default, ``args_frozen`` are used.

        """
        argnames = [i for i in self._traversable if i not in kwargs]
        kwargs.update(zip(argnames, args))
        self.__dict__.update(kwargs)

    @property
    def args(self):
        """
        Arguments used to construct the Node.
        """
        return {k: v for k, v in self.__dict__.items() if k in self.__dataclass_fields__.keys()}  # pylint: disable=no-member

    @property
    def args_frozen(self):
        """
        Arguments used to construct the Node that cannot be traversed.
        """
        return {k: v for k, v in self.args.items() if k not in self._traversable}

    def __repr__(self):
        raise NotImplementedError

    def view(self):
        """
        Pretty-print the node hierachy under this node.
        """
        # pylint: disable=import-outside-toplevel,cyclic-import
        from loki.backend.pprint import pprint
        pprint(self)

    def ir_graph(self, show_comments=False, show_expressions=False, linewidth=40, symgen=str):
        """
        Get the IR graph to visualize the node hierachy under this node.
        """
        # pylint: disable=import-outside-toplevel,cyclic-import
        from loki.ir.ir_graph import ir_graph

        return ir_graph(self, show_comments, show_expressions,linewidth, symgen)

    @property
    def live_symbols(self):
        """
        Yield the list of live symbols at this node, i.e., variables that
        have been defined (potentially) prior to this point in the control flow
        graph.

        This property is attached to the Node by
        :py:func:`loki.analyse.analyse_dataflow.attach_dataflow_analysis` or
        when using the
        :py:func:`loki.analyse.analyse_dataflow.dataflow_analysis_attached`
        context manager.
        """
        if self.__dict__['_live_symbols'] is None:
            raise RuntimeError('Need to run dataflow analysis on the IR first.')
        return self.__dict__['_live_symbols']

    @property
    def defines_symbols(self):
        """
        Yield the list of symbols (potentially) defined by this node.

        This property is attached to the Node by
        :py:func:`loki.analyse.analyse_dataflow.attach_dataflow_analysis` or
        when using the
        :py:func:`loki.analyse.analyse_dataflow.dataflow_analysis_attached`
        context manager.
        """
        if self.__dict__['_defines_symbols'] is None:
            raise RuntimeError('Need to run dataflow analysis on the IR first.')
        return self.__dict__['_defines_symbols']

    @property
    def uses_symbols(self):
        """

        Yield the list of symbols used by this node before defining it.

        This property is attached to the Node by
        :py:func:`loki.analyse.analyse_dataflow.attach_dataflow_analysis` or
        when using the
        :py:func:`loki.analyse.analyse_dataflow.dataflow_analysis_attached`
        context manager.
        """
        if self.__dict__['_uses_symbols'] is None:
            raise RuntimeError('Need to run dataflow analysis on the IR first.')
        return self.__dict__['_uses_symbols']


@dataclass_strict(frozen=True)
class _InternalNode():
    """ Type definitions for :any:`InternalNode` node type. """

    body: Tuple[Union[Node, Scope], ...] = ()


@dataclass_strict(frozen=True)
class InternalNode(Node, _InternalNode):
    """
    Internal representation of a control flow node that has a traversable
    `body` property.

    Parameters
    ----------
    body : tuple
        The nodes that make up the body.
    """

    _traversable = ['body']

    @field_validator('body', mode='before')
    @classmethod
    def ensure_tuple(cls, value):
        return _sanitize_tuple(value)

    def __repr__(self):
        raise NotImplementedError


@dataclass_strict(frozen=True)
class LeafNode(Node):
    """
    Internal representation of a control flow node without a `body`.
    """

    def __repr__(self):
        raise NotImplementedError


# Mix-ins

class ScopedNode(Scope):
    """
    Mix-in to attache a scope to an IR :any:`Node`

    Additionally, this specializes the node's :meth:`_update` and
    :meth:`_rebuild` methods to make sure that an existing symbol table
    is carried over correctly.
    """

    @property
    def args(self):
        """
        Arguments used to construct the :any:`ScopedNode`, excluding
        the symbol table.
        """
        keys = tuple(k for k in self.__dataclass_fields__.keys() if k not in ('symbol_attrs', ))  # pylint: disable=no-member
        return {k: v for k, v in self.__dict__.items() if k in keys}

    def _update(self, *args, **kwargs):
        if 'symbol_attrs' not in kwargs:
            # Retain the symbol table (unless given explicitly)
            kwargs['symbol_attrs'] = self.symbol_attrs
        super()._update(*args, **kwargs)  # pylint: disable=no-member

    def _rebuild(self, *args, **kwargs):
        # Retain the symbol table (unless given explicitly)
        symbol_attrs = kwargs.pop('symbol_attrs', self.symbol_attrs)
        rescope_symbols = kwargs.pop('rescope_symbols', False)

        # Ensure 'parent' is always explicitly set
        kwargs['parent'] = kwargs.get('parent', None)

        new_obj = super()._rebuild(*args, **kwargs)  # pylint: disable=no-member
        new_obj.symbol_attrs.update(symbol_attrs)

        if rescope_symbols:
            new_obj.rescope_symbols()
        return new_obj

    def __getstate__(self):
        s = self.args
        s['symbol_attrs'] = self.symbol_attrs
        return s

    def __setstate__(self, s):
        symbol_attrs = s.pop('symbol_attrs', None)
        self._update(**s, symbol_attrs=symbol_attrs, rescope_symbols=True)

    @property
    @abstractmethod
    def variables(self):
        """
        Return the variables defined in this :any:`ScopedNode`.
        """

    @property
    def variable_map(self):
        """
        Map of variable names to :any:`Variable` objects
        """
        return CaseInsensitiveDict((v.name, v) for v in self.variables)

    def get_symbol(self, name):
        """
        Returns the symbol for a given name as defined in its declaration.

        The returned symbol might include dimension symbols if it was
        declared as an array.

        Parameters
        ----------
        name : str
            Base name of the symbol to be retrieved
        """
        return self.get_symbol_scope(name).variable_map.get(name)

    def Variable(self, **kwargs):
        """
        Factory method for :any:`TypedSymbol` or :any:`MetaSymbol` classes.

        This invokes the :any:`Variable` with this node as the scope.

        Parameters
        ----------
        name : str
            The name of the variable.
        type : optional
            The type of that symbol. Defaults to :any:`BasicType.DEFERRED`.
        parent : :any:`Scalar` or :any:`Array`, optional
            The derived type variable this variable belongs to.
        dimensions : :any:`ArraySubscript`, optional
            The array subscript expression.
        """
        kwargs['scope'] = self
        return Variable(**kwargs)

    def parse_expr(self, expr_str, strict=False, evaluate=False, context=None):
        """
        Uses :meth:`parse_expr` to convert expression(s) represented
        in a string to Loki expression(s)/IR.

        Parameters
        ----------
        expr_str : str
            The expression as a string
        strict : bool, optional
            Whether to raise exception for unknown variables/symbols when
            evaluating an expression (default: `False`)
        evaluate : bool, optional
            Whether to evaluate the expression or not (default: `False`)
        context : dict, optional
            Symbol context, defining variables/symbols/procedures to help/support
            evaluating an expression

        Returns
        -------
        :any:`Expression`
            The expression tree corresponding to the expression
        """
        return parse_expr(expr_str, scope=self, strict=strict, evaluate=evaluate, context=context)


# Intermediate node types


@dataclass_strict(frozen=True)
class _SectionBase():
    """ Type definitions for :any:`Section` node type. """


@dataclass_strict(frozen=True)
class Section(InternalNode, _SectionBase):
    """
    Internal representation of a single code region.
    """

    def append(self, node):
        """
        Append the given node(s) to the section's body.

        Parameters
        ----------
        node : :any:`Node` or tuple of :any:`Node`
            The node(s) to append to the section.
        """
        self._update(body=self.body + as_tuple(node))

    def insert(self, pos, node):
        """
        Insert the given node(s) into the section's body at a specific
        position.

        Parameters
        ----------
        pos : int
            The position at which the node(s) should be inserted. Any existing
            nodes at this or after this position are shifted back.
        node : :any:`Node` or tuple of :any:`Node`
            The node(s) to append to the section.
        """
        self._update(body=self.body[:pos] + as_tuple(node) + self.body[pos:])  # pylint: disable=unsubscriptable-object

    def prepend(self, node):
        """
        Insert the given node(s) at the beginning of the section's body.

        Parameters
        ----------
        node : :any:`Node` or tuple of :any:`Node`
            The node(s) to insert into the section.
        """
        self._update(body=as_tuple(node) + self.body)

    def __repr__(self):
        if self.label is not None:
            return f'Section:: {self.label}'
        return 'Section::'


@dataclass_strict(frozen=True)
class _AssociateBase():
    """ Type definitions for :any:`Associate` node type. """

    associations: Tuple[Tuple[Expression, Expression], ...]


@dataclass_strict(frozen=True)
class Associate(ScopedNode, Section, _AssociateBase):  # pylint: disable=too-many-ancestors
    """
    Internal representation of a code region in which names are associated
    with expressions or variables.

    Parameters
    ----------
    body : tuple
        The associate's body.
    associations : dict or collections.OrderedDict
        The mapping of names to expressions or variables valid inside the
        associate's body.
    parent : :any:`Scope`, optional
        The parent scope in which the associate appears
    symbol_attrs : :any:`SymbolTable`, optional
        An existing symbol table to use
    **kwargs : optional
        Other parameters that are passed on to the parent class constructor.
    """

    _traversable = ['body', 'associations']

    def __post_init__(self, parent=None):
        super(ScopedNode, self).__post_init__(parent=parent)
        super(Section, self).__post_init__()

        assert self.associations is None or isinstance(self.associations, tuple)

    @property
    def association_map(self):
        """
        An :any:`collections.OrderedDict` of associated expressions.
        """
        return CaseInsensitiveDict((k, v) for k, v in self.associations)

    @property
    def inverse_map(self):
        """
        An :any:`collections.OrderedDict` of associated expressions.
        """
        return CaseInsensitiveDict((v, k) for k, v in self.associations)

    @property
    def variables(self):
        return tuple(v for _, v in self.associations)

    def _derive_local_symbol_types(self, parent_scope):
        """ Derive the types of locally defined symbols from their associations. """

        rescoped_associations = ()
        for expr, name in self.associations:
            # Put symbols in associated expression into the right scope
            expr = AttachScopesMapper()(expr, scope=parent_scope)

            # Determine type of new names
            if isinstance(expr, (sym.TypedSymbol, sym.MetaSymbol)):
                # Use the type of the associated variable
                _type = expr.type.clone(parent=None)
                if isinstance(expr, sym.Array) and expr.dimensions is not None:
                    shape = ExpressionDimensionsMapper()(expr)
                    if shape == (sym.IntLiteral(1),):
                        # For a scalar expression, we remove the shape
                        shape = None
                    _type = _type.clone(shape=shape)
            else:
                # TODO: Handle data type and shape of complex expressions
                shape = ExpressionDimensionsMapper()(expr)
                if shape == (sym.IntLiteral(1),):
                    # For a scalar expression, we remove the shape
                    shape = None
                _type = SymbolAttributes(BasicType.DEFERRED, shape=shape)
            name = name.clone(scope=self, type=_type)
            rescoped_associations += ((expr, name),)

        self._update(associations=rescoped_associations)

    def __repr__(self):
        if self.associations:
            associations = ', '.join(f'{str(var)}={str(expr)}'
                                     for var, expr in self.associations)
            return f'Associate:: {associations}'
        return 'Associate::'


@dataclass_strict(frozen=True)
class _LoopBase():
    """ Type definitions for :any:`Loop` node type. """

    variable: Expression
    bounds: Expression
    body: Tuple[Node, ...]
    pragma: Optional[Tuple[Node, ...]] = None
    pragma_post: Optional[Tuple[Node, ...]] = None
    loop_label: Optional[Any] = None
    name: Optional[str] = None
    has_end_do: Optional[bool] = True


@dataclass_strict(frozen=True)
class Loop(InternalNode, _LoopBase):
    """
    Internal representation of a loop with induction variable and range.

    Parameters
    ----------
    variable : :any:`Scalar`
        The induction variable of the loop.
    bounds : :any:`LoopRange`
        The range of the loop, defining the iteration space.
    body : tuple
        The loop body.
    pragma : tuple of :any:`Pragma`, optional
        Pragma(s) that appear in front of the loop. By default :any:`Pragma`
        nodes appear as standalone nodes in the IR before the :any:`Loop` node.
        Only a bespoke context created by :py:func:`pragmas_attached`
        attaches them for convenience.
    pragma_post : tuple of :any:`Pragma`, optional
        Pragma(s) that appear after the loop. The same applies as for `pragma`.
    loop_label : str, optional
        The Fortran label for that loop. Importantly, this is an intrinsic
        Fortran feature and different from the statement label that can be
        attached to other nodes.
    name : str, optional
        The Fortran construct name for that loop.
    has_end_do : bool, optional
        In Fortran, loop blocks can be closed off by a ``CONTINUE`` statement
        (which we retain as an :any:`Intrinsic` node) and therefore ``END DO``
        can be omitted. For string reproducibility this parameter can be set
        `False` to indicate that this loop did not have an ``END DO``
        statement in the original source.
    **kwargs : optional
        Other parameters that are passed on to the parent class constructor.
    """

    _traversable = ['variable', 'bounds', 'body']

    def __post_init__(self):
        super().__post_init__()
        assert self.variable is not None

    def __repr__(self):
        label = ', '.join(l for l in [self.name, self.loop_label] if l is not None)
        if label:
            label = ' ' + label
        control = f'{str(self.variable)}={str(self.bounds)}'
        return f'Loop::{label} {control}'


@dataclass_strict(frozen=True)
class _WhileLoopBase():
    """ Type definitions for :any:`WhileLoop` node type. """

    condition: Optional[Expression]
    body: Tuple[Node, ...]
    pragma: Optional[Node] = None
    pragma_post: Optional[Node] = None
    loop_label: Optional[Any] = None
    name: Optional[str] = None
    has_end_do: Optional[bool] = True


@dataclass_strict(frozen=True)
class WhileLoop(InternalNode, _WhileLoopBase):
    """
    Internal representation of a while loop in source code.

    Importantly, this is different from a ``DO`` (Fortran) or ``for`` (C) loop,
    as we do not have a specified induction variable with explicit iteration
    range.

    Parameters
    ----------
    condition : :any:`pymbolic.primitives.Expression`
        The condition evaluated before executing the loop body.
    body : tuple
        The loop body.
    pragma : tuple of :any:`Pragma`, optional
        Pragma(s) that appear in front of the loop. By default :any:`Pragma`
        nodes appear as standalone nodes in the IR before the :any:`Loop` node.
        Only a bespoke context created by :py:func:`pragmas_attached`
        attaches them for convenience.
    pragma_post : tuple of :any:`Pragma`, optional
        Pragma(s) that appear after the loop. The same applies as for `pragma`.
    loop_label : str, optional
        The Fortran label for that loop. Importantly, this is an intrinsic
        Fortran feature and different from the statement label that can be
        attached to other nodes.
    name : str, optional
        The Fortran construct name for that loop.
    has_end_do : bool, optional
        In Fortran, loop blocks can be closed off by a ``CONTINUE`` statement
        (which we retain as an :any:`Intrinsic` node) and therefore ``END DO``
        can be omitted. For string reproducibility this parameter can be set
        `False` to indicate that this loop did not have an ``END DO``
        statement in the original source.
    **kwargs : optional
        Other parameters that are passed on to the parent class constructor.
    """

    _traversable = ['condition', 'body']

    def __repr__(self):
        label = ', '.join(l for l in [self.name, self.loop_label] if l is not None)
        if label:
            label = ' ' + label
        control = str(self.condition) if self.condition else ''
        return f'WhileLoop::{label} {control}'


@dataclass_strict(frozen=True)
class _ConditionalBase():
    """ Type definitions for :any:`Conditional` node type. """

    condition: Expression
    body: Tuple[Node, ...]
    else_body: Optional[Tuple[Node, ...]] = ()
    inline: bool = False
    has_elseif: bool = False
    name: Optional[str] = None


@dataclass_strict(frozen=True)
class Conditional(InternalNode, _ConditionalBase):
    """
    Internal representation of a conditional branching construct.

    Parameters
    ----------
    condition : :any:`pymbolic.primitives.Expression`
        The condition evaluated before executing the body.
    body : tuple
        The conditional's body.
    else_body : tuple
        The body of the else branch. Can be empty.
    inline : bool, optional
        Flag that marks this conditional as inline, i.e., it s body consists
        only of a single statement that appeared immediately after the
        ``IF`` statement and it does not have an ``else_body``.
    has_elseif : bool, optional
        Flag that indicates that this conditional has an ``ELSE IF`` branch
        in the original source. In Loki's IR these are represented as a chain
        of :any:`Conditional` but for string reproducibility this flag can be
        provided to enable backends to reproduce the original appearance.
    name : str, optional
        The Fortran construct name for that conditional.
    **kwargs : optional
        Other parameters that are passed on to the parent class constructor.
    """

    _traversable = ['condition', 'body', 'else_body']

    @field_validator('body', 'else_body', mode='before')
    @classmethod
    def ensure_tuple(cls, value):
        return _sanitize_tuple(value)

    def __post_init__(self):
        super().__post_init__()
        assert self.condition is not None

        if self.has_elseif:
            assert len(self.else_body) == 1
            assert isinstance(self.else_body[0], Conditional)  # pylint: disable=unsubscriptable-object

    def __repr__(self):
        if self.name:
            return f'Conditional:: {self.name}'
        return 'Conditional::'

    @property
    def else_bodies(self):
        """
        Return all nested node tuples in the ``ELSEIF``/``ELSE`` part
        of the conditional chain.
        """
        if self.has_elseif:
            return (self.else_body[0].body,) + self.else_body[0].else_bodies
        return (self.else_body,) if self.else_body else ()


@dataclass_strict(frozen=True)
class _PragmaRegionBase():
    """ Type definitions for :any:`PragmaRegion` node type. """

    body: Tuple[Node, ...]
    pragma: Node = None
    pragma_post: Node = None


@dataclass_strict(frozen=True)
class PragmaRegion(InternalNode, _PragmaRegionBase):
    """
    Internal representation of a block of code defined by two matching pragmas.

    Generally, the pair of pragmas are assumed to be of the form
    ``!$ `` and ``!$ end ``.

    This node type is injected into the IR within a context created by
    :py:func:`pragma_regions_attached`.

    Parameters
    ----------
    body : tuple
        The statements appearing between opening and closing pragma.
    pragma : :any:`Pragma`
        The opening pragma declaring that region.
    pragma_post : :any:`Pragma`
        The closing pragma for that region.
    **kwargs : optional
        Other parameters that are passed on to the parent class constructor.
    """

    _traversable = ['body']

    def append(self, node):
        self._update(body=self.body + as_tuple(node))

    def insert(self, pos, node):
        '''Insert at given position'''
        self._update(body=self.body[:pos] + as_tuple(node) + self.body[pos:])  # pylint: disable=unsubscriptable-object

    def prepend(self, node):
        self._update(body=as_tuple(node) + self.body)

    def __repr__(self):
        return 'PragmaRegion::'


@dataclass_strict(frozen=True)
class _InterfaceBase():
    """ Type definitions for :any:`Interface` node type. """

    body: Tuple[Any, ...]
    abstract: bool = False
    spec: Optional[Union[Expression, str]] = None


@dataclass_strict(frozen=True)
class Interface(InternalNode, _InterfaceBase):
    """
    Internal representation of a Fortran interface block.

    Parameters
    ----------
    body : tuple
        The body of the interface block, containing function and subroutine
        specifications or procedure statements
    abstract : bool, optional
        Flag to indicate that this is an abstract interface
    spec : str, optional
        A generic name, operator, assignment, or I/O specification
    **kwargs : optional
        Other parameters that are passed on to the parent class constructor.
    """

    _traversable = ['body']

    def __post_init__(self):
        super().__post_init__()
        assert not (self.abstract and self.spec)

    @property
    def symbols(self):
        """
        The list of symbol names declared by this interface
        """
        symbols = as_tuple(flatten(
            getattr(node, 'procedure_symbol', getattr(node, 'symbols', ()))
            for node in self.body  # pylint: disable=not-an-iterable
        ))
        if self.spec:
            return (self.spec,) + symbols
        return symbols

    @property
    def symbol_map(self):
        """
        Map symbol name to symbol declared by this interface
        """
        return CaseInsensitiveDict(
            (s.name.lower(), s) for s in self.symbols
        )

    def __contains__(self, name):
        return name in self.symbol_map

    def __repr__(self):
        symbols = ', '.join(str(var) for var in self.symbols)
        if self.abstract:
            return f'Abstract Interface:: {symbols}'
        if self.spec:
            return f'Interface {self.spec}:: {symbols}'
        return f'Interface:: {symbols}'

# Leaf node types

@dataclass_strict(frozen=True)
class _AssignmentBase():
    """ Type definitions for :any:`Assignment` node type. """

    lhs: Expression
    rhs: Expression
    ptr: bool = False
    comment: Optional[Node] = None


@dataclass_strict(frozen=True)
class Assignment(LeafNode, _AssignmentBase):
    """
    Internal representation of a variable assignment.

    Parameters
    ----------
    lhs : :any:`pymbolic.primitives.Expression`
        The left-hand side of the assignment.
    rhs : :any:`pymbolic.primitives.Expression`
        The right-hand side expression of the assignment.
    ptr : bool, optional
        Flag to indicate pointer assignment (``=>``). Defaults to ``False``.
    comment : :py:class:`Comment`, optional
        Inline comment that appears in-line after the right-hand side in the
        original source.
    **kwargs : optional
        Other parameters that are passed on to the parent class constructor.
    """

    _traversable = ['lhs', 'rhs']

    def __post_init__(self):
        super().__post_init__()
        assert self.lhs is not None
        assert self.rhs is not None

    def __repr__(self):
        return f'Assignment:: {str(self.lhs)} = {str(self.rhs)}'


@dataclass_strict(frozen=True)
class _ConditionalAssignmentBase():
    """ Type definitions for :any:`ConditionalAssignment` node type. """

    lhs: Optional[Expression] = None
    condition: Optional[Expression] = None
    rhs: Optional[Expression] = None
    else_rhs: Optional[Expression] = None


@dataclass_strict(frozen=True)
class ConditionalAssignment(LeafNode, _ConditionalAssignmentBase):
    """
    Internal representation of an inline conditional assignment using a
    ternary operator.

    There is no Fortran-equivalent to this. In C, this takes the following form:

    .. code-block:: C

        lhs = condition ? rhs : else_rhs;

    Parameters
    ----------
    lhs : :any:`pymbolic.primitives.Expression`
        The left-hand side of the assignment.
    condition : :any:`pymbolic.primitives.Expression`
        The condition of the ternary operator.
    rhs : :any:`pymbolic.primitives.Expression`
        The right-hand side expression of the assignment that is assigned when
        the condition applies.
    else_rhs : :any:`pymbolic.primitives.Expression`
        The right-hand side expression of the assignment that is assigned when
        the condition does not apply.
    **kwargs : optional
        Other parameters that are passed on to the parent class constructor.
    """

    _traversable = ['condition', 'lhs', 'rhs', 'else_rhs']

    def __repr__(self):
        return f'CondAssign:: {self.lhs} = {self.condition} ? {self.rhs} : {self.else_rhs}'


@dataclass_strict(frozen=True)
class _CallStatementBase():
    """ Type definitions for :any:`CallStatement` node type. """

    name: Expression
    arguments: Optional[Tuple[Expression, ...]] = ()
    kwarguments: Optional[Tuple[Tuple[str, Expression], ...]] = ()
    pragma: Optional[Tuple[Node, ...]] = None
    not_active: Optional[bool] = None
    chevron: Optional[Tuple[Expression, ...]] = None


@dataclass_strict(frozen=True)
class CallStatement(LeafNode, _CallStatementBase):
    """
    Internal representation of a subroutine call.

    Parameters
    ----------
    name : :any:`pymbolic.primitives.Expression`
        The name of the subroutine to call.
    arguments : tuple of :any:`pymbolic.primitives.Expression`
        The list of positional arguments.
    kwarguments : tuple of tuple
        The list of keyword arguments, provided as pairs of `(name, value)`.
    pragma : tuple of :any:`Pragma`, optional
        Pragma(s) that appear in front of the statement. By default
        :any:`Pragma` nodes appear as standalone nodes in the IR before.
        Only a bespoke context created by :py:func:`pragmas_attached`
        attaches them for convenience.
    not_active : bool, optional
        Flag to indicate that this call has explicitly been marked as inactive for
        the purpose of processing call trees (Default: `None`)
    chevron : tuple of :any:`pymbolic.primitives.Expression`
        Launch configuration for CUDA Fortran Kernels.
        See [CUDA Fortran programming guide](https://docs.nvidia.com/hpc-sdk/compilers/cuda-fortran-prog-guide/).
    **kwargs : optional
        Other parameters that are passed on to the parent class constructor.
    """

    _traversable = ['name', 'arguments', 'kwarguments']

    @field_validator('arguments', mode='before')
    @classmethod
    def ensure_tuple(cls, value):
        return _sanitize_tuple(value)

    @field_validator('kwarguments', mode='before')
    @classmethod
    def ensure_nested_tuple(cls, value):
        return tuple(_sanitize_tuple(pair) for pair in as_tuple(value))

    def __post_init__(self):
        super().__post_init__()
        assert isinstance(self.arguments, tuple)
        assert all(isinstance(arg, Expression) for arg in as_tuple(self.arguments))

        if self.kwarguments is not None:
            assert isinstance(self.kwarguments, tuple)
            assert all(
                isinstance(a, tuple) and len(a) == 2 and isinstance(a[1], Expression)
                for a in self.kwarguments  # pylint: disable=not-an-iterable
            )

        if self.chevron is not None:
            assert isinstance(self.chevron, tuple)
            assert all(isinstance(a, Expression) for a in self.chevron)  # pylint: disable=not-an-iterable
            assert 2 <= len(self.chevron) <= 4

    def __repr__(self):
        return f'Call:: {self.name}'

    @property
    def procedure_type(self):
        """
        The :any:`ProcedureType` of the :any:`Subroutine` object of the called routine

        For a :class:`CallStatement` node called ``call``, this is shorthand for ``call.name.type.dtype``.

        If the procedure type object has been linked up with the corresponding
        :any:`Subroutine` object, then it is available via ``call.procedure_type.procedure``.

        Returns
        -------
        :any:`ProcedureType` or :any:`BasicType.DEFERRED`
            The type of the called procedure. If the symbol type of the called routine
            has not been identified correctly, this may yield :any:`BasicType.DEFERRED`.
        """
        return self.name.type.dtype

    @property
    def routine(self):
        """
        The :any:`Subroutine` object of the called routine

        Shorthand for ``call.name.type.dtype.procedure``

        Returns
        -------
        :any:`Subroutine` or :any:`BasicType.DEFERRED`
            If the :any:`ProcedureType` object of the :any:`ProcedureSymbol`
            in :attr:`name` is linked up to the target routine, this returns
            the corresponding :any:`Subroutine` object, otherwise `None`.
        """
        procedure_type = self.procedure_type
        if procedure_type is BasicType.DEFERRED:
            return BasicType.DEFERRED
        return procedure_type.procedure

    def arg_iter(self):
        """
        Iterator that maps argument definitions in the target :any:`Subroutine`
        to arguments and keyword arguments in the call.

        Returns
        -------
        iterator
            An iterator that traverses the mapping ``(arg name, call arg)`` for
            all positional and then keyword arguments.
        """
        routine = self.routine
        assert routine is not BasicType.DEFERRED
        r_args = CaseInsensitiveDict((arg.name, arg) for arg in routine.arguments)
        args = zip(routine.arguments, self.arguments)
        kwargs = ((r_args[kw], arg) for kw, arg in as_tuple(self.kwarguments))
        return chain(args, kwargs)

    @property
    def arg_map(self):
        """
        A full map of all qualified argument matches from arguments
        and keyword arguments.
        """
        return dict(self.arg_iter())

    def _sort_kwarguments(self):
        """
        Helper routine to sort the kwarguments according to the order of the
        arguments (``self.routine.arguments``)`.
        """
        routine = self.routine
        assert routine is not BasicType.DEFERRED
        kwargs = CaseInsensitiveDict(self.kwarguments)
        r_arg_names = [arg.name for arg in routine.arguments if arg.name in kwargs]
        new_kwarguments = tuple((arg_name, kwargs[arg_name]) for arg_name in r_arg_names)
        return new_kwarguments

    def is_kwargs_order_correct(self):
        """
        Check whether kwarguments are correctly ordered
        in respect to the arguments (``self.routine.arguments``).
        """
        return self.kwarguments == self._sort_kwarguments()

    def sort_kwarguments(self):
        """
        Sort and update the kwarguments according to the order of the
        arguments (``self.routine.arguments``).
        """
        new_kwarguments = self._sort_kwarguments()
        self._update(kwarguments=new_kwarguments)

    def convert_kwargs_to_args(self):
        """
        Convert all kwarguments to arguments and update the call accordingly.
        """
        new_kwarguments = self._sort_kwarguments()
        new_args = tuple(arg[1] for arg in new_kwarguments)
        self._update(arguments=self.arguments + new_args, kwarguments=())


@dataclass_strict(frozen=True)
class _AllocationBase():
    """ Type definitions for :any:`Allocation` node type. """

    variables: Tuple[Expression, ...]
    data_source: Optional[Expression] = None
    status_var: Optional[Expression] = None


@dataclass_strict(frozen=True)
class Allocation(LeafNode, _AllocationBase):
    """
    Internal representation of a variable allocation.

    Parameters
    ----------
    variables : tuple of :any:`pymbolic.primitives.Expression`
        The list of variables that are allocated.
    data_source : :any:`pymbolic.primitives.Expression` or str
        Fortran's ``SOURCE`` allocation option.
    status_var : :any:`pymbolic.primitives.Expression`
        Fortran's ``STAT`` allocation option.
    **kwargs : optional
        Other parameters that are passed on to the parent class constructor.
    """

    _traversable = ['variables', 'data_source', 'status_var']

    def __post_init__(self):
        super().__post_init__()
        assert is_iterable(self.variables)
        assert all(isinstance(var, Expression) for var in self.variables)
        assert self.data_source is None or isinstance(self.data_source, Expression)
        assert self.status_var is None or isinstance(self.status_var, Expression)

    def __repr__(self):
        return f'Allocation:: {", ".join(str(var) for var in self.variables)}'


@dataclass_strict(frozen=True)
class _DeallocationBase():
    """ Type definitions for :any:`Deallocation` node type. """

    variables: Tuple[Expression, ...]
    status_var: Optional[Expression] = None


@dataclass_strict(frozen=True)
class Deallocation(LeafNode, _DeallocationBase):
    """
    Internal representation of a variable deallocation.

    Parameters
    ----------
    variables : tuple of :any:`pymbolic.primitives.Expression`
        The list of variables that are deallocated.
    status_var : :any:`pymbolic.primitives.Expression`
        Fortran's ``STAT`` deallocation option.
    **kwargs : optional
        Other parameters that are passed on to the parent class constructor.
    """

    _traversable = ['variables', 'status_var']

    def __post_init__(self):
        super().__post_init__()
        assert is_iterable(self.variables)
        assert all(isinstance(var, Expression) for var in self.variables)
        assert self.status_var is None or isinstance(self.status_var, Expression)

    def __repr__(self):
        return f'Deallocation:: {", ".join(str(var) for var in self.variables)}'


@dataclass_strict(frozen=True)
class _NullifyBase():
    """ Type definitions for :any:`Nullify` node type. """

    variables: Tuple[Expression, ...]


@dataclass_strict(frozen=True)
class Nullify(LeafNode, _NullifyBase):
    """
    Internal representation of a pointer nullification.

    Parameters
    ----------
    variables : tuple of :any:`pymbolic.primitives.Expression`
        The list of pointer variables that are nullified.
    **kwargs : optional
        Other parameters that are passed on to the parent class constructor.
    """

    _traversable = ['variables']

    def __post_init__(self):
        super().__post_init__()
        assert is_iterable(self.variables)
        assert all(isinstance(var, Expression) for var in self.variables)

    def __repr__(self):
        return f'Nullify:: {", ".join(str(var) for var in self.variables)}'


@dataclass_strict(frozen=True)
class _CommentBase():
    """ Type definitions for :any:`Comment` node type. """

    text: str


@dataclass_strict(frozen=True)
class Comment(LeafNode, _CommentBase):
    """
    Internal representation of a single comment.

    Parameters
    ----------
    text : str, optional
        The content of the comment. Can be empty to represent empty lines
        in the original source.
    **kwargs : optional
        Other parameters that are passed on to the parent class constructor.
    """

    def __post_init__(self):
        assert isinstance(self.text, str)

    def __repr__(self):
        return f'Comment:: {truncate_string(self.text)}'


@dataclass_strict(frozen=True)
class _CommentBlockBase():
    """ Type definitions for :any:`CommentBlock` node type. """

    comments: Tuple[Node, ...]


@dataclass_strict(frozen=True)
class CommentBlock(LeafNode, _CommentBlockBase):
    """
    Internal representation of a block comment that is formed from
    multiple single-line comments.

    Parameters
    ----------
    comments: tuple of :any:`Comment`
        The individual (subsequent) comments that make up the block.
    **kwargs : optional
        Other parameters that are passed on to the parent class constructor.
    """

    def __post_init__(self):
        super().__post_init__()
        assert self.comments is not None
        assert is_iterable(self.comments)

    @property
    def text(self):
        """The combined string of all comments in this block"""
        return ''.join(comment.text for comment in self.comments)

    def __repr__(self):
        return f'CommentBlock:: {truncate_string(self.text)}'


@dataclass_strict(frozen=True)
class _PragmaBase():
    """ Type definitions for :any:`Pragma` node type. """

    keyword: str
    content: Optional[str] = None


@dataclass_strict(frozen=True)
class Pragma(LeafNode, _PragmaBase):
    """
    Internal representation of a pragma.

    Pragmas are assumed to appear in Fortran source code in the form of
    `!$ `.

    Parameters
    ----------
    keyword : str
        The keyword of the pragma.
    content : str, optional
        The content of the pragma after the keyword.
    **kwargs : optional
        Other parameters that are passed on to the parent class constructor.
    """

    def __post_init__(self):
        super().__post_init__()
        assert self.keyword and isinstance(self.keyword, str)

    def __repr__(self):
        return f'Pragma:: {self.keyword} {truncate_string(self.content)}'


@dataclass_strict(frozen=True)
class _PreprocessorDirectiveBase():
    """ Type definitions for :any:`PreprocessorDirective` node type. """

    text: str = None


@dataclass_strict(frozen=True)
class PreprocessorDirective(LeafNode, _PreprocessorDirectiveBase):
    """
    Internal representation of a preprocessor directive.

    Preprocessor directives are typically assumed to start at the beginning of
    a line with the letter ``#`` in the original source.

    Parameters
    ----------
    text : str, optional
        The content of the directive.
    **kwargs : optional
        Other parameters that are passed on to the parent class constructor.
    """

    def __repr__(self):
        return f'PreprocessorDirective:: {truncate_string(self.text)}'


@dataclass_strict(frozen=True)
class _ImportBase():
    """ Type definitions for :any:`Import` node type. """

    module: Optional[str]
    symbols: Tuple[Expression, ...] = ()
    nature: Optional[str] = None
    c_import: bool = False
    f_include: bool = False
    f_import: bool = False
    rename_list: Optional[Tuple[Any, ...]] = None


@dataclass_strict(frozen=True)
class Import(LeafNode, _ImportBase):
    """
    Internal representation of an import.

    Parameters
    ----------
    module : str
        The name of the module or header file to import from.
    symbols : tuple of :any:`Expression` or :any:`DataType`, optional
        The list of names imported. Can be empty when importing all.
    nature : str, optional
        The module nature (``INTRINSIC`` or ``NON_INTRINSIC``)
    c_import : bool, optional
        Flag to indicate that this is a C-style include. Defaults to `False`.
    f_include : bool, optional
        Flag to indicate that this is a preprocessor-style include in
        Fortran source code.
    f_import : bool, optional
        Flag to indicate that this is a Fortran ``IMPORT``.
    rename_list: tuple of tuples (`str`, :any:`Expression`), optional
        Rename list with pairs of `(use name, local name)` entries
    **kwargs : optional
        Other parameters that are passed on to the parent class constructor.
    """

    _traversable = ['symbols', 'rename_list']

    def __post_init__(self):
        super().__post_init__()
        assert self.module is None or isinstance(self.module, str)
        assert isinstance(self.symbols, tuple)
        assert all(isinstance(s, (Expression, DataType)) for s in self.symbols)
        assert self.nature is None or (
            isinstance(self.nature, str) and
            self.nature.lower() in ('intrinsic', 'non_intrinsic') and
            not (self.c_import or self.f_include or self.f_import)
        )
        if self.c_import + self.f_include + self.f_import not in (0, 1):
            raise ValueError('Import can only be either C include, F include or F import')
        if self.rename_list and (self.symbols or self.c_import or self.f_include or self.f_import):
            raise ValueError('Import cannot have rename and only lists or be an include')

    def __repr__(self):
        if self.f_import:
            return f'Import:: {self.symbols}'
        _c = 'C-' if self.c_import else 'F-' if self.f_include else ''
        return f'{_c}Import:: {self.module} => {self.symbols}'


@dataclass_strict(frozen=True)
class _VariableDeclarationBase():
    """ Type definitions for :any:`VariableDeclaration` node type. """

    symbols: Tuple[Expression, ...]
    dimensions: Optional[Tuple[Expression, ...]] = None
    comment: Optional[Node] = None
    pragma: Optional[Node] = None


@dataclass_strict(frozen=True)
class VariableDeclaration(LeafNode, _VariableDeclarationBase):
    """
    Internal representation of a variable declaration.

    Parameters
    ----------
    symbols : tuple of :any:`pymbolic.primitives.Expression`
        The list of variables declared by this declaration.
    dimensions : tuple of :any:`pymbolic.primitives.Expression`, optional
        The declared allocation size if given as part of the declaration
        attributes.
    comment : :py:class:`Comment`, optional
        Inline comment that appears in-line after the declaration in the
        original source.
    pragma : tuple of :any:`Pragma`, optional
        Pragma(s) that appear before the declaration. By default
        :any:`Pragma` nodes appear as standalone nodes in the IR.
        Only a bespoke context created by :py:func:`pragmas_attached`
        attaches them for convenience.
    **kwargs : optional
        Other parameters that are passed on to the parent class constructor.
    """

    _traversable = ['symbols', 'dimensions']

    def __post_init__(self):
        super().__post_init__()
        assert self.symbols is not None
        assert is_iterable(self.symbols)
        assert all(isinstance(s, Expression) for s in self.symbols)

        if self.dimensions is not None:
            assert is_iterable(self.dimensions)
            assert all(isinstance(d, Expression) for d in self.dimensions)  # pylint: disable=not-an-iterable

    def __repr__(self):
        symbols = ', '.join(str(var) for var in self.symbols)
        return f'VariableDeclaration:: {symbols}'


@dataclass_strict(frozen=True)
class _ProcedureDeclarationBase():
    """ Type definitions for :any:`ProcedureDeclaration` node type. """

    symbols: Tuple[Expression, ...]
    interface: Optional[Union[Expression, DataType]] = None
    external: bool = False
    module: bool = False
    generic: bool = False
    final: bool = False
    comment: Optional[Node] = None
    pragma: Optional[Tuple[Node, ...]] = None


@dataclass_strict(frozen=True)
class ProcedureDeclaration(LeafNode, _ProcedureDeclarationBase):
    """
    Internal representation of a procedure declaration.

    Parameters
    ----------
    symbols : tuple of :any:`pymbolic.primitives.Expression`
        The list of procedure symbols declared by this declaration.
    interface : :any:`pymbolic.primitives.Expression` or :any:`DataType`, optional
        The procedure interface of the declared procedure entity names.
    external : bool, optional
        This is a Fortran ``EXTERNAL`` declaration.
    module : bool, optional
        This is a Fortran ``MODULE PROCEDURE`` declaration in an interface
        (i.e. includes the keyword ``MODULE``)
    generic : bool,  optional
        This is a generic binding procedure statement in a derived type.
    final : bool, optional
        This is a declaration to mark a subroutine for clean-up of a
        derived type.
    comment : :py:class:`Comment`, optional
        Inline comment that appears in-line after the declaration in the
        original source.
    pragma : tuple of :any:`Pragma`, optional
        Pragma(s) that appear before the declaration. By default
        :any:`Pragma` nodes appear as standalone nodes in the IR.
        Only a bespoke context created by :py:func:`pragmas_attached`
        attaches them for convenience.
    **kwargs : optional
        Other parameters that are passed on to the parent class constructor.
    """

    _traversable = ['symbols', 'interface']

    def __post_init__(self):
        super().__post_init__()
        assert is_iterable(self.symbols)
        assert all(isinstance(var, Expression) for var in self.symbols)
        assert self.interface is None or isinstance(self.interface, (Expression, DataType))

        assert self.external + self.module + self.generic + self.final in (0, 1)

    def __repr__(self):
        symbols = ', '.join(str(var) for var in self.symbols)
        return f'ProcedureDeclaration:: {symbols}'


@dataclass_strict(frozen=True)
class _DataDeclarationBase():
    """ Type definitions for :any:`DataDeclaration` node type. """

    # TODO: This should only allow Expression instances but needs frontend changes
    # TODO: Support complex statements (LOKI-23)
    variable: Any
    values: Tuple[Expression, ...]


@dataclass_strict(frozen=True)
class DataDeclaration(LeafNode, _DataDeclarationBase):
    """
    Internal representation of a ``DATA`` declaration for explicit array
    value lists.

    Parameters
    ----------
    variable : :any:`pymbolic.primitives.Expression`
        The left-hand side of the data declaration.
    values : tuple of :any:`pymbolic.primitives.Expression`
        The right-hand side of the data declaration.
    **kwargs : optional
        Other parameters that are passed on to the parent class constructor.
    """

    _traversable = ['variable', 'values']

    def __post_init__(self):
        super().__post_init__()
        assert isinstance(self.variable, (Expression, str, tuple))
        assert is_iterable(self.values)
        assert all(isinstance(val, Expression) for val in self.values)

    def __repr__(self):
        return f'DataDeclaration:: {str(self.variable)}'


@dataclass_strict(frozen=True)
class _StatementFunctionBase():
    """ Type definitions for :any:`StatementFunction` node type. """

    variable: Expression
    arguments: Tuple[Expression, ...]
    rhs: Expression
    return_type: SymbolAttributes


@dataclass_strict(frozen=True)
class StatementFunction(LeafNode, _StatementFunctionBase):
    """
    Internal representation of Fortran statement function statements

    Parameters
    ----------
    variable : :any:`pymbolic.primitives.Expression`
        The name of the statement function
    arguments : tuple of :any:`pymbolic.primitives.Expression`
        The list of dummy arguments
    rhs : :any:`pymbolic.primitives.Expression`
        The expression defining the statement function
    return_type : :any:`SymbolAttributes`
        The return type of the statement function
    """

    _traversable = ['variable', 'arguments', 'rhs']

    def __post_init__(self):
        super().__post_init__()
        assert isinstance(self.variable, Expression)
        assert is_iterable(self.arguments) and all(isinstance(a, Expression) for a in self.arguments)
        assert isinstance(self.return_type, SymbolAttributes)

    @property
    def name(self):
        return str(self.variable)

    @property
    def is_function(self):
        return True

    def __repr__(self):
        return f'StatementFunction:: {self.variable}({" ,".join(str(a) for a in self.arguments)})'


@dataclass_strict(frozen=True)
class _TypeDefBase():
    """ Type definitions for :any:`TypeDef` node type. """

    name: Optional[str] = None
    body: Optional[Tuple[Node, ...]] = None
    abstract: bool = False
    extends: Optional[str] = None
    bind_c: bool = False
    private: bool = False
    public: bool = False


@dataclass_strict(frozen=True)
class TypeDef(ScopedNode, InternalNode, _TypeDefBase):
    """
    Internal representation of a derived type definition.

    Similar to :py:class:`Sourcefile`, :py:class:`Module`, and
    :py:class:`Subroutine`, it forms its own scope for symbols and types.
    This is required to instantiate :py:class:`TypedSymbol` instances in
    declarations, imports etc. without having them show up in the enclosing
    scope.

    Parameters
    ----------
    name : str
        The name of the type.
    body : tuple
        The body of the type definition.
    abstract : bool, optional
        Flag to indicate that this is an abstract type definition.
    extends : str, optional
        The parent type name
    bind_c : bool, optional
        Flag to indicate that this contains a ``BIND(C)`` attribute.
    private : bool, optional
        Flag to indicate that this has been declared explicitly as ``PRIVATE``
    public : bool, optional
        Flag to indicate that this has been declared explicitly as ``PUBLIC``
    parent : :any:`Scope`, optional
        The parent scope in which the type definition appears
    symbol_attrs : :any:`SymbolTable`, optional
        An existing symbol table to use
    **kwargs : optional
        Other parameters that are passed on to the parent class constructor.
    """

    _traversable = ['body']

    def __post_init__(self, parent=None):
        super(ScopedNode, self).__post_init__(parent=parent)
        super(InternalNode, self).__post_init__()

        # Register this typedef in the parent scope
        if self.parent:
            self.parent.symbol_attrs[self.name] = SymbolAttributes(self.dtype)

    @property
    def ir(self):
        return self.body

    @property
    def parent_type(self):
        if not self.extends:
            return None
        if not self.parent:
            return BasicType.DEFERRED
        parent_type = self.parent.symbol_attrs.lookup(self.extends)
        if not (parent_type and isinstance(parent_type.dtype, DerivedType)):
            return BasicType.DEFERRED
        return parent_type.dtype.typedef

    @property
    def declarations(self):
        decls = tuple(
            c for c in as_tuple(self.body)
            if isinstance(c, (VariableDeclaration, ProcedureDeclaration))
        )

        # Inherit non-overriden symbols from parent type
        if (parent_type := self.parent_type) and parent_type is not BasicType.DEFERRED:
            local_symbols = [s for decl in decls for s in decl.symbols]
            for decl in parent_type.declarations:
                decl_symbols = tuple(s.clone(scope=self) for s in decl.symbols if s not in local_symbols)
                if decl_symbols:
                    decls += (decl.clone(symbols=decl_symbols),)

        return decls

    @property
    def comments(self):
        return tuple(c for c in as_tuple(self.body) if isinstance(c, Comment))

    @property
    def variables(self):
        return tuple(flatten([decl.symbols for decl in self.declarations]))

    @property
    def imported_symbols(self):
        """
        Return the symbols imported in this typedef
        """
        return tuple(flatten(c.symbols for c in as_tuple(self.body) if isinstance(c, Import)))

    @property
    def imported_symbol_map(self):
        """
        Map of imported symbol names to objects
        """
        return CaseInsensitiveDict((s.name, s) for s in self.imported_symbols)

    def __contains__(self, name):
        """
        Check if a symbol with the given name is declared in this type
        """
        return name in self.variables

    @property
    def interface_symbols(self):
        """
        Return the list of symbols declared via interfaces in this unit

        This returns always an empty tuple since there are no interface declarations
        allowed in typedefs.
        """
        return ()

    @property
    def dtype(self):
        """
        Return the :any:`DerivedType` representing this type
        """
        return DerivedType(name=self.name, typedef=self)

    def __repr__(self):
        return f'TypeDef:: {self.name}'

    def clone(self, **kwargs):
        from loki.ir.transformer import Transformer  # pylint: disable=import-outside-toplevel,cyclic-import
        if 'body' not in kwargs:
            kwargs['body'] = Transformer().visit(self.body)
        return super().clone(**kwargs)


@dataclass_strict(frozen=True)
class _MultiConditionalBase():
    """ Type definitions for :any:`MultiConditional` node type. """

    expr: Expression
    values: Tuple[Tuple[Expression, ...], ...]
    bodies: Tuple[Any, ...]
    else_body: Tuple[Node, ...]
    name: Optional[str] = None


@dataclass_strict(frozen=True)
class MultiConditional(LeafNode, _MultiConditionalBase):
    """
    Internal representation of a multi-value conditional (eg. ``SELECT CASE``).

    Parameters
    ----------
    expr : :any:`pymbolic.primitives.Expression`
        The expression that is evaluated to choose the appropriate case.
    values : tuple of tuple of :any:`pymbolic.primitives.Expression`
        The list of values, a tuple for each case.
    bodies : tuple of tuple
        The corresponding bodies for each case.
    else_body : tuple
        The body for the ``DEFAULT`` case.
    name : str, optional
        The construct-name of the multi conditional in the original source.
    **kwargs : optional
        Other parameters that are passed on to the parent class constructor.
    """

    _traversable = ['expr', 'values', 'bodies', 'else_body']

    @field_validator('else_body', mode='before')
    @classmethod
    def ensure_tuple(cls, value):
        return _sanitize_tuple(value)

    @field_validator('values', 'bodies', mode='before')
    @classmethod
    def ensure_nested_tuple(cls, value):
        return tuple(_sanitize_tuple(pair) for pair in as_tuple(value))

    def __post_init__(self):
        super().__post_init__()
        assert isinstance(self.expr, Expression)
        assert is_iterable(self.values)
        assert all(isinstance(v, tuple) and all(isinstance(c, Expression) for c in v)
                                           for v in self.values)
        assert is_iterable(self.bodies) and all(is_iterable(b) for b in self.bodies)
        assert is_iterable(self.else_body)

    def __repr__(self):
        label = f' {self.name}' if self.name else ''
        return f'MultiConditional::{label} {str(self.expr)}'


@dataclass_strict(frozen=True)
class _TypeConditionalBase():
    """ Type definitions for :any:`TypeConditional` node type. """

    expr: Expression
    values: Tuple[Tuple[Expression, bool], ...]
    bodies: Tuple[Any, ...]
    else_body: Tuple[Node, ...]
    name: Optional[str] = None


@dataclass_strict(frozen=True)
class TypeConditional(LeafNode, _TypeConditionalBase):
    """
    Internal representation of a multi-type conditional (eg. ``SELECT TYPE``).

    Parameters
    ----------
    expr : :any:`pymbolic.primitives.Expression`
        The expression that is evaluated to choose the appropriate type case.
    values : tuple of 2-tuple with (:any:`pymbolic.primitives.Expression`, bool)
        The list of values, a tuple for each case consisting of the type name and
        a bool value to indicate if this is a derived/polymorphic type case (in Fortran,
        this yields the difference between ``TYPE IS`` and ``CLASS IS``)
    bodies : tuple of tuple
        The corresponding bodies for each case.
    else_body : tuple
        The body for the ``DEFAULT`` case.
    name : str, optional
        The construct-name of the multi conditional in the original source.
    **kwargs : optional
        Other parameters that are passed on to the parent class constructor.
    """

    _traversable = ['expr', 'values', 'bodies', 'else_body']

    def __post_init__(self):
        super().__post_init__()
        assert isinstance(self.expr, Expression)
        assert is_iterable(self.values)
        assert all(
            isinstance(v, tuple) and len(v) == 2 and isinstance(v[0], Expression) and isinstance(v[1], bool)
            for v in self.values
        )
        assert is_iterable(self.bodies) and all(is_iterable(b) for b in self.bodies)
        assert is_iterable(self.else_body)

    def __repr__(self):
        label = f' {self.name}' if self.name else ''
        return f'TypeConditional::{label} {str(self.expr)}'


@dataclass_strict(frozen=True)
class _ForallBase():
    """ Type definition for :any:`Forall` node type. """

    named_bounds: Tuple[Tuple[Expression, Expression], ...]
    body: Tuple[Node, ...]
    mask: Optional[Expression] = None
    name: Optional[str] = None
    inline: bool = False


@dataclass_strict(frozen=True)
class Forall(InternalNode, _ForallBase):
    """
    Internal representation of a FORALL statement or construct.

    Parameters
    ----------
    named_bounds : tuple of pairs (, ) of type :any:`pymbolic.primitives.Expression`
        The collection of named variables with bounds (ranges).
    body : tuple of :any:`Node`
        The collection of assignment statements, nested FORALLs, and/or comments.
    mask : :any:`pymbolic.primitives.Expression`, optional
        The condition that define the mask.
    name : str, optional
        The name of the multi-line FORALL construct in the original source.
    inline : bool, optional
        Flag to indicate a single-line FORALL statement.
    **kwargs : optional
        Other parameters that are passed on to the parent class constructor.
    """
    _traversable = ['named_bounds', 'mask', 'body']

    def __post_init__(self):
        super().__post_init__()
        assert is_iterable(self.named_bounds) and all(isinstance(c, tuple) for c in self.named_bounds), \
            "FORALL named bounds must be tuples of "
        assert is_iterable(self.body), "FORALL body must be iterable"
        if self.inline:
            assert len(self.body) == 1, "FORALL statement must contain exactly one assignment"
            assert self.name is None, "FORALL statement cannot have a name label"

    def __repr__(self):
        return f"Forall:: {', '.join([e[0].name for e in self.named_bounds])}"


@dataclass_strict(frozen=True)
class _MaskedStatementBase():
    """ Type definitions for :any:`MaskedStatement` node type. """

    conditions: Tuple[Expression, ...]
    bodies: Tuple[Tuple[Node, ...], ...]
    default: Optional[Tuple[Node, ...]] = None
    inline: bool = False


@dataclass_strict(frozen=True)
class MaskedStatement(LeafNode, _MaskedStatementBase):
    """
    Internal representation of a masked array assignment (``WHERE`` clause).

    Parameters
    ----------
    conditions : tuple of :any:`pymbolic.primitives.Expression`
        The conditions that define the mask
    bodies : tuple of tuple of :any:`Node`
        The conditional assignment statements corresponding to each condition.
    default : tuple of :any:`Node`, optional
        The assignment statements to be executed for array entries not
        captured by the mask (``ELSEWHERE`` statement).
    inline : bool, optional
        Flag to indicate this is a one-line where-stmt
    **kwargs : optional
        Other parameters that are passed on to the parent class constructor.
    """

    _traversable = ['conditions', 'bodies', 'default']

    def __post_init__(self):
        super().__post_init__()
        assert is_iterable(self.conditions) and all(isinstance(c, Expression) for c in self.conditions)
        assert is_iterable(self.bodies) and all(isinstance(c, tuple) for c in self.bodies)
        assert len(self.conditions) == len(self.bodies)
        assert is_iterable(self.default)

        if self.inline:
            assert len(self.bodies) == 1 and len(self.bodies[0]) == 1 and not self.default

    def __repr__(self):
        return f'MaskedStatement:: {str(self.conditions[0])}'


@dataclass(frozen=True)
class _IntrinsicBase():
    """ Type definitions for :any:`Intrinsic` node type. """

    text: str


@dataclass_strict(frozen=True)
class Intrinsic(LeafNode, _IntrinsicBase):
    """
    Catch-all generic node for corner-cases.

    This is provided as a fallback for any statements that do not have
    an appropriate representation in the IR. These can either be language
    features for which support was not yet added, or statements that are not
    relevant in Loki's scope of applications. This node retains the text of
    the statement in the original source as-is.

    Parameters
    ----------
    text : str
        The statement as a string.
    **kwargs : optional
        Other parameters that are passed on to the parent class constructor.
    """

    def __post_init__(self):
        super().__post_init__()
        assert isinstance(self.text, str)

    def __repr__(self):
        return f'Intrinsic:: {truncate_string(self.text)}'


@dataclass_strict(frozen=True)
class _EnumerationBase():
    """ Type definitions for :any:`Enumeration` node type. """

    symbols: Tuple[Expression, ...]


@dataclass_strict(frozen=True)
class Enumeration(LeafNode, _EnumerationBase):
    """
    Internal representation of an ``ENUM``

    The constants declared by this are represented as :any:`Variable`
    objects with their value (if specified explicitly) stored as the
    ``initial`` property in the symbol's type.

    Parameters
    ----------
    symbols : list of :any:`Expression`
        The named constants declared in this enum
    **kwargs : optional
        Other parameters that are passed on to the parent class constructor.
    """

    def __post_init__(self):
        super().__post_init__()
        if self.symbols is not None:
            assert all(isinstance(s, Expression) for s in self.symbols)  # pylint: disable=not-an-iterable

    def __repr__(self):
        symbols = ', '.join(str(var) for var in as_tuple(self.symbols))
        return f'Enumeration:: {symbols}'


@dataclass_strict(frozen=True)
class _RawSourceBase():
    """ Type definitions for :any:`RawSource` node type. """

    text: str


@dataclass_strict(frozen=True)
class RawSource(LeafNode, _RawSourceBase):
    """
    Generic node for unparsed source code sections

    This is used by the :any:`REGEX` frontend to store unparsed code sections
    in the IR. Currently, they don't serve any other purpose than making sure
    the entire string content of the original Fortran source is retained.

    Parameters
    ----------
    text : str
        The source code as a string.
    **kwargs : optional
        Other parameters that are passed on to the parent class constructor.
    """

    def __repr__(self):
        return f'RawSource:: {truncate_string(self.text.strip())}'
loki-ecmwf-0.3.6/loki/program_unit.py0000664000175000017500000010076115167130205017764 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from abc import abstractmethod

from loki.expression import Variable, parse_expr
from loki.frontend import (
    Frontend, parse_omni_source, parse_fparser_source,
    RegexParserClass, preprocess_cpp, sanitize_input
)
from loki.ir import (
    nodes as ir, FindNodes, Transformer, ExpressionTransformer
)
from loki.logging import debug
from loki.tools import CaseInsensitiveDict, as_tuple, flatten
from loki.types import BasicType, DerivedType, ProcedureType, Scope


__all__ = ['ProgramUnit']


class ProgramUnit(Scope):
    """
    Common base class for :any:`Module` and :any:`Subroutine`

    Parameters
    ----------
    name : str
        Name of the program unit.
    docstring : tuple of :any:`Node`, optional
        The docstring in the original source.
    spec : :any:`Section`, optional
        The spec of the program unit.
    contains : :any:`Section`, optional
        The internal-subprogram part following a ``CONTAINS`` statement
        declaring module or member procedures
    ast : optional
        Parse tree node from the frontend for this program unit
    source : :any:`Source`
        Source object representing the raw source string information from the
        original file.
    parent : :any:`Scope`, optional
        The enclosing parent scope of the program unit. Declarations from
        the parent scope remain valid within the program unit's scope
        (unless shadowed by local declarations).
    rescope_symbols : bool, optional
        Ensure that the type information for all :any:`TypedSymbol` in the
        IR exist in this program unit's scope or the scope's parents.
        Defaults to `False`.
    symbol_attrs : :any:`SymbolTable`, optional
        Use the provided :any:`SymbolTable` object instead of creating a new
    incomplete : bool, optional
        Mark the object as incomplete, i.e. only partially parsed. This is
        typically the case when it was instantiated using the :any:`Frontend.REGEX`
        frontend and a full parse using one of the other frontends is pending.
    parser_classes : :any:`RegexParserClass`, optional
        Provide the list of parser classes used during incomplete regex parsing
    """

    def __initialize__(self, name, docstring=None, spec=None, contains=None,
                       ast=None, source=None, rescope_symbols=False, incomplete=False,
                       parser_classes=None):
        # Common properties
        assert name and isinstance(name, str)
        self.name = name
        self._ast = ast
        self._source = source
        self._incomplete = incomplete
        self._parser_classes = parser_classes

        # Bring arguments into shape
        if spec is not None and not isinstance(spec, ir.Section):
            spec = ir.Section(body=as_tuple(spec))
        if contains is not None:
            if not isinstance(contains, ir.Section):
                contains = ir.Section(body=as_tuple(contains))
            for node in contains.body:
                if isinstance(node, ir.Intrinsic) and 'contains' in node.text.lower():  # pylint: disable=no-member
                    break
                if isinstance(node, ProgramUnit):
                    contains.prepend(ir.Intrinsic(text='CONTAINS'))
                    break

        # Primary IR components
        self.docstring = as_tuple(docstring)
        self.spec = spec
        self.contains = contains

        # Finally, register this object in the parent scope
        self.register_in_parent_scope()

        if rescope_symbols:
            self.rescope_symbols()

    @classmethod
    def from_source(cls, source, definitions=None, preprocess=False,
                    includes = None, defines=None, xmods=None, omni_includes=None,
                    frontend=Frontend.FP, parser_classes=None, parent=None):
        """
        Instantiate an object derived from :any:`ProgramUnit` from raw source string

        This calls the frontend-specific factory method implemented in the derived class,
        such as :any:`Module` or :any:`Subroutine`

        Parameters
        ----------
        source : str
            Fortran source string
        definitions : list of :any:`Module`, optional
            :any:`Module` object(s) that may supply external type or procedure
            definitions.
        preprocess : bool, optional
            Flag to trigger CPP preprocessing (by default `False`).

            .. attention::
                Please note that, when using the OMNI frontend, C-preprocessing
                will always be applied, so :data:`includes` and :data:`defines`
                may have to be defined even when disabling :data:`preprocess`.

        includes : list of str, optional
            Include paths to pass to the C-preprocessor.
        defines : list of str, optional
            Symbol definitions to pass to the C-preprocessor.
        xmods : str, optional
            Path to directory to find and store ``.xmod`` files when using the
            OMNI frontend.
        omni_includes: list of str, optional
            Additional include paths to pass to the preprocessor run as part of
            the OMNI frontend parse. If set, this **replaces** (!)
            :data:`includes`, otherwise :data:`omni_includes` defaults to the
            value of :data:`includes`.
        frontend : :any:`Frontend`, optional
            Frontend to use for producing the AST (default :any:`FP`).
        parent : :any:`Scope`, optional
            The parent scope this module or subroutine is nested into
        """
        if isinstance(frontend, str):
            frontend = Frontend[frontend.upper()]

        if preprocess:
            # Trigger CPP-preprocessing explicitly, as includes and
            # defines can also be used by our OMNI frontend
            if frontend == Frontend.OMNI and omni_includes:
                includes = omni_includes
            source = preprocess_cpp(source=source, includes=includes, defines=defines)

        if frontend == Frontend.REGEX:
            return cls.from_regex(raw_source=source, parser_classes=parser_classes, parent=parent)

        if frontend == Frontend.OMNI:
            ast = parse_omni_source(source, xmods=xmods)
            type_map = {t.attrib['type']: t for t in ast.find('typeTable')}
            return cls.from_omni(ast=ast, raw_source=source, definitions=definitions,
                                 type_map=type_map, parent=parent)

        if frontend == Frontend.FP:
            # Preprocess using internal frontend-specific PP rules
            # to sanitize input and work around known frontend problems.
            source, pp_info = sanitize_input(source=source, frontend=frontend)

            ast = parse_fparser_source(source)
            return cls.from_fparser(ast=ast, raw_source=source, definitions=definitions,
                                    pp_info=pp_info, parent=parent)

        raise NotImplementedError(f'Unknown frontend: {frontend}')

    @classmethod
    @abstractmethod
    def from_omni(cls, ast, raw_source, definitions=None, parent=None, type_map=None):
        """
        Create the :any:`ProgramUnit` object from an :any:`OMNI` parse tree.

        This method must be implemented by the derived class.

        Parameters
        ----------
        ast :
            The OMNI parse tree
        raw_source : str
            Fortran source string
        definitions : list, optional
            List of external :any:`Module` to provide derived-type and procedure declarations
        parent : :any:`Scope`, optional
            The enclosing parent scope of the module
        typetable : dict, optional
            A mapping from type hash identifiers to type definitions, as provided in
            OMNI's ``typeTable`` parse tree node
        """

    @classmethod
    @abstractmethod
    def from_fparser(cls, ast, raw_source, definitions=None, pp_info=None, parent=None):
        """
        Create the :any:`ProgramUnit` object from an :any:`FP` parse tree.

        This method must be implemented by the derived class.

        Parameters
        ----------
        ast :
            The FParser parse tree
        raw_source : str
            Fortran source string
        definitions : list
            List of external :any:`Module` to provide derived-type and procedure declarations
        pp_info :
            Preprocessing info as obtained by :any:`sanitize_input`
        parent : :any:`Scope`, optional
            The enclosing parent scope of the module.
        """

    @classmethod
    @abstractmethod
    def from_regex(cls, raw_source, parser_classes=None, parent=None):
        """
        Create the :any:`ProgramUnit` object from source regex'ing.

        This method must be implemented by the derived class.

        Parameters
        ----------
        raw_source : str
            Fortran source string
        parent : :any:`Scope`, optional
            The enclosing parent scope of the module.
        """

    @abstractmethod
    def register_in_parent_scope(self):
        """
        Insert the type information for this object in the parent's symbol table

        If :attr:`parent` is `None`, this does nothing.

        This method must be implemented by the derived class.
        """

    def make_complete(self, **frontend_args):
        """
        Trigger a re-parse of the object if incomplete to produce a full Loki IR

        If the object is marked to be incomplete, i.e. when using the `lazy` constructor
        option, this triggers a new parsing of all :any:`ProgramUnit` objects and any
        :any:`RawSource` nodes in the :attr:`ir`.

        Existing :any:`Module` and :any:`Subroutine` objects continue to exist and references
        to them stay valid, as they will only be updated instead of replaced.
        """
        if not self._incomplete:
            return
        frontend = frontend_args.pop('frontend', Frontend.FP)
        if isinstance(frontend, str):
            frontend = Frontend[frontend.upper()]
        definitions = frontend_args.get('definitions')
        xmods = frontend_args.get('xmods')
        parser_classes = frontend_args.get('parser_classes', RegexParserClass.AllClasses)
        if frontend == Frontend.REGEX and self._parser_classes:
            if self._parser_classes == (self._parser_classes | parser_classes):
                return
            parser_classes = parser_classes | self._parser_classes

        # If this object does not have a parent, we create a temporary parent scope
        # and make sure the node exists in the parent scope. This way, the existing
        # object is re-used while converting the parse tree to Loki-IR.
        has_parent = self.parent is not None
        if not has_parent:
            parent_scope = Scope(parent=None)
            self._reset_parent(parent_scope)
        if self.name not in self.parent.symbol_attrs:
            self.register_in_parent_scope()

        ir_ = self.from_source(
            self.source.string, frontend=frontend, definitions=definitions, xmods=xmods,
            parser_classes=parser_classes, parent=self.parent
        )
        assert ir_ is self

        if not has_parent:
            self._reset_parent(None)

    def enrich(self, definitions, recurse=False):
        """
        Enrich the current scope with inter-procedural annotations

        This updates the :any:`SymbolAttributes` in the scope's :any:`SymbolTable`
        with :data:`definitions` for all imported symbols.

        Note that :any:`Subroutine.enrich` expands this to interface-declared calls.

        Parameters
        ----------
        definitions : list of :any:`ProgramUnit`
            A list of all available definitions
        recurse : bool, optional
            Enrich contained scopes
        """
        definitions_map = CaseInsensitiveDict((r.name, r) for r in as_tuple(definitions))

        # Enrich type info from all known imports (including parent scopes)
        for imprt in self.all_imports:
            if not (module := definitions_map.get(imprt.module)):
                # Skip modules that are not available in the definitions list
                continue

            # Build a list of symbols that are imported
            if imprt.symbols:
                # Import only symbols listed in the only list
                symbols = imprt.symbols
            else:
                # Import all symbols
                rename_list = CaseInsensitiveDict((k, v) for k, v in as_tuple(imprt.rename_list))
                symbols = [
                    Variable(name=str(rename_list.get(symbol.name, symbol.name)), scope=self)
                    for symbol in module.symbols
                ]

            updated_symbol_attrs = {}
            for symbol in symbols:
                # Take care of renaming upon import
                local_name = symbol.name
                remote_name = symbol.type.use_name or local_name
                try:
                    remote_node = module[remote_name]
                except KeyError:
                    remote_node = None

                if remote_node and hasattr(remote_node, 'procedure_type'):
                    # This is a subroutine/function defined in the remote module
                    updated_symbol_attrs[local_name] = symbol.type.clone(
                        dtype=remote_node.procedure_type, imported=True, module=module
                    )
                elif remote_node and hasattr(remote_node, 'dtype'):
                    # This is a derived type defined in the remote module
                    updated_symbol_attrs[local_name] = symbol.type.clone(
                        dtype=remote_node.dtype, imported=True, module=module
                    )
                    # Update dtype for local variables using this type
                    variables_with_this_type = {
                        name: type_.clone(dtype=remote_node.dtype)
                        for name, type_ in self.symbol_attrs.items()
                        if getattr(type_.dtype, 'name') == remote_node.dtype.name
                    }
                    updated_symbol_attrs.update(variables_with_this_type)
                elif remote_node and hasattr(remote_node, 'type'):
                    # This is a global variable or interface import
                    updated_symbol_attrs[local_name] = remote_node.type.clone(
                        imported=True, module=module, use_name=symbol.type.use_name
                    )
                else:
                    debug('Cannot enrich import of %s from module %s', local_name, module.name)
            self.symbol_attrs.update(updated_symbol_attrs)

            if imprt.symbols:
                # Rebuild the symbols in the import's symbol list to obtain the correct
                # expression nodes
                imprt._update(symbols=tuple(symbol.clone() for symbol in imprt.symbols))

        # Update any symbol table entries that have been inherited from the parent
        if self.parent:
            updated_symbol_attrs = {}
            for name, attrs in self.symbol_attrs.items():
                if name not in self.parent.symbol_attrs:
                    continue

                if attrs.imported and not attrs.module:
                    updated_symbol_attrs[name] = self.parent.symbol_attrs[name]
                elif isinstance(attrs.dtype, ProcedureType) and attrs.dtype.procedure is BasicType.DEFERRED:
                    updated_symbol_attrs[name] = self.parent.symbol_attrs[name]
                elif isinstance(attrs.dtype, DerivedType) and attrs.dtype.typedef is BasicType.DEFERRED:
                    updated_symbol_attrs[name] = attrs.clone(dtype=self.parent.symbol_attrs[name].dtype)
            self.symbol_attrs.update(updated_symbol_attrs)

        # Rebuild local symbols to ensure correct symbol types
        self.spec = ExpressionTransformer(inplace=True).visit(self.spec)

        if recurse:
            for routine in self.subroutines:
                routine.enrich(definitions, recurse=True)

    def clone(self, **kwargs):
        """
        Create a deep copy of the object with the option to override individual
        parameters

        Parameters
        ----------
        **kwargs :
            Any parameters from the constructor of the class.

        Returns
        -------
        Object of type ``self.__class__``
            The cloned object.
        """
        # Collect all properties that have not been overriden
        if self.name is not None and 'name' not in kwargs:
            kwargs['name'] = self.name
        if self.docstring and 'docstring' not in kwargs:
            kwargs['docstring'] = self.docstring
        if self.spec and 'spec' not in kwargs:
            kwargs['spec'] = self.spec
        if self.contains and 'contains' not in kwargs:
            contains_needs_clone = True
            kwargs['contains'] = self.contains
        else:
            contains_needs_clone = False
        if self._ast is not None and 'ast' not in kwargs:
            kwargs['ast'] = self._ast
        if self._source is not None and 'source' not in kwargs:
            kwargs['source'] = self._source
        kwargs.setdefault('incomplete', self._incomplete)
        kwargs.setdefault('parser_classes', self._parser_classes)

        # Rebuild IRs
        rebuild = Transformer({}, rebuild_scopes=True)
        if 'docstring' in kwargs:
            kwargs['docstring'] = rebuild.visit(kwargs['docstring'])
        if 'spec' in kwargs:
            kwargs['spec'] = rebuild.visit(kwargs['spec'])
        if 'contains' in kwargs:
            kwargs['contains'] = rebuild.visit(kwargs['contains'])

        # Rescope symbols if not explicitly disabled
        kwargs.setdefault('rescope_symbols', True)

        # Escalate to Scope's clone function
        obj = super().clone(**kwargs)

        # Update contained routines with new parent scope
        # TODO: Convert ProgramUnit to an IR node(-like) object and make this
        #       work via `Transformer`
        if obj.contains:
            if contains_needs_clone:
                contains = [
                    node.clone(parent=obj, rescope_symbols=kwargs['rescope_symbols'])
                    if isinstance(node, ProgramUnit) else node
                    for node in obj.contains.body
                ]
                obj.contains = obj.contains.clone(body=as_tuple(contains))
            else:
                for node in obj.contains.body:
                    if isinstance(node, ProgramUnit):
                        node._reset_parent(obj)
                        node.register_in_parent_scope()

            # Rescope to ensure that symbol references are up to date
            obj.rescope_symbols()

        obj.register_in_parent_scope()

        return obj

    @property
    def typedefs(self):
        """
        Return the :any:`TypeDef` defined in the :attr:`spec` of this unit
        """
        return as_tuple(FindNodes(ir.TypeDef).visit(self.spec))

    @property
    def typedef_map(self):
        """
        Map of names and :any:`TypeDef` defined in the :attr:`spec` of this unit
        """
        return CaseInsensitiveDict((td.name, td) for td in self.typedefs)

    @property
    def declarations(self):
        """
        Return the declarations from the :attr:`spec` of this unit
        """
        return as_tuple(FindNodes((ir.VariableDeclaration, ir.ProcedureDeclaration)).visit(self.spec))

    @property
    def variables(self):
        """
        Return the variables declared in the :attr:`spec` of this unit
        """
        return as_tuple(flatten(decl.symbols for decl in self.declarations))

    @variables.setter
    def variables(self, variables):
        """
        Set the variables property and ensure that the internal declarations match.
        """
        # First map variables to existing declarations
        decl_map = dict((v, decl) for decl in self.declarations for v in decl.symbols)

        for v in as_tuple(variables):
            if v not in decl_map:
                # By default, append new variables to the end of the spec
                if isinstance(v.type.dtype, ProcedureType):
                    new_decl = ir.ProcedureDeclaration(symbols=(v, ))
                else:
                    new_decl = ir.VariableDeclaration(symbols=(v, ))
                self.spec.append(new_decl)

        # Run through existing declarations and check that all variables still exist
        dmap = {}
        for decl in self.declarations:
            new_vars = as_tuple(v for v in decl.symbols if v in variables)
            if len(new_vars) > 0:
                decl._update(symbols=new_vars)
            else:
                dmap[decl] = None  # Mark for removal

        # Remove all redundant declarations
        self.spec = Transformer(dmap).visit(self.spec)

    @property
    def variable_map(self):
        """
        Map of variable names to :any:`Variable` objects
        """
        return CaseInsensitiveDict((v.name, v) for v in self.variables)

    @property
    def imports(self):
        """
        Return the list of :any:`Import` in this unit
        """
        return as_tuple(FindNodes(ir.Import).visit(self.spec or ()))

    @property
    def import_map(self):
        """
        Map of imported symbol names to :any:`Import` objects
        """
        return CaseInsensitiveDict((s.name, imprt) for imprt in self.imports for s in imprt.symbols)

    @property
    def imported_symbols(self):
        """
        Return the symbols imported in this unit
        """
        imports = self.imports
        return as_tuple(flatten(
            imprt.symbols or [s[1] for s in imprt.rename_list or []]
            for imprt in imports
        ))

    @property
    def imported_symbol_map(self):
        """
        Map of imported symbol names to objects
        """
        return CaseInsensitiveDict((s.name, s) for s in self.imported_symbols)

    @property
    def all_imports(self):
        """
        Return the list of :any:`Import` in this unit and any parent scopes
        """
        imports = self.imports
        scope = self
        while (scope := scope.parent):
            imports += scope.imports
        return imports

    @property
    def all_imported_symbols(self):
        """
        Return the symbols imported in this unit and any parent scopes
        """
        imports = self.all_imports
        return as_tuple(flatten(
            imprt.symbols or [s[1] for s in imprt.rename_list or []]
            for imprt in imports
        ))

    @property
    def all_imported_symbol_map(self):
        """
        Map of imported symbol names to objects for this unit and any parent scopes
        """
        return CaseInsensitiveDict((s.name, s) for s in self.all_imported_symbols)

    @property
    def interfaces(self):
        """
        Return the list of :any:`Interface` declared in this unit
        """
        return as_tuple(FindNodes(ir.Interface).visit(self.spec))

    @property
    def interface_symbols(self):
        """
        Return the list of symbols declared via interfaces in this unit
        """
        return as_tuple(flatten(intf.symbols for intf in self.interfaces))

    @property
    def interface_map(self):
        """
        Map of declared interface names to :any:`Interface` nodes
        """
        return CaseInsensitiveDict(
            (s.name, intf) for intf in self.interfaces for s in intf.symbols
        )

    @property
    def interface_symbol_map(self):
        """
        Map of declared interface names to symbols
        """
        return CaseInsensitiveDict(
            (s.name, s) for s in self.interface_symbols
        )

    @property
    def enum_symbols(self):
        """
        List of symbols defined via an enum
        """
        return as_tuple(flatten(enum.symbols for enum in FindNodes(ir.Enumeration).visit(self.spec or ())))

    @property
    def definitions(self):
        """
        The list of IR nodes defined by this program unit.

        Returns an empty tuple by default and can be overwritten by derived nodes.
        """
        return ()

    @property
    def symbols(self):
        """
        Return list of all symbols declared or imported in this module scope
        """

        #Find all nodes that may contain symbols
        nodelist = FindNodes((ir.VariableDeclaration, ir.ProcedureDeclaration,
                    ir.Import, ir.Interface, ir.Enumeration)).visit(self.spec or ())

        #Return all symbols found in nodelist as well as any procedure_symbols
        #in contained subroutines
        return as_tuple(flatten(n.symbols for n in nodelist)) + \
               tuple(routine.procedure_symbol for routine in self.subroutines)

    @property
    def symbol_map(self):
        """
        Map of symbol names to symbols
        """
        return CaseInsensitiveDict(
            (s.name, s) for s in self.symbols
        )

    def get_symbol(self, name):
        """
        Returns the symbol for a given name as defined in its declaration.

        The returned symbol might include dimension symbols if it was
        declared as an array.

        Parameters
        ----------
        name : str
            Base name of the symbol to be retrieved
        """
        return self.get_symbol_scope(name).variable_map.get(name)

    def Variable(self, **kwargs):
        """
        Factory method for :any:`TypedSymbol` or :any:`MetaSymbol` classes.

        This invokes the :any:`Variable` with this node as the scope.

        Parameters
        ----------
        name : str
            The name of the variable.
        type : optional
            The type of that symbol. Defaults to :any:`BasicType.DEFERRED`.
        parent : :any:`Scalar` or :any:`Array`, optional
            The derived type variable this variable belongs to.
        dimensions : :any:`ArraySubscript`, optional
            The array subscript expression.
        """
        kwargs['scope'] = self
        return Variable(**kwargs)

    def parse_expr(self, expr_str, strict=False, evaluate=False, context=None):
        """
        Uses :meth:`parse_expr` to convert expression(s) represented
        in a string to Loki expression(s)/IR.

        Parameters
        ----------
        expr_str : str
            The expression as a string
        strict : bool, optional
            Whether to raise exception for unknown variables/symbols when
            evaluating an expression (default: `False`)
        evaluate : bool, optional
            Whether to evaluate the expression or not (default: `False`)
        context : dict, optional
            Symbol context, defining variables/symbols/procedures to help/support
            evaluating an expression

        Returns
        -------
        :any:`Expression`
            The expression tree corresponding to the expression
        """
        return parse_expr(expr_str, scope=self, strict=strict, evaluate=evaluate, context=context)

    @property
    def subroutines(self):
        """
        List of :class:`Subroutine` objects that are declared in this unit
        """
        from loki.subroutine import Subroutine  # pylint: disable=import-outside-toplevel,cyclic-import
        if self.contains is None:
            return ()
        return as_tuple([
            routine for routine in self.contains.body if isinstance(routine, Subroutine)
        ])

    routines = subroutines

    @property
    def subroutine_map(self):
        """
        Map of subroutine names to :any:`Subroutine` objects in :attr:`subroutines`
        """
        return CaseInsensitiveDict(
            (s.name, s) for s in self.subroutines
        )

    @property
    def spec_parts(self):
        """
        Return the :attr:`spec` subdivided into the parts the Fortran standard
        describes and requires to appear in a specific order

        The parts are:

        * import statements (such as module imports via ``USE``)
        * implicit-part (such as ``IMPLICIT NONE``)
        * declaration constructs (such as access statements, variable declarations etc.)

        This can be useful when adding or looking for statements that have to appear
        in a certain position.

        Note that comments at the interface between parts may be allocated to the
        previous or next part.

        Returns
        -------
        tuple of tuple of :class:`ir.Node`
            The parts of the spec, with empty parts represented by empty tuples.
        """
        if not self.spec:
            return ((),(),())

        intrinsic_nodes = FindNodes(ir.Intrinsic).visit(self.spec)
        implicit_nodes = [node for node in intrinsic_nodes if node.text.lstrip().lower().startswith('implicit')]

        if implicit_nodes:
            # Use 'IMPLICIT' statements as divider
            implicit_start_index = self.spec.body.index(implicit_nodes[0])
            if len(implicit_nodes) == 1:
                implicit_end_index = implicit_start_index
            else:
                implicit_end_index = self.spec.body.index(implicit_nodes[-1])

            return (
                self.spec.body[:implicit_start_index],
                self.spec.body[implicit_start_index:implicit_end_index+1],
                self.spec.body[implicit_end_index+1:]
            )

        # No 'IMPLICIT' statements: find the end of imports
        import_nodes = FindNodes(ir.Import).visit(self.spec)

        if not import_nodes:
            return ((), (), self.spec.body)

        import_nodes_end_index = self.spec.body.index(import_nodes[-1])
        return (
            self.spec.body[:import_nodes_end_index+1],
            (),
            self.spec.body[import_nodes_end_index+1:]
        )

    @property
    def ir(self):
        """
        All components of the intermediate representation in this unit
        """
        return (self.docstring, self.spec, self.contains)

    @property
    def source(self):
        """
        The :any:`Source` object for this unit
        """
        return self._source

    def to_fortran(self, conservative=False, cuf=False):
        """
        Convert this unit to Fortran source representation
        """
        if cuf:
            from loki.backend.cufgen import cufgen # pylint: disable=import-outside-toplevel
            return cufgen(self)
        from loki.backend.fgen import fgen  # pylint: disable=import-outside-toplevel
        return fgen(self, conservative=conservative)

    def __repr__(self):
        """
        Short string representation
        """
        return f'{self.__class__.__name__}:: {self.name}'

    def __contains__(self, name):
        """
        Check if a symbol, type or subroutine with the given name is declared
        inside this unit
        """
        return name in self.symbols or name in self.typedef_map

    def __getitem__(self, name):
        """
        Get the IR node of the subroutine, typedef, imported symbol or declared
        variable corresponding to the given name
        """
        if not isinstance(name, str):
            raise TypeError('Name lookup requires a string!')

        item = self.subroutine_map.get(name)
        if item is None:
            item = self.typedef_map.get(name)
        if item is None:
            item = self.symbol_map[name]
        return item

    def __iter__(self):
        """
        Make :any:`ProgramUnit`s non-iterable
        """
        raise TypeError('ProgramUnit nodes can not be traversed. Try `ir` or `subroutines` instead.')

    def __bool__(self):
        """
        Ensure existing objects register as True in boolean checks, despite
        raising exceptions in :meth:`__iter__`.
        """
        return True

    def apply(self, op, **kwargs):
        """
        Apply a given transformation to this program unit

        Note that the dispatch routine ``op.apply(source)`` will ensure
        that all entities of this :any:`ProgramUnit` are correctly traversed.
        """
        # TODO: Should type-check for an `Operation` object here
        op.apply(self, **kwargs)

    def resolve_typebound_var(self, name, variable_map=None):
        """
        A small convenience utility to resolve type-bound variables.

        Parameters
        ----------
        name : str
            The full name of the variable to be resolved, e.g., a%b%c%d.
        variable_map : dict
            A map of the variables defined in the current scope.
        """

        if not (_variable_map := variable_map):
            _variable_map = self.variable_map

        name_parts = name.split('%', maxsplit=1)
        var = _variable_map[name_parts[0]]
        if len(name_parts) > 1:
            var = var.get_derived_type_member(name_parts[1])
        return var
loki-ecmwf-0.3.6/loki/config.py0000664000175000017500000000722215167130205016521 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import os
from collections import OrderedDict
from contextlib import contextmanager


class Configuration(OrderedDict):
    """
    Dictionary class that holds global configuration parameters.

    In addition to sanity checking this dict also allows callbacks
    to be used to propagate values to the relevant parts of the
    system.

    Example usage:

    .. code-block::

        config = Configuration('Loki')
        config.register('log-level', 'INFO', env_variable='LOKI_LOGGING')
        ...

        config.initialize()
        logging = config['log-level']
    """

    def __init__(self, name=None):
        super().__init__()
        self.name = name

        self._defaults = {}
        self._env_variables = {}
        self._preprocess_functions = {}
        self._callback_functions = {}

    def initialize(self):
        """
        Initialize all registered entries by either using the value given
        via environemnt variables or the default.
        """
        for key in self.keys():
            if self._env_variables[key]:
                env_val = os.environ.get(self._env_variables[key], None)
                if env_val is None:
                    self[key] = self._defaults[key]
                else:
                    self[key] = env_val

    def register(self, key, default, env_variable=None, preprocess=None, callback=None):
        """
        Register configuration option with optional default value
        and callback function.

        Parameters
        ----------
        key : str
            Internal name of the configuration option
        default :
            Default value if unspecified in environment
        env_variable : str
            Name of environment variable to check for value
        preprocess :
            Optional preprocess function that turns string-based
            values into the correct format (eg. for env variables)
        callback :
            Optional callback function to trigger on updates
        """
        super().__setitem__(key, default)

        self._defaults[key] = default
        self._env_variables[key] = env_variable
        self._preprocess_functions[key] = preprocess
        self._callback_functions[key] = callback

    def print_state(self):
        """
        Print the current configuration state.
        """
        from loki.logging import info  # pylint: disable=import-outside-toplevel
        info("[Loki] global config:")
        for k, v in self.items():
            info(f'  {k}: {v}')

    def _updated(self, key, value):
        # Execute callback function for ``key``
        if self._callback_functions[key]:
            self._callback_functions[key](value)

    def __setitem__(self, key, value):
        # Preprocess any given value
        if self._preprocess_functions[key]:
            value = self._preprocess_functions[key](value)

        super().__setitem__(key, value)

        # Trigger configured callbacks
        self._updated(key, value)


config = Configuration('Loki configuration')


@contextmanager
def config_override(settings):
    """
    Simple context manager for testing purposes that temporarily overrides
    config options with :param:`settings` and restores the original after.
    """
    original = tuple(config.items())
    config.update(settings)

    yield

    config.update(dict(original))
loki-ecmwf-0.3.6/loki/cli/0000775000175000017500000000000015167130205015446 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/cli/__init__.py0000664000175000017500000000076515167130205017567 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Command-line utilities and entry-points for various Loki invocation methods.
"""

from loki.cli.common import *  # noqa
loki-ecmwf-0.3.6/loki/cli/tests/0000775000175000017500000000000015167130205016610 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/cli/tests/__init__.py0000664000175000017500000000057015167130205020723 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
loki-ecmwf-0.3.6/loki/cli/tests/test_loki_transform.py0000664000175000017500000001240515167130205023254 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from pathlib import Path
import pytest
import tomli_w

from click.testing import CliRunner

from loki.cli.loki_transform import cli
from loki.logging import log_levels


@pytest.fixture(scope='module', name='here')
def fixture_here():
    return Path(__file__).parent


@pytest.fixture(scope='module', name='testdir')
def fixture_testdir(here):
    return here.parent.parent/'tests'


@pytest.fixture(name='config')
def fixture_config():
    """
    Default configuration dict with basic options.
    """
    return {
        'default': {
            'mode': 'idem',
            'role': 'kernel',
            'expand': True,
            'strict': True,
            'disable': ['abort'],
            'enable_imports': True,
        },
        'routines': {},
        'transformations': {
            'Idem': {
                'classname': 'IdemTransformation',
                'module': 'loki.transformations',
            },
            'ModuleWrap': {
                'classname': 'ModuleWrapTransformation',
                'module': 'loki.transformations.build_system',
                'options': {'module_suffix': '_MOD'},
            },
            'Dependency': {
                'classname': 'DependencyTransformation',
                'module': 'loki.transformations.build_system',
                'options': {'suffix': '_LOKI', 'module_suffix': '_MOD'},
            },
        },
        'pipelines': {
            'idem': {
                'transformations': ['Idem', 'ModuleWrap', 'Dependency']
            }
        }
    }


def test_loki_transform_plan(testdir, config, caplog, tmp_path):
    """ Test the CLI invocation of the "plan" mode """

    projA = testdir/'sources/projA'
    projA_files = [
        'driverA_mod.f90', 'kernelA_mod.F90', 'compute_l1_mod.f90',
        'another_l1.F90', 'compute_l2_mod.f90', 'another_l2.F90'
    ]

    # Create final config
    config['routines'] = {
        'driverA': {'role': 'driver'},
        'another_l1': {'role': 'driver'},
    }
    (tmp_path/'my.config').write_text(tomli_w.dumps(config))
    plan_file = tmp_path/'plan.cmake'
    assert not plan_file.exists()

    caplog.clear()
    with caplog.at_level(log_levels['INFO']):
        # Execute command in separate runner
        result = CliRunner().invoke(
            cli, [
                '--debug', 'plan', '--mode=idem', f'--config={tmp_path}/my.config',
                '--frontend=fp', f'--source={projA}', f'--root={projA}',
                f'--build={tmp_path}/build', f'--header={projA}/module/header_mod.f90',
                f'--plan-file={tmp_path}/plan.cmake'
            ]
        )

        # Check execution and logs for certain messages
        assert result.exit_code == 0
        logout = ''.join(str(r) for r in caplog.records)
        assert '[Loki::Scheduler] Performed initial source scan' in logout
        assert '[Loki] Scheduler writing CMake plan' in logout
        assert '[Loki::Scheduler] Applied transformation ' in logout

        # Check generated plan file
        assert plan_file.exists()
        plan_str = plan_file.read_text()
        for fname in projA_files:
            # Check that each file is named twice and the modified once
            assert plan_str.count(fname) == 2
            fname_mod = fname.replace('.f90', '.idem.f90').replace('.F90', '.idem.F90')
            assert plan_str.count(fname_mod) == 1

    assert not plan_file.unlink()


def test_loki_transform_convert(testdir, config, caplog, tmp_path):
    """ Test the CLI invocation of the "convert" mode """

    projA = testdir/'sources/projA'
    projA_files = [
        'driverA_mod.f90', 'kernelA_mod.F90', 'compute_l1_mod.f90',
        'another_l1.F90', 'compute_l2_mod.f90', 'another_l2.F90'
    ]

    # Create final config and build directory
    config['routines'] = {
        'driverA': {'role': 'driver'},
        'another_l1': {'role': 'driver'},
    }
    (tmp_path/'my.config').write_text(tomli_w.dumps(config))
    (tmp_path/'build').mkdir()

    caplog.clear()
    with caplog.at_level(log_levels['INFO']):
        # Execute command in separate runner
        result = CliRunner().invoke(
            cli, [
                '--debug', 'convert', '--mode=idem', f'--config={tmp_path}/my.config',
                '--frontend=fp', f'--source={projA}', f'--root={projA}',
                f'--build={tmp_path}/build', f'--header={projA}/module/header_mod.f90',
            ]
        )

        # Check execution and logs for certain messages
        assert result.exit_code == 0
        logout = ''.join(str(r) for r in caplog.records)
        assert '[Loki::Scheduler] Performed initial source scan' in logout
        assert '[Loki::Scheduler] Applied transformation ' in logout

        for fname in projA_files:
            # Ensure all files have been generated
            fname_mod = fname.replace('.f90', '.idem.f90').replace('.F90', '.idem.F90')
            assert (tmp_path/'build'/fname_mod).exists()
loki-ecmwf-0.3.6/loki/cli/tests/test_loki_lint.py0000664000175000017500000000557015167130205022214 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from pathlib import Path
import pytest

from click.testing import CliRunner

from loki.cli.loki_lint import cli
from loki.logging import log_levels


@pytest.fixture(scope='module', name='here')
def fixture_here():
    return Path(__file__).parent


@pytest.fixture(scope='module', name='testdir')
def fixture_testdir(here):
    return here.parent.parent/'tests'


def test_loki_lint_rules(caplog):
    """ Test the CLI invocation of the loki-lint "rules" and "default-config" mode """

    caplog.clear()
    with caplog.at_level(log_levels['DEBUG']):
        # Execute command in separate runner
        result = CliRunner().invoke(cli, ['--debug', 'rules'])

        # Check execution and logs for certain messages
        assert result.exit_code == 0
        logout = ''.join(str(r) for r in caplog.records)
        assert 'MissingImplicitNoneRule' in logout
        assert 'MissingIntfbRule' in logout
        assert 'OnlyParameterGlobalVarRule' in logout

    caplog.clear()
    with caplog.at_level(log_levels['DEBUG']):
        # Execute command in separate runner
        result = CliRunner().invoke(cli, ['--debug', 'default-config'])

        # Check execution and logs for certain messages
        assert result.exit_code == 0
        logout = ''.join(str(r) for r in caplog.records)
        assert 'MissingImplicitNoneRule' in logout
        assert 'MissingIntfbRule' in logout
        assert 'OnlyParameterGlobalVarRule' in logout


def test_loki_lint_check(testdir, caplog):
    """ Test the CLI invocation of the loki-lint "rules" mode """

    projA = testdir/'sources/projA'
    projInlineCalls = testdir/'sources/projInlineCalls'

    caplog.clear()
    with caplog.at_level(log_levels['WARNING']):
        # Execute command on a clean project
        result = CliRunner().invoke(
            cli, [
                'check', '--no-scheduler', f'--basedir={projA}', '--include=*.F90'
            ]
        )
        # Check that nothing triggered
        assert result.exit_code == 0
        assert not caplog.records

    caplog.clear()
    with caplog.at_level(log_levels['INFO']):
        # Execute check command in an unclean project
        result = CliRunner().invoke(
            cli, [
                '--debug', 'check', '--no-scheduler', f'--basedir={projInlineCalls}', '--include=*.F90'
            ]
        )

        # Check execution and logs for certain messages
        assert result.exit_code == 0
        logout = ''.join(str(r) for r in caplog.records)
        assert logout.count('[L3] OnlyParameterGlobalVarRule') == 4
loki-ecmwf-0.3.6/loki/cli/common.py0000664000175000017500000001057115167130205017314 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from dataclasses import dataclass, asdict
from functools import wraps
from pathlib import Path
from typing import Tuple

import click
from click_option_group import optgroup

from loki.frontend import Frontend
from loki.tools.util import auto_post_mortem_debugger, set_excepthook


__all__ = ['cli', 'frontend_options', 'scheduler_options']


@click.group()
@click.option(
    '--debug/--no-debug', default=False, show_default=True,
    help=('Enable / disable debug mode. This automatically attaches '
          'a debugger when exceptions occur')
)
def cli(debug):
    if debug:
        set_excepthook(hook=auto_post_mortem_debugger)


@dataclass
class FrontendOptions:
    """
    Storage object for frontend options that can be passed to the :any:`Scheduler`.
    """

    frontend: Frontend = Frontend.FP
    preprocess: bool = False
    includes: Tuple[Path] = ()
    defines: Tuple[str] = ()
    xmods: Tuple[Path] = ()
    omni_includes: Tuple[str] = ()

    @property
    def asdict(self):
        return asdict(self)


def frontend_options(func):
    """
    Option group configuring the Loki frontend options, including preprocessing.
    """

    @optgroup.group('Loki frontend options',
                    help='Frontend parsing options for Loki.')
    @optgroup.option('--frontend', default='fp', type=click.Choice(['fp', 'ofp', 'omni']),
                     help='Frontend parser to use (default FP)')
    @optgroup.option('--cpp/--no-cpp', default=False,
                     help='Trigger C-preprocessing of source files.')
    @optgroup.option('--include', '-I', type=click.Path(), multiple=True,
                     help='Path for header file(s) that provide type '
                     'information without being part of the call tree')
    @optgroup.option('--define', '-D', multiple=True,
                     help='Additional symbol definitions for the C-preprocessor')
    @optgroup.option('--xmod', '-M', type=click.Path(), multiple=True,
                     help='Path for additional .xmod file(s) for OMNI')
    @optgroup.option('--omni-include', type=click.Path(), multiple=True,
                     help='Additional path for header files, specifically for OMNI')
    @click.pass_context
    @wraps(func)
    def process_frontend_options(ctx, *args, **kwargs):
        frontendopts = ctx.ensure_object(FrontendOptions)
        frontendopts.frontend = Frontend[kwargs.pop('frontend').upper()]
        frontendopts.preprocess = kwargs.pop('cpp')
        frontendopts.includes = kwargs.pop('include')
        frontendopts.defines = kwargs.pop('define')
        frontendopts.xmods = kwargs.pop('xmod')
        frontendopts.omni_includes = kwargs.pop('omni_include')
        return ctx.invoke(func, *args, frontendopts, **kwargs)

    return process_frontend_options


@dataclass
class SchedulerOptions:
    """
    Storage object for scheduler options to instantiate a :any:`Scheduler`.
    """

    build: Path = Path.cwd()
    source: Tuple[Path] = ()
    header: Tuple[str] = ()


def scheduler_options(func):
    """
    Option group configuring the Loki batch scheduler..
    """

    @optgroup.group('Loki batch scheduler options',
                    help='Batch scheduler options for Loki.')
    @optgroup.option('--build', '-b', '--out-path', type=click.Path(), default=None,
                     help='Path to build directory for source generation.')
    @optgroup.option('--source', '-s', '--path', type=click.Path(), multiple=True,
                     help='Path to search during source exploration.')
    @optgroup.option('--header', '-h', type=click.Path(), multiple=True,
                     help='Path for additional header file(s).')
    @click.pass_context
    @wraps(func)
    def process_scheduler_options(ctx, *args, **kwargs):
        scheduleropts = ctx.ensure_object(SchedulerOptions)
        scheduleropts.build = kwargs.pop('build')
        scheduleropts.source = kwargs.pop('source')
        scheduleropts.header = kwargs.pop('header')
        return ctx.invoke(func, *args, scheduleropts, **kwargs)

    return process_scheduler_options
loki-ecmwf-0.3.6/loki/cli/loki_transform.py0000664000175000017500000001421715167130205021056 0ustar  alastairalastair#!/usr/bin/env python

# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Loki head script for source-to-source transformations for batch processing
via the :any:`Scheduler`.
"""

from pathlib import Path
import sys
import click

from loki import config as loki_config, Sourcefile, as_tuple, info
from loki.batch import Scheduler, SchedulerConfig, ProcessingStrategy
from loki.cli.common import cli, frontend_options, scheduler_options

from loki.transformations.build_system import FileWriteTransformation


@cli.command()
@frontend_options
@scheduler_options
@click.option('--mode', '-m', default='idem', type=click.STRING,
              help='Transformation mode, selecting which code transformations to apply.')
@click.option('--config', default=None, type=click.Path(),
              help='Path to custom scheduler configuration file')
@click.option('--plan-file', type=click.Path(), default=None,
              help='Process pipeline in planning mode and generate CMake "plan" file.')
@click.option('--callgraph', '-g', type=click.Path(), default=None,
              help='Generate and display the subroutine callgraph.')
@click.option('--root', type=click.Path(), default=None,
              help='Root path to which all paths are relative to.')
@click.option('--log-level', '-l', default='info', envvar='LOKI_LOGGING',
              type=click.Choice(['debug', 'detail', 'perf', 'info', 'warning', 'error']),
              help='Log level to output during batch processing')
def convert(
        frontend_opts, scheduler_opts, mode, config, plan_file, callgraph, root, log_level
):
    """
    Batch-processing mode for Fortran-to-Fortran transformations that
    employs a :class:`Scheduler` to process large numbers of source
    files.

    Based on the given "mode" string, configuration file, source file
    paths and build arguments the :any:`Scheduler` will perform
    automatic call-tree exploration and apply a set of
    :any:`Transformation` objects to this call tree.
    """

    loki_config['log-level'] = log_level

    if plan_file is not None:
        processing_strategy = ProcessingStrategy.PLAN
        info(f'[Loki] Creating CMake plan file from config: {config}')
    else:
        processing_strategy = ProcessingStrategy.DEFAULT
        info(f'[Loki] Batch-processing source files using config: {config} ')

    config = SchedulerConfig.from_file(config)

    # set default transformation mode in Scheduler config
    config.default['mode'] = mode

    # Note, in order to get function inlinig correct, we need full knowledge
    # of any imported symbols and functions. Since we cannot yet retro-fit that
    # after creation, we need to make sure that the order of definitions can
    # be used to create a coherent stack of type definitions.
    # definitions with new scheduler not necessary anymore. However, "source" need to be adjusted
    #  in order to allow the scheduler to find the dependencies
    definitions = []
    for h in scheduler_opts.header:
        sfile = Sourcefile.from_file(filename=h, definitions=definitions, **frontend_opts.asdict)
        definitions = definitions + list(sfile.definitions)

    # Create a scheduler to bulk-apply source transformations
    paths = [Path(p) for p in as_tuple(scheduler_opts.source)]
    paths += [Path(h).parent for h in as_tuple(scheduler_opts.header)]
    # Skip full source parse for planning mode
    full_parse = processing_strategy == ProcessingStrategy.DEFAULT
    scheduler = Scheduler(
        paths=paths, config=config, full_parse=full_parse,
        definitions=definitions, output_dir=scheduler_opts.build, **frontend_opts.asdict
    )

    # If requested, apply a custom pipeline from the scheduler config
    # Note that this new entry point will bypass all other default
    # behaviour and exit immediately after.
    if mode not in config.pipelines:
        msg = f'[Loki] ERROR: Pipeline or transformation mode {mode} not found in config file.\n'
        msg += '[Loki] Please provide a config file with configured transformation or pipelines instead.\n'
        sys.exit(msg)

    info(f'[Loki-transform] Applying custom pipeline {mode} from config:')
    info(str(config.pipelines[mode]))

    scheduler.process(config.pipelines[mode], proc_strategy=processing_strategy)

    mode = mode.replace('-', '_')  # Sanitize mode string

    # Write out all modified source files into the build directory
    file_write_trafo = scheduler.config.transformations.get('FileWriteTransformation', None)
    if not file_write_trafo:
        file_write_trafo = FileWriteTransformation(cuf='cuf' in mode)
    scheduler.process(transformation=file_write_trafo, proc_strategy=processing_strategy)

    if plan_file is not None:
        scheduler.write_cmake_plan(plan_file, rootpath=root)

    if callgraph:
        scheduler.callgraph(callgraph)


@cli.command('plan')
@frontend_options
@scheduler_options
@click.option('--mode', '-m', default='idem', type=click.STRING,
              help='Transformation mode, selecting which code transformations to apply.')
@click.option('--config', '-c', type=click.Path(),
              help='Path to configuration file.')
@click.option('--root', type=click.Path(), default=None,
              help='Root path to which all paths are relative to.')
@click.option('--callgraph', '-g', type=click.Path(), default=None,
              help='Generate and display the subroutine callgraph.')
@click.option('--plan-file', type=click.Path(),
              help='CMake "plan" file to generate.')
@click.option('--log-level', '-l', default='info', envvar='LOKI_LOGGING',
              type=click.Choice(['debug', 'detail', 'perf', 'info', 'warning', 'error']),
              help='Log level to output during batch processing')
@click.pass_context
def plan(ctx, *_args, **_kwargs):
    """
    Create a "plan", a schedule of files to inject and transform for a
    given configuration.
    """
    return ctx.forward(convert)
loki-ecmwf-0.3.6/loki/cli/loki_lint.py0000664000175000017500000002111015167130205017777 0ustar  alastairalastair#!/usr/bin/env python3

# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import sys
import importlib
from logging import FileHandler
from pathlib import Path
import click
from codetiming import Timer
import yaml

from loki.logging import logger, DEBUG, warning, info, debug, error
from loki.lint import Linter, lint_files
from loki.tools import yaml_include_constructor, auto_post_mortem_debugger, as_tuple


def get_rules(module):
    """
    Return the list of all available rules in the named module.
    """
    rules_module = importlib.import_module(f'lint_rules.{module}')
    return Linter.lookup_rules(rules_module)


@click.group()
@click.option('--debug/--no-debug', default=False, show_default=True,
              help=('Enable / disable debug mode. This incures more verbose '
                    'output and automatically attaches a debugger on exceptions.'))
@click.option('--log', type=click.Path(writable=True),
              help='Write more detailed information to a log file.')
@click.option('--rules-module', default='ifs_arpege_coding_standards', show_default=True,
              help='Select Python module with rules in lint_rules.')
@click.pass_context
def cli(ctx, debug, log, rules_module):  # pylint:disable=redefined-outer-name
    if ctx.obj is None:
        ctx.obj = {}
    ctx.obj['DEBUG'] = debug
    ctx.obj['rules_module'] = rules_module
    if debug:
        logger.setLevel(DEBUG)
        sys.excepthook = auto_post_mortem_debugger
    if log:
        file_handler = FileHandler(log, mode='w')
        file_handler.setLevel(DEBUG)
        logger.addHandler(file_handler)


@cli.command(help='Get default configuration of linter and rules.')
@click.option('--output-file', '-o', type=click.File(mode='w'),
              help='Write default configuration to file.')
@click.pass_context
def default_config(ctx, output_file):  # pylint: disable=unused-argument
    config = Linter.default_config(rules=get_rules(ctx.obj['rules_module']))
    # Eliminate empty config dicts
    config = {key: val for key, val in config.items() if val}
    config_str = yaml.dump(config, default_flow_style=False)

    if output_file:
        output_file.write(config_str)
    else:
        logger.info(config_str)


@cli.command(help='List all available rules.')
@click.option('--with-title/--without-title', default=False, show_default=True,
              help='With / without title and id from each rule\'s docs.')
@click.option('--sort-by', type=click.Choice(['title', 'id']), default='title',
              show_default=True, help='Sort rules by a specific criterion.')
@click.pass_context
def rules(ctx, with_title, sort_by):  # pylint: disable=unused-argument
    rule_list = get_rules(ctx.obj['rules_module'])
    sort_keys = {'title': lambda rule: rule.__name__.lower(),
                 'id': lambda rule: list(map(int, rule.docs.get('id').split('.')))}
    rule_list.sort(key=sort_keys[sort_by])

    rule_names = [rule.__name__ for rule in rule_list]
    max_width_name = max(len(name) for name in rule_names)

    if with_title:
        rule_ids = [rule.docs.get('id', '') for rule in rule_list]
        max_width_id = max(len(id_) for id_ in rule_ids)
        rule_titles = [rule.docs.get('title', '').format(**rule.config)
                       if rule.config else rule.docs.get('title', '')
                       for rule in rule_list]

        fmt_string = '{name:<{name_width}}  {id:^{id_width}}  {title}'
        output_string = '\n'.join(
            fmt_string.format(name=name, name_width=max_width_name,
                              id=id_, id_width=max_width_id, title=title)
            for name, id_, title in zip(rule_names, rule_ids, rule_titles))

    else:
        output_string = '\n'.join(rule_names)

    logger.info(output_string)


@cli.command(help='Check for syntax errors and compliance to coding rules.')
@click.option('--include', '-I', type=str, multiple=True,
              help=('File name or pattern for file names to be checked. '
                    'Allows for relative and absolute paths/glob patterns.'))
@click.option('--exclude', '-X', type=str, multiple=True,
              help=('File name or pattern for file names to be excluded. '
                    'This allows to exclude files that were included by '
                    '--include.'))
@click.option('--basedir', type=click.Path(exists=True, file_okay=False),
              help=('Base directory relative to which --include/--exclude '
                    'patterns are interpreted.'))
@click.option('--config', '-c', type=click.File(),
              help='Configuration file for behaviour of linter and rules.')
@click.option('--fix/--no-fix', default=False, show_default=True,
              help='Attempt to fix problems where possible.')
@click.option('--backup-suffix', type=str,
              help=('When fixing, create a backup of the original file with '
                    'the given suffix.'))
@click.option('--worker', type=int, default=4, show_default=True,
              help=('Number of worker processes to use. With --debug enabled '
                    'this option is ignored and only one worker is used.'))
@click.option('--write-violations-file', is_flag=False, flag_value='violations.yml', default=None,
              help=('Write a YAML file that lists for every file the violated rules. '
                    'The file can be included into a config file to disable reporting '
                    'these violations in subsequent linting runs.'))
@click.option('--scheduler/--no-scheduler', default=False, show_default=True,
              help='Use a Scheduler to plan source file traversal.')
@click.option('--junitxml', type=click.Path(dir_okay=False, writable=True),
              help='Enable output in JUnit XML format to the given file.')
@click.pass_context
def check(ctx, include, exclude, basedir, config, fix, backup_suffix, worker,
          write_violations_file, scheduler, junitxml):
    yaml.add_constructor('!include', yaml_include_constructor, yaml.SafeLoader)
    config_values = yaml.safe_load(config) if config else {}
    if ctx.obj['DEBUG']:
        worker = 1

    if include:
        if 'include' in config_values:
            info('Merging include patterns from config and command line')
            config_values['include'] = as_tuple(config_values['include']) + as_tuple(include)
        else:
            config_values['include'] = as_tuple(include)
        include += as_tuple(config_values['include'])

    if 'include' not in config_values:
        error('No include pattern given')
        return

    if exclude:
        if 'exclude' in config_values:
            info('Merging exclude patterns from config and command line')
            config_values['exclude'] = as_tuple(config_values['exclude']) + as_tuple(exclude)
        else:
            config_values['exclude'] = as_tuple(exclude)

    if basedir:
        if 'basedir' in config_values:
            warning('Overwriting `basedir` value in the config file with command line argument')
        config_values['basedir'] = basedir
    elif 'basedir' not in config_values:
        config_values['basedir'] = Path.cwd()

    debug('Base directory: %s', config_values['basedir'])

    if scheduler:
        config_values.setdefault('scheduler', {
            'default': {
                'mode': 'lint',
                'role': 'kernel',
                'expand': True,
                'strict': False,
            }
        })
    else:
        debug('Include patterns:')
        for p in config_values['include']:
            debug('  - %s', p)
        if 'exclude' in config_values:
            debug('Exclude patterns:')
            for p in config_values['exclude']:
                debug('  - %s', p)
        debug('')

    rule_list = get_rules(ctx.obj['rules_module'])
    debug('%d rules available.', len(rule_list))

    config_values['fix'] = fix
    if backup_suffix:
        if not backup_suffix.startswith('.'):
            backup_suffix = '.' + backup_suffix
        config_values['backup_suffix'] = backup_suffix

    config_values['max_workers'] = worker

    if write_violations_file:
        config_values['violations_file'] = write_violations_file
    if junitxml:
        config_values['junitxml_file'] = junitxml

    with Timer(logger=info, text='Files checking completed in {:.2f}s'):
        checked_count = lint_files(rule_list, config_values)
    info('%d files checked', checked_count)
loki-ecmwf-0.3.6/loki/lint/0000775000175000017500000000000015167130205015645 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/lint/utils.py0000664000175000017500000001674415167130205017373 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from hashlib import md5
import re

from loki.ir import Comment, CommentBlock, LeafNode, FindNodes, Transformer
from loki.module import Module
from loki.sourcefile import Sourcefile
from loki.subroutine import Subroutine


__all__ = ['Fixer', 'get_filename_from_parent', 'get_location_hash', 'is_rule_disabled']


class Fixer:
    """
    Operater class to fix problems reported by fixable rules.
    """

    @classmethod
    def fix_module(cls, module, reports, config):  # pylint: disable=unused-argument
        """
        Call `fix_module` for all rules and apply the transformations.
        """
        # TODO: implement this!
        if reports:
            module.source.invalidate()
        return module

    @classmethod
    def fix_subroutine(cls, subroutine, reports, config):
        """
        Call `fix_subroutine` for all rules and apply the transformations.
        """
        mapper = {}
        for report in reports:
            rule_config = config[report.rule.__name__]
            mapper.update(report.rule.fix_subroutine(subroutine, report, rule_config) or {})

        if mapper:
            # Apply the changes and invalidate source objects
            subroutine.spec = Transformer(mapper).visit(subroutine.spec)
            subroutine.body = Transformer(mapper).visit(subroutine.body)
            subroutine.source.invalidate()
            parent = subroutine.parent
            while parent is not None:
                parent.source.invalidate(children=True)
                parent = getattr(parent, 'parent', None)

        return subroutine

    @classmethod
    def fix_sourcefile(cls, sourcefile, reports, config):  # pylint: disable=unused-argument
        """
        Call `fix_sourcefile` for all rules and apply the transformations.
        """
        # TODO: implement this!
        if reports:
            sourcefile.source.invalidate(children=True)
            sourcefile.ir.source.invalidate(children=True)
        return sourcefile

    @classmethod
    def fix(cls, ast, reports, config):
        """
        Attempt to fix problems flagged by fixable rules in the given IR object.

        This routine calls `fix_module`, `fix_subroutine` and `fix_file`
        as applicable for all rules on all entities in the given IR object.

        :param ast: the IR object to be fixed.
        :type ast: :py:class:`Sourcefile`, :py:class:`Module`, or
                   :py:class:`Subroutine`
        :param list reports: the fixable :py:class:`RuleReport` reports.
        :type rule_report: :py:class:`FileReport`
        :param dict config: a `dict` with the config values.

        :return: the modified AST object.
        """

        # Fix on source file level
        if isinstance(ast, Sourcefile):
            # Depth-first traversal
            if hasattr(ast, 'subroutines') and ast.subroutines is not None:
                for routine in ast.subroutines:
                    cls.fix_subroutine(routine, reports, config)
            if hasattr(ast, 'modules') and ast.modules is not None:
                for module in ast.modules:
                    cls.fix_module(module, reports, config)

            cls.fix_sourcefile(ast, reports, config)

        # Fix on module level
        elif isinstance(ast, Module):
            # Depth-first traversal
            if hasattr(ast, 'subroutines') and ast.subroutines is not None:
                for routine in ast.subroutines:
                    cls.fix_subroutine(routine, reports, config)

            cls.fix_module(ast, reports, config)

        # Fix on subroutine level
        elif isinstance(ast, Subroutine):
            # Depth-first traversal
            if hasattr(ast, 'members') and ast.members is not None:
                for routine in ast.members:
                    cls.fix_subroutine(routine, reports, config)

            cls.fix_subroutine(ast, reports, config)

        return ast


def get_filename_from_parent(obj):
    """
    Try to determine filename of the source file for an IR object

    It follows ``parent`` attributes until :any:`Sourcefile` is encountered.

    Parameters
    ----------
    obj : :any:`Sourcefile`, :any:`Subroutine` or :any:`Module`
        A source file, module or subroutine object.

    Returns
    -------
    str or NoneType
        The filename if found, else `None`.
    """
    scope = obj
    while hasattr(scope, 'parent') and scope.parent:
        # Go up until we are at Sourcefile level
        scope = scope.parent
    if hasattr(scope, 'path'):
        return scope.path
    return None


def get_location_hash(location):
    """
    Utility routine that produces an identifier hash for a location in the IR

    Parameters
    ----------
    location : :class:`Node` or :class:`ProgramUnit`
        The IR object for which to produce the hash

    Returns
    -------
    str or None
        The hash as a string or, if no hash can be created for :data:`location`,
        `None` is returned.
    """
    if getattr(location, 'source', None) and location.source.string:
        first_line = location.source.string[:location.source.string.find('\n')]
        line_hash = str(md5(first_line.encode()).hexdigest())
        return line_hash
    return None


_disabled_rules_re = re.compile(r'^\s*!\s*loki-lint\s*:(?:.*?)disable=(?P[\w\.,]*)')

def is_rule_disabled(ir, identifiers, disabled_line_hashes=None):
    """
    Check if a Linter rule is disabled in the provided context via user annotations

    This looks for comments of the form

    .. code-block:

        ! loki-lint: disable=RuleName

    Where ``RuleName`` is one of the provided :data:`identifiers`.

    If :data:`ir` is a :class:`LeafNode`, only any attached in-line comments
    are checked. If :data:`ir` is any other IR object, the entire subtree below
    this object is searched.

    Parameters
    ----------
    ir : :class:`Node` or :class:`ProgramUnit`
        The IR object for which to check if a rule is disabled
    identifiers : list
        A list of string identifiers via which the rule can be disabled
    disabled_line_hashes : list, optional
        A list of hashes. If the first line of :data:`ir` corresponding to
        the violation matches a hash in this list, the rule is disabled

    Returns
    -------
    bool
        Returns `True` if a rule is disabled, otherwise `False`
    """
    def _match_disabled_rules(comment):
        match = _disabled_rules_re.match(comment.text)
        if match:
            for rule in match.group('rules').split(','):
                if rule in identifiers:
                    return True
        return False

    if disabled_line_hashes:
        line_hash = get_location_hash(ir)
        if line_hash and line_hash in disabled_line_hashes:
            return True

    # If we have a leaf node, we check for in-line comments
    if isinstance(ir, LeafNode):
        if hasattr(ir, 'comment') and ir.comment:
            return _match_disabled_rules(ir.comment)
        return False

    # Otherwise: look in the entire subtree
    for comments in FindNodes((Comment, CommentBlock)).visit(ir):
        for comment in getattr(comments, 'comments', [comments]):
            if _match_disabled_rules(comment):
                return True
    return False
loki-ecmwf-0.3.6/loki/lint/reporter.py0000664000175000017500000004662015167130205020071 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from pathlib import Path

try:
    import yaml
    HAVE_YAML = True
except ImportError:
    HAVE_YAML = False

try:
    from junit_xml import TestSuite, TestCase, to_xml_report_string
    HAVE_JUNIT_XML = True
except ImportError:
    HAVE_JUNIT_XML = False

from loki.ir import Node
from loki.lint.utils import get_filename_from_parent, is_rule_disabled, get_location_hash
from loki.logging import logger, error
from loki.module import Module
from loki.sourcefile import Sourcefile
from loki.subroutine import Subroutine
from loki.tools import filehash

__all__ = [
    'ProblemReport', 'RuleReport', 'FileReport', 'Reporter',
    'GenericHandler', 'DefaultHandler', 'ViolationFileHandler',
    'JunitXmlHandler', 'LazyTextfile'
]


class ProblemReport:
    """
    Data type to represent a problem reported for a node in the IR

    Parameters
    ----------
    msg : str
        The message describing the problem.
    location : :any:`Sourcefile` or :any:`Module` or :any:`Subroutine` or :any:`Node`
        The IR component in which the problem exists.
    """

    def __init__(self, msg, location):
        self.msg = msg
        self.location = location


class RuleReport:
    """
    Container type to collect all individual problems reported by a rule

    All :class:`RuleReport` instances that belong to a file are
    collected in a :class:`FileReport`.

    Parameters
    ----------
    rule : :any:`GenericRule`
        The rule that generated the report
    reports : list of :any:`ProblemReport`, optional
        List of problem reports for this rule
    disabled : bool or list, optional
        Flag to disable reporting of this rule. If set to `True`, no violations will
        be reported. If a list of hashes is provided, the first line of a violation
        will be compared against these hashes and not reported if found.
    """

    def __init__(self, rule, reports=None, disabled=None):
        self.rule = rule
        self.problem_reports = reports or []
        self.disabled = disabled or []
        self.elapsed_sec = 0.

    def add(self, msg, location):
        """
        Convenience function to append a problem report to the list of problems
        reported by the rule.

        Parameters
        ----------
        msg : str
            The message describing the problem.
        location : :any:`Sourcefile` or :any:`Module` or :any:`Subroutine` or :any:`Node`
            The IR node or expression node in which the problem exists.
        """
        if self.disabled is True:
            return
        if not isinstance(location, (Sourcefile, Module, Subroutine, Node)):
            raise TypeError(f'Invalid type for report location: {type(location).__name__}')
        if not is_rule_disabled(location, self.rule.identifiers(), self.disabled):
            self.problem_reports.append(ProblemReport(msg, location))


class FileReport:
    """
    Container type to collect all rule reports for a file

    Parameters
    ----------
    filename : str
        The filename of the file the report is for
    hash : str, optional
        Provide a hash for the file's content to identify the file version
    reports : list, optional
        List of :py:class:`RuleReport`.
    """

    def __init__(self, filename, hash=None, reports=None):  # pylint: disable=redefined-builtin
        self.filename = filename
        self.hash = hash or filehash(Path(filename).read_text())
        self.reports = reports or []

    def add(self, rule_report):
        """
        Append a rule report to the list of reports.

        Parameters
        -----------
        rule_report : :any:`RuleReport`
            The report to be stored.
        """
        if not isinstance(rule_report, RuleReport):
            raise TypeError(f'{type(rule_report)} given, {RuleReport} expected')
        self.reports.append(rule_report)

    @property
    def fixable_reports(self):
        """
        Yield only those rule reports that belong to a rule that can be fixed.
        """
        fixable_reports = [report for report in self.reports
                           if report.rule.fixable and report.problem_reports]
        return fixable_reports


class Reporter:
    """
    Manager for problem reports and their handler.

    It collects file reports and feeds them to all available handlers to generate
    their individual reporting pieces.
    Note that this processing of reports happens immediately when adding a new file
    report for two reasons:

    #. Enable immediate output functionality (i.e., being able to print problems
       as soon as they are detected and not only at the very end of a (lengthy)
       multi file parser run.
    #. To allow parallel processing. The location of problem reports is not
       pickable and thus they need to be processed into a pickable form.

    The class maintains a `dict` in which a list of reports is stored for each handler.
    In a parallel setting, this needs to be initialized explicitly to enable thread
    safe data structures by calling `init_parallel()`.

    Parameters
    ----------
    list handlers : list of :any:`GenericHandler`, optional
        The enabled handlers. If none given, :any:`DefaultHandler` will be used.
    """

    def __init__(self, handlers=None):
        if not handlers:
            handlers = [DefaultHandler()]
        self.handlers_reports = {handler: [] for handler in handlers}

    def init_parallel(self, manager):
        """
        Additional initialization step when using the reporter in a parallel setting.

        Parameters
        ----------
        manager : :any:`multiprocessing.Manager`
            The multiprocessing manager that should be used to create thread safe data structures.
        """
        parallel_reports = manager.dict()
        for handler, reports in self.handlers_reports.items():
            parallel_reports[handler] = manager.list(reports)
        self.handlers_reports = parallel_reports

    def add_file_report(self, file_report):
        """
        Process a file report in all handlers and store the results.

        :param :py:class:`FileReport` file_report: the file report to be processed.
        """
        if not isinstance(file_report, FileReport):
            raise TypeError(f'{type(file_report)} given, {FileReport} expected')
        for handler, reports in self.handlers_reports.items():
            reports.append(handler.handle(file_report))

    def add_file_error(self, filename, rule, msg):
        """
        Create a file report with a single problem reported and add it.

        This is a convenience function that can be used, e.g., to report a failing rule
        or other problems with a certain file.

        Parameters
        ----------
        filename : str
            The file name of the corresponding file.
        rule : :any:`GenericRule`
            The rule that exposed the problem or `None`.
        msg : str
            A description of the problem.
        """
        problem_report = ProblemReport(msg, None)
        rule_report = RuleReport(rule, reports=[problem_report])
        file_report = FileReport(filename, reports=[rule_report])
        self.add_file_report(file_report)

    def output(self):
        """
        Call the `output` function for all reports on every handler.
        """
        for handler, reports in self.handlers_reports.items():
            handler.output(reports)


class GenericHandler:
    """
    Base class for report handler.

    Parameters
    ----------
    basedir : str, optional
        Base directory path relative to which file paths are given.
    """

    def __init__(self, basedir=None):
        self.basedir = basedir

    def get_relative_filename(self, filename):
        if filename and self.basedir:
            try:
                filename = Path(filename).relative_to(self.basedir)
            except ValueError:
                pass
        return filename


    def format_location(self, filename, location):
        """
        Create a string representation of the location given in a `ProblemReport`.

        For a given location it tries to determine:
            - the file name (if not given)
            - the source line
            - the name of the scope (i.e., enclosing subroutine or module)

        Parameters
        ----------
        filename : str
            The file name of the source file.
        location : :any:`Node` or :any:`Subroutine` or :any:`Sourcefile` or :any:`Module`
            The AST node that triggered the problem report.

        Returns
        -------
        str
            The formatted string in the form
            " (l. ) [in routine/module ...]"
        """
        if not filename:
            filename = get_filename_from_parent(location) or ''
        filename = self.get_relative_filename(filename)

        source = getattr(location, '_source', getattr(location, 'source', None))
        if source is not None:
            line = f' (l. {source.lines[0]})'
        else:
            line = ''

        if isinstance(location, Subroutine):
            scope = f' in routine "{location.name}"'
        elif isinstance(location, Module):
            scope = f' in module "{location.name}"'
        else:
            scope = ''
        return f'{filename}{line}{scope}'

    def handle(self, file_report):  # pylint: disable=unused-argument
        """
        Handle the given :attr:`file_report`.

        This routine has to be implemented by the handler class.
        It should either print/save the report immediately or return a picklable
        object that is later to be printed/saved via :meth:`output`.

        Note that the only requirement is that :meth:`handle` and
        :meth:`output` are compatible in the sense that a list of objects
        returned by :meth:`handle` can be processed by :meth:`output`.
        """
        raise NotImplementedError()

    def output(self, handler_reports):
        """
        Output the list of report objects created by :meth:`handle`.
        """
        raise NotImplementedError()


class DefaultHandler(GenericHandler):
    """
    The default report handler for command line output of problems.

    Parameters
    ----------
    target : optional
        The output destination as a callback. Will be called with a string.
        Defaults to :attr:`loki.logging.logger.warning`
    immediate_output : bool, optional
        Print problems immediately if `True`, otherwise
        collect messages and print when calling `output()`. Defaults to `True`
    basedir : str, optional
        Base directory path relative to which file paths are given.
    """

    fmt_string = '{rule}: {location} - {msg}'

    def __init__(self, target=logger.warning, immediate_output=True, basedir=None):
        super().__init__(basedir)
        self.target = target
        self.immediate_output = immediate_output

    def handle(self, file_report):
        """
        Creates a string output of all problem reports and (by default) prints them
        immediately to `target`.

        Parameters
        ----------
        file_report : :any:`FileReport`
            The file report to be processed.

        Returns
        -------
        list of str
            The list of problem report strings.
        """
        filename = file_report.filename
        reports_list = []
        for rule_report in file_report.reports:
            rule = rule_report.rule.__name__
            if hasattr(rule_report.rule, 'docs') and rule_report.rule.docs:
                if 'id' in rule_report.rule.docs:
                    rule = f'[{rule_report.rule.docs["id"]}] {rule}'
            for problem in rule_report.problem_reports:
                location = self.format_location(filename, problem.location)
                msg = self.fmt_string.format(rule=rule, location=location, msg=problem.msg)
                if self.immediate_output:
                    self.target(msg)
                reports_list.append(msg)
        return reports_list

    def output(self, handler_reports):
        """
        Print all reports to `target` if `immediate_output` is disabled.

        Parameters
        ----------
        handler_reports : list of list of str
            The list of lists of reports.
        """
        if not self.immediate_output:
            for reports in handler_reports:
                for report in reports:
                    self.target(report)


class ViolationFileHandler(GenericHandler):
    """
    Report handler class that writes a YAML file with rules violated per file

    The content of the YAML file has the form

    .. code-block:: yaml

        'path/to/my/file.F90':
            filehash: 'abc123'
            rules:
            - SomeRule
            - SomeOtherRule

    This YAML file can be included into a linter config file to disable reporting
    of these rules on that version of the file (source modifications identified by
    :any:`filehash`) in future runs.

    An alternative format of this file is the following

    .. code-block:: yaml

        'path/to/my/file.F90':
            rules:
            - Some Rule:
              - 
              - 
            - SomeOtherRule:
              - 

    This will disable reporting violations of given rules for the specified file,
    as long as the line hash matches one of the listed line hashes. This file format
    can be generated by specifying :data:`use_line_hashes`.

    Parameters
    ----------
    target :
        The output destination
    basedir : str, optional
        Base directory path relative to which file paths are given.
    use_line_hashes : bool, optional
        Disable rule violations per line
    """
    def __init__(self, target=logger.warning, basedir=None, use_line_hashes=False):
        if not HAVE_YAML:
            error('Pyyaml is not available')
            raise RuntimeError

        super().__init__(basedir)
        self.target = target
        self.use_line_hashes = use_line_hashes

    def handle(self, file_report):
        """
        Create YAML block for this file

        Parameters
        ----------
        file_report : :any:`FileReport`
            The file report to be processed

        Returns
        -------
        str
            YAML block for this file
        """
        if self.use_line_hashes:
            violated_rules = [
                {
                    rule_report.rule.__name__: [
                        line_hash for problem_report in rule_report.problem_reports
                        if (line_hash := get_location_hash(problem_report.location))
                    ]
                }
                for rule_report in file_report.reports
                if rule_report.problem_reports
            ]
        else:
            violated_rules = [
                rule_report.rule.__name__ for rule_report in file_report.reports
                if rule_report.problem_reports
            ]

        if violated_rules:
            violations_report = {}
            violations_report['rules'] = violated_rules
            if not self.use_line_hashes:
                violations_report['filehash'] = file_report.hash
            return yaml.dump({
                str(self.get_relative_filename(file_report.filename)): violations_report
            })
        return ''

    def output(self, handler_reports):
        """
        Generate the YAML output from the list of reports.
        """
        self.target('\n'.join(report for report in handler_reports if report))


class JunitXmlHandler(GenericHandler):
    """
    Report handler class that generates JUnit-compatible XML output that can be understood
    by CI platforms such as Jenkins or Bamboo

    Parameters
    ----------
    target :
        The output destination
    basedir : str, optional
        Base directory path relative to which file paths are given.
    """

    fmt_string = '{location} - {msg}'

    def __init__(self, target=logger.warning, basedir=None):
        if not HAVE_JUNIT_XML:
            error('junit_xml is not available')
            raise RuntimeError

        super().__init__(basedir)
        self.target = target

    def handle(self, file_report):
        """
        Creates tuples of string arguments for `junit_xml.TestCase`

        Parameters
        ----------
        file_report : :any:`FileReport`
            The file report to be processed

        Returns
        -------
        tuple(str, list)
            Tuples of the form ``(filename, [(kwargs, messages)])`` with :attr:`kwargs`
            being the constructor arguments for `junit_xml.TestCase` and
            :attr:`messages` a list of strings.
        """
        filename = file_report.filename
        classname = str(Path(filename).with_suffix(''))
        test_cases = []
        for rule_report in file_report.reports:
            kwargs = {'name': rule_report.rule.__name__, 'classname': classname,
                      'allow_multiple_subelements': True, 'elapsed_sec': rule_report.elapsed_sec}
            messages = []
            for problem in rule_report.problem_reports:
                location = self.format_location(filename, problem.location)
                msg = self.fmt_string.format(location=location, msg=problem.msg)
                messages.append(msg)
            test_cases.append((kwargs, messages))
        return (filename, test_cases)

    def output(self, handler_reports):
        """
        Generate the XML output from the list of reports.
        """
        testsuites = []
        for filename, tc_args in handler_reports:
            testcases = []
            for kwargs, messages in tc_args:
                testcase = TestCase(**kwargs)
                for msg in messages:
                    testcase.add_failure_info(msg)
                testcases.append(testcase)
            testsuites.append(TestSuite(filename, testcases))
        xml_string = to_xml_report_string(testsuites)
        self.target(xml_string)


class LazyTextfile:
    """
    Helper class to encapsulate opening and writing to a file.

    This exists because opening the file immediately and then passing
    its ``write`` function to a :any:`GenericHandler` makes it
    impossible to pickle it afterwards, which would make parallel
    execution infeasible.

    Instead of creating a more complicated interface for the handlers
    we opted for this way of a just-in-time/lazy file handler.

    The file is opened automatically for writing when calling :meth:`write`
    for the first time, and closed when the object is garbage collected.

    Parameters
    ----------
    filename : str
        The filename of the output file
    """

    def __init__(self, filename):
        self.file_name = Path(filename)
        self.file_handle = None

    def _check_open(self):
        """
        Check if the file is open already, otherwise open it
        """
        if not self.file_handle:
            self.file_handle = self.file_name.open(mode='w')  # pylint: disable=consider-using-with

    def __del__(self):
        if self.file_handle:
            self.file_handle.close()
            self.file_handle = None

    def write(self, msg):
        """
        Write the given :data:`msg` to the file

        The file is opened first, unless it is open already
        """
        self._check_open()
        self.file_handle.write(msg)
loki-ecmwf-0.3.6/loki/lint/__init__.py0000664000175000017500000000114115167130205017753 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
"""
Code linting infrastructure to allow coding standard checks using Loki.
"""

from loki.lint.utils import * # noqa
from loki.lint.rules import * # noqa
from loki.lint.linter import * # noqa
from loki.lint.reporter import * # noqa
loki-ecmwf-0.3.6/loki/lint/tests/0000775000175000017500000000000015167130205017007 5ustar  alastairalastairloki-ecmwf-0.3.6/loki/lint/tests/test_linter.py0000664000175000017500000004753715167130205021735 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import importlib
from pathlib import Path
import sys
import xml.etree.ElementTree as ET
import pytest
from fparser.two.utils import FortranSyntaxError

from loki import Sourcefile, Assignment, FindNodes, FindVariables, SourceStatus
from loki.lint import (
    GenericHandler, Reporter, Linter, GenericRule,
    LinterTransformation, lint_files, LazyTextfile
)

@pytest.fixture(scope='module', name='rules')
def fixture_rules():
    rules = importlib.import_module('rules')
    return rules


@pytest.fixture(scope='module', name='here')
def fixture_here():
    return Path(__file__).parent


@pytest.fixture(scope='module', name='testdir')
def fixture_testdir(here):
    return here.parent.parent/'tests'


@pytest.fixture(scope='module', name='dummy_file')
def dummy_file_fixture(here):
    file_path = here/'test_linter_dummy_file.F90'
    fcode = """
! dummy file for linter tests
subroutine dummy
end subroutine dummy
    """.strip()
    file_path.write_text(fcode)
    yield file_path
    file_path.unlink()


@pytest.fixture(scope='module', name='dummy_rules')
def dummy_rules_fixture():
    class TestRule(GenericRule):
        config = {'key': 'default_value'}

        @classmethod
        def check(cls, ast, rule_report, config, **kwargs):
            assert len(config) == 1
            assert 'key' in config
            assert config['key'] == 'default_value'
            rule_report.add('TestRule', ast)

    class TestRule2(GenericRule):
        config = {'key': 'default_value'}

        @classmethod
        def check(cls, ast, rule_report, config, **kwargs):
            assert len(config) == 2
            assert 'key' in config
            assert config['key'] == 'non_default_value'
            assert 'other_key' in config
            assert config['other_key'] == 'other_value'
            rule_report.add('TestRule2', ast)

    yield [TestRule2, TestRule]


@pytest.fixture(scope='module', name='dummy_handler')
def dummy_handler_fixture(dummy_file, dummy_rules):
    class TestHandler(GenericHandler):

        def handle(self, file_report):
            assert len(file_report.reports) == 2
            assert len(file_report.reports[0].problem_reports) == 1
            assert file_report.reports[0].problem_reports[0].msg == 'TestRule2'
            assert file_report.reports[0].problem_reports[0].location.path == dummy_file
            assert file_report.reports[0].rule == dummy_rules[0]
            assert file_report.reports[1].problem_reports[0].msg == 'TestRule'
            assert file_report.reports[1].problem_reports[0].location.path == dummy_file
            assert file_report.reports[1].rule == dummy_rules[1]

        def output(self, handler_reports):
            pass

    yield TestHandler()


@pytest.mark.parametrize('rule_names, num_rules', [
    (None, 1),
    (['FooRule'], 0),
    (['DummyRule'], 1)
])
def test_linter_lookup_rules(rules, rule_names, num_rules):
    '''Make sure that linter picks up all rules by default.'''
    rule_list = Linter.lookup_rules(rules, rule_names=rule_names)
    assert len(rule_list) == num_rules


def test_linter_fail(rules):
    '''Make sure that linter fails if it is not given a source file.'''
    with pytest.raises(TypeError, match=r'.*Sourcefile.*expected.*'):
        Linter(None, rules).check(None)


def test_linter_check(dummy_file, dummy_rules, dummy_handler):
    '''Make sure that linter runs through all given rules and hands them
    the right config.'''

    config = {
        'TestRule2': {
            'other_key': 'other_value',
            'key': 'non_default_value'
        }
    }
    reporter = Reporter(handlers=[dummy_handler])
    linter = Linter(reporter, dummy_rules, config=config)
    linter.check(Sourcefile.from_file(dummy_file))


def test_linter_transformation(dummy_file, dummy_rules, dummy_handler):
    '''Make sure that linter runs through all given rules and hands them
    the right config when called via Transformation.'''

    config = {
        'TestRule2': {
            'other_key': 'other_value',
            'key': 'non_default_value'
        }
    }
    reporter = Reporter(handlers=[dummy_handler])
    linter = Linter(reporter, dummy_rules, config=config)
    transformation = LinterTransformation(linter=linter)
    transformation.apply(Sourcefile.from_file(dummy_file))


@pytest.mark.parametrize('file_rule,module_rule,subroutine_rule,assignment_rule,report_counts', [
    ('', '', '', '', 3),
    ('', '', '', '13.37', 3),
    ('', '', '13.37', '', 2),
    pytest.param('', '13.37', '', '', 1, marks=pytest.mark.xfail()),
    ('BlubRule', 'FooRule', 'BarRule', 'BazRule', 3),
    ('', '', '', 'AlwaysComplainRule', 3),
    ('', '', 'AlwaysComplainRule', '', 2),
    pytest.param('', 'AlwaysComplainRule', '', '', 1, marks=pytest.mark.xfail()),
    pytest.param('AlwaysComplainRule', '', '', '', 0, marks=pytest.mark.xfail()),
    pytest.param('13.37', '', '', '', 0, marks=pytest.mark.xfail()),
    # Note: Failed tests are due to the fact that rule disable lookup currently works
    # the wrong way around, see LOKI-64 for details
])
def test_linter_disable_per_scope(file_rule, module_rule, subroutine_rule, assignment_rule, report_counts):
    class AlwaysComplainRule(GenericRule):
        docs = {'id': '13.37'}

        @classmethod
        def check_file(cls, sourcefile, rule_report, config, **kwargs):  # pylint: disable=unused-argument
            rule_report.add(cls.__name__, sourcefile)

        check_module = check_file
        check_subroutine = check_file

    class TestHandler(GenericHandler):
        def handle(self, file_report):
            return len(file_report.reports[0].problem_reports)

        def output(self, handler_reports):
            pass


    fcode = f"""
! loki-lint: disable={file_rule}

module linter_mod
! loki-lint:disable={module_rule}

contains

subroutine linter_routine
! loki-lint: redherring=abc disable={subroutine_rule}
  integer :: i

  i = 1  ! loki-lint  : disable={assignment_rule}
end subroutine linter_routine
end module linter_mod
    """.strip()
    sourcefile = Sourcefile.from_source(fcode)

    handler = TestHandler()
    reporter = Reporter(handlers=[handler])
    rule_list = [AlwaysComplainRule]
    linter = Linter(reporter, rule_list)
    linter.check(sourcefile)

    assert reporter.handlers_reports[handler] == [report_counts]


@pytest.mark.parametrize('rule_list,count', [
    ('', 8),
    ('NonExistentRule', 8),
    ('13.37', 5),
    ('AssignmentComplainRule', 5),
    ('NonExistentRule,AssignmentComplainRule', 5),
    ('23.42', 3),
    ('VariableComplainRule', 3),
    ('23.42,NonExistentRule', 3),
    ('13.37,23.42', 0),
    ('VariableComplainRule,13.37', 0),
    ('23.42,VariableComplainRule,AssignmentComplainRule', 0),
])
def test_linter_disable_inline(rule_list, count):
    class AssignmentComplainRule(GenericRule):
        docs = {'id': '13.37'}

        @classmethod
        def check_subroutine(cls, subroutine, rule_report, config, **kwargs):  # pylint: disable=unused-argument
            for node in FindNodes(Assignment).visit(subroutine.ir):
                rule_report.add(cls.__name__ + '_' + str(node.source.lines[0]), node)

    class VariableComplainRule(GenericRule):
        docs = {'id': '23.42'}

        @classmethod
        def check_subroutine(cls, subroutine, rule_report, config, **kwargs):  # pylint: disable=unused-argument
            for node, variables in FindVariables(with_ir_node=True).visit(subroutine.body):
                for var in variables:
                    rule_report.add(cls.__name__ + '_' + str(var), node)

    class TestHandler(GenericHandler):
        def handle(self, file_report):
            return sum(len(report.problem_reports) for report in file_report.reports)

        def output(self, handler_reports):
            pass

    fcode = """
subroutine linter_disable_inline
integer :: a, b, c

a = 1  ! loki-lint: disable=###
b = 2  !loki-lint:disable=###
c = a + b!     loki-lint       :      disable=###
end subroutine linter_disable_inline
    """.strip()

    fcode = fcode.replace('###', rule_list)
    sourcefile = Sourcefile.from_source(fcode)

    handler = TestHandler()
    reporter = Reporter(handlers=[handler])
    rule_list = [AssignmentComplainRule, VariableComplainRule]
    linter = Linter(reporter, rule_list)
    linter.check(sourcefile)

    assert reporter.handlers_reports[handler] == [count]


@pytest.mark.parametrize('disable_config,count', [
    ({}, 8),  # Empty 'disable' section in config should work
    ({'file.F90': {'rules': ['MyMadeUpRule']}}, 8),  # Disables non-existent rule, no effect
    ({'file.F90': {'rules': ['AssignmentComplainRule']}}, 5),  # Disables one rule
    ({'file.f90': {'rules': ['AssignmentComplainRule']}}, 8),  # Filename spelled wrong, no effect
    ({'file.F90': {'rules': ['VariableComplainRule']}}, 3),  # Disables another rule
    ({'file.F90': {'rules': ['AssignmentComplainRule', 'VariableComplainRule']}}, 0),  # Disables all rules
    ({'file.F90': {  # Disables rule with correct filehash
        'filehash': 'd0d8dd935d0e98a951cbd6c703847bac',
        'rules': ['AssignmentComplainRule']
    }}, 5),
    ({'file.F90': {  # Wrong filehash, no effect
        'filehash': 'd0d8dd935d0e98a951cbd6c703847baa',
        'rules': ['AssignmentComplainRule']
    }}, 8)
])
def test_linter_disable_config(disable_config, count):
    class AssignmentComplainRule(GenericRule):
        docs = {'id': '13.37'}

        @classmethod
        def check_subroutine(cls, subroutine, rule_report, config, **kwargs):  # pylint: disable=unused-argument
            for node in FindNodes(Assignment).visit(subroutine.ir):
                rule_report.add(cls.__name__ + '_' + str(node.source.lines[0]), node)

    class VariableComplainRule(GenericRule):
        docs = {'id': '23.42'}

        @classmethod
        def check_subroutine(cls, subroutine, rule_report, config, **kwargs):  # pylint: disable=unused-argument
            for node, variables in FindVariables(with_ir_node=True).visit(subroutine.body):
                for var in variables:
                    rule_report.add(cls.__name__ + '_' + str(var), node)

    class TestHandler(GenericHandler):
        def handle(self, file_report):
            return sum(len(report.problem_reports) for report in file_report.reports)

        def output(self, handler_reports):
            pass

    fcode = """
module linter_disable_config_mod
    implicit none

    integer :: modvar

contains

    subroutine linter_disable_inline
        integer :: a, b, c

        a = 1
        b = 2
        c = a + b
    end subroutine linter_disable_inline
end module linter_disable_config_mod
    """.strip()

    sourcefile = Sourcefile.from_source(fcode)
    sourcefile.path = Path('file.F90')  # specify a dummy filename
    rule_list = [AssignmentComplainRule, VariableComplainRule]

    config = Linter.default_config(rules=rule_list)
    config['disable'] = disable_config

    handler = TestHandler()
    reporter = Reporter(handlers=[handler])
    linter = Linter(reporter, rule_list, config=config)
    linter.check(sourcefile)

    assert reporter.handlers_reports[handler] == [count]

class PicklableTestHandler(GenericHandler):

    def __init__(self, basedir, target):
        super().__init__(basedir)
        self.target = target

    def handle(self, file_report):
        return str(self.get_relative_filename(file_report.filename))

    def output(self, handler_reports):
        self.target('\n'.join(handler_reports))


@pytest.mark.parametrize('max_workers', [
    None,
    1,
    pytest.param(4, marks=pytest.mark.xfail(
        sys.version_info[:2] == (3, 12),
        reason='For Python 3.12, the string read from target_file_name is empty'
    ))
])
@pytest.mark.parametrize('counter,exclude,files', [
    (15, None, [
        'projA/module/compute_l1_mod.f90',
        'projA/module/compute_l2_mod.f90',
        'projA/module/driverA_mod.f90',
        'projA/module/driverB_mod.f90',
        'projA/module/driverC_mod.f90',
        'projA/module/driverD_mod.f90',
        'projA/module/driverE_mod.f90',
        'projA/module/header_mod.f90',
        'projA/module/kernelA_mod.F90',
        'projA/module/kernelB_mod.F90',
        'projA/module/kernelC_mod.f90',
        'projA/module/kernelD_mod.f90',
        'projA/module/kernelE_mod.f90',
        'projA/source/another_l1.F90',
        'projA/source/another_l2.F90'
    ]),
    (15, [], [
        'projA/module/compute_l1_mod.f90',
        'projA/module/compute_l2_mod.f90',
        'projA/module/driverA_mod.f90',
        'projA/module/driverB_mod.f90',
        'projA/module/driverC_mod.f90',
        'projA/module/driverD_mod.f90',
        'projA/module/driverE_mod.f90',
        'projA/module/header_mod.f90',
        'projA/module/kernelA_mod.F90',
        'projA/module/kernelB_mod.F90',
        'projA/module/kernelC_mod.f90',
        'projA/module/kernelD_mod.f90',
        'projA/module/kernelE_mod.f90',
        'projA/source/another_l1.F90',
        'projA/source/another_l2.F90'
    ]),
    (5, ['**/kernel*', '**/driver*'], [
        'projA/module/compute_l1_mod.f90',
        'projA/module/compute_l2_mod.f90',
        'projA/module/header_mod.f90',
        'projA/source/another_l1.F90',
        'projA/source/another_l2.F90'
    ]),
    (4, ['*.f90'], [
        'projA/module/kernelA_mod.F90',
        'projA/module/kernelB_mod.F90',
        'projA/source/another_l1.F90',
        'projA/source/another_l2.F90'
    ])
])
def test_linter_lint_files_glob(tmp_path, testdir, rules, counter, exclude, files, max_workers):
    basedir = testdir/'sources'
    config = {
        'basedir': str(basedir),
        'include': ['projA/**/*.f90', 'projA/**/*.F90'],
    }
    if exclude is not None:
        config['exclude'] = exclude
    if max_workers is not None:
        config['max_workers'] = max_workers

    target_file_name = tmp_path/'linter_lint_files_glob.log'
    if max_workers and max_workers > 1:
        target = LazyTextfile(target_file_name)
    else:
        target = target_file_name.open('w')
    handler = PicklableTestHandler(basedir=basedir, target=target.write)
    checked = lint_files(rules, config, handlers=[handler])

    assert checked == counter

    if not max_workers or max_workers == 1:
        target.close()

    checked_files = Path(target_file_name).read_text().splitlines()
    assert len(checked_files) == counter
    if max_workers and max_workers > 1:
        # Cannot guarantee order anymore
        assert set(checked_files) == set(files)
    else:
        assert checked_files == files


@pytest.mark.parametrize('routines,files', [
    ({'driverA': {'role': 'driver'}}, [
        'module/driverA_mod.f90',
        'module/kernelA_mod.F90',
        'module/compute_l1_mod.f90',
        'source/another_l1.F90',
        'source/another_l2.F90',
        'module/header_mod.f90'
    ]),
    ({
        'another_l1': {'role': 'driver'},
        'compute_l1': {'role': 'driver'}
    }, [
        'source/another_l1.F90',
        'module/compute_l1_mod.f90',
        'source/another_l2.F90',
        'module/header_mod.f90'
    ]),
    ({
        'another_l1': {'role': 'driver'}
    }, [
        'source/another_l1.F90',
        'source/another_l2.F90',
        'module/header_mod.f90'
    ]),
])
def test_linter_lint_files_scheduler(testdir, rules, routines, files):
    basedir = testdir/'sources/projA'

    class TestHandler(GenericHandler):
        def __init__(self, **kwargs):
            super().__init__(**kwargs)
            self.counter = 0
            self.files = []

        def handle(self, file_report):
            self.counter += 1
            self.files += [str(Path(file_report.filename).relative_to(basedir))]

        def output(self, handler_reports):
            pass

    config = {
        'basedir': str(basedir),
        'scheduler': {
            'default': {
                'mode': 'lint',
                'role': 'kernel',
                'expand': True,
                'strict': False,
                'block': ['compute_l2'],
                'enable_imports': True,
            },
            'routines': routines
        }
    }

    handler = TestHandler()
    checked = lint_files(rules, config, handlers=[handler])

    counter = len(files)
    assert checked == counter
    assert handler.counter == counter
    assert handler.files == files


@pytest.mark.parametrize('config', [
    {'scheduler': {
        'default': {
            'mode': 'lint',
            'role': 'kernel',
            'expand': True,
            'strict': True,
        },
        'routines': {'other_routine': {}}
    }},
    {'include': ['linter_lint_files_fix.F90']}
])
@pytest.mark.parametrize('backup_suffix', [None, '.bak'])
def test_linter_lint_files_fix(tmp_path, config, backup_suffix):

    class TestRule(GenericRule):

        fixable = True

        @classmethod
        def check_subroutine(cls, subroutine, rule_report, config, **kwargs):
            if not subroutine.name.isupper():
                rule_report.add(f'Subroutine name "{subroutine.name}" is not upper case', subroutine)

        @classmethod
        def fix_sourcefile(cls, sourcefile, rule_report, config):
            if rule_report.problem_reports:
                subroutine.source.invalidate(children=True)
                subroutine.ir.source.invalidate(children=True)

        @classmethod
        def fix_subroutine(cls, subroutine, rule_report, config):
            assert len(rule_report.problem_reports) == 1
            if rule_report.problem_reports[0].location is subroutine:
                subroutine.name = subroutine.name.upper()
                subroutine.source.invalidate()
                return {None: None}
            return {}

    fcode = """
subroutine some_routine
implicit none
end subroutine some_routine

subroutine OTHER_ROUTINE
implicit none
call some_routine
end subroutine OTHER_ROUTINE
    """.strip()
    assert fcode.count('some_routine') == 3
    assert fcode.count('SOME_ROUTINE') == 0

    filename = tmp_path/'linter_lint_files_fix.F90'
    filename.write_text(fcode)

    config['basedir'] = tmp_path
    config['fix'] = True
    if backup_suffix:
        config['backup_suffix'] = backup_suffix

    checked_files = lint_files([TestRule], config)
    assert checked_files == 1

    fixed_fcode = filename.read_text()
    assert fixed_fcode.count('some_routine') == 1  # call statement
    assert fixed_fcode.count('SOME_ROUTINE') == 2

    if backup_suffix:
        backup_file = filename.with_suffix('.bak.F90')
        assert backup_file.read_text() == fcode


@pytest.mark.parametrize('config', [
    {'scheduler': {
        'default': {
            'mode': 'lint',
            'role': 'kernel',
            'expand': True,
            'strict': True
        },
        'routines': {'other_routine': {}}
    }},
    {'include': ['*.F90']}
])
def test_linter_fortran_syntax_error(tmp_path, config, rules):
    fcode = """
subroutine some_routine
implicit none
This is invalid Fortran syntax
end subroutine some_routine

subroutine OTHER_ROUTINE
implicit none
call some_routine
end subroutine OTHER_ROUTINE
    """.strip()

    filename = tmp_path/'linter_lint_files_syntax_error.F90'
    filename.write_text(fcode)
    junitxml_file = tmp_path/'junitxml.xml'

    config.update({
        'basedir': tmp_path,
        'junitxml_file': str(junitxml_file)
    })

    if 'scheduler' in config:
        with pytest.raises(FortranSyntaxError):
            lint_files(rules, config)
    else:
        checked_files = lint_files(rules, config)
        assert checked_files == 0

        # Sanity check that this ends up in reports
        xml = ET.parse(junitxml_file).getroot()
        assert xml.attrib['tests'] == '1'
        assert xml.attrib['failures'] == '1'
        report = xml.find('testsuite/testcase/failure')
        assert 'This is invalid Fortran syntax' in report.attrib['message']
loki-ecmwf-0.3.6/loki/lint/tests/rules.py0000664000175000017500000000115415167130205020514 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.lint import GenericRule, RuleType

__all__ = ['DummyRule']


class DummyRule(GenericRule):

    type = RuleType.WARN

    docs = {'title': 'A dummy rule for the sake of testing the Linter'}

    config = {'dummy_key': 'dummy value'}
loki-ecmwf-0.3.6/loki/lint/tests/test_reporter.py0000664000175000017500000002023615167130205022265 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import importlib
from pathlib import Path
import xml.etree.ElementTree as ET
import pytest

try:
    import yaml
    HAVE_YAML = True
except ImportError:
    HAVE_YAML = False

from loki.ir import Intrinsic
from loki.lint.linter import lint_files
from loki.lint.reporter import (
    ProblemReport, RuleReport, FileReport,
    DefaultHandler, ViolationFileHandler,
    LazyTextfile
)
from loki.lint.rules import GenericRule, RuleType


@pytest.fixture(scope='module', name='here')
def fixture_here():
    return Path(__file__).parent


@pytest.fixture(scope='module', name='testdir')
def fixture_testdir(here):
    return here.parent.parent/'tests'


@pytest.fixture(scope='module', name='rules')
def fixture_rules():
    rules = importlib.import_module('rules')
    return rules


@pytest.fixture(scope='module', name='dummy_file')
def dummy_file_fixture(here):
    file_path = here/'test_reporter_dummy_file.F90'
    fcode = "! dummy file for reporter tests"
    file_path.write_text(fcode)
    yield file_path
    file_path.unlink()


@pytest.fixture(scope='module', name='dummy_file_report')
def fixture_dummy_file_report(dummy_file):
    file_report = FileReport(str(dummy_file))
    rule_report = RuleReport(GenericRule)
    rule_report.add('Some message', Intrinsic('foobar'))
    rule_report.add('Other message', Intrinsic('baz'))
    file_report.add(rule_report)
    return file_report


class DummyLogger:

    def __init__(self):
        self.messages = []

    def write(self, msg):
        self.messages += [msg]


def test_reports(dummy_file):
    file_report = FileReport(str(dummy_file))
    assert not file_report.reports and file_report.reports is not None

    class SomeRule(GenericRule):
        pass

    rule_report = RuleReport(SomeRule)
    assert not rule_report.problem_reports and rule_report.problem_reports is not None
    rule_report.add('Some message', Intrinsic('foobar'))
    rule_report.add('Other message', Intrinsic('baz'))
    assert len(rule_report.problem_reports) == 2
    assert isinstance(rule_report.problem_reports[0], ProblemReport)
    assert rule_report.problem_reports[0].msg == 'Some message'

    file_report.add(rule_report)
    assert len(file_report.reports) == 1


def test_default_handler_immediate(dummy_file_report):
    logger_target = DummyLogger()
    handler = DefaultHandler(target=logger_target.write)
    reports = handler.handle(dummy_file_report)
    assert len(logger_target.messages) == 2
    handler.output([reports])
    assert len(logger_target.messages) == 2


def test_default_handler_not_immediate(dummy_file_report):
    logger_target = DummyLogger()
    handler = DefaultHandler(target=logger_target.write, immediate_output=False)
    reports = handler.handle(dummy_file_report)
    assert len(logger_target.messages) == 0
    handler.output([reports])
    assert len(logger_target.messages) == 2


@pytest.mark.skipif(not HAVE_YAML, reason='Pyyaml not installed')
def test_violation_file_handler(dummy_file, dummy_file_report):
    logger_target = DummyLogger()
    handler = ViolationFileHandler(target=logger_target.write)
    reports = handler.handle(dummy_file_report)
    handler.output([reports])
    assert len(logger_target.messages) == 1
    yaml_report = yaml.safe_load(logger_target.messages[0])
    assert len(yaml_report) == 1
    assert str(dummy_file) in yaml_report
    file_report = yaml_report[str(dummy_file)]
    assert file_report['filehash'] == dummy_file_report.hash
    assert len(file_report['rules']) == 1
    assert 'GenericRule' in file_report['rules']


def test_lazy_textfile(tmp_path):
    # Choose the output file and make sure it doesn't exist
    filename = tmp_path/'lazytextfile.log'

    # Instantiating the object should _not_ create the file
    f = LazyTextfile(filename)
    assert not filename.exists()

    # Writing to the object should open (and therefore create) the file
    f.write('s0me TEXT')
    assert filename.exists()

    # Writing more to the object should append text
    f.write(' AAAAND other Th1ngs!!!')

    # Deleting the object should (hopefully) trigger __del__,
    # which should flush the buffers to disk and allow us to read
    # (and check) the content
    del f
    assert filename.read_text() == 's0me TEXT AAAAND other Th1ngs!!!'


@pytest.mark.parametrize('max_workers', [None, 1])
@pytest.mark.parametrize('fail_on,failures', [(None,0), ('kernel',4)])
def test_linter_junitxml(tmp_path, testdir, max_workers, fail_on, failures):
    class RandomFailingRule(GenericRule):
        type = RuleType.WARN
        docs = {'title': 'A dummy rule for the sake of testing the Linter'}
        config = {'dummy_key': 'dummy value'}

        @classmethod
        def check_subroutine(cls, subroutine, rule_report, config, **kwargs):
            if fail_on and fail_on in subroutine.name:
                rule_report.add(cls.__name__, subroutine)

    basedir = testdir/'sources'
    junitxml_file = tmp_path/'linter_junitxml_outputfile.xml'
    config = {
        'basedir': str(basedir),
        'include': ['projA/**/*.f90', 'projA/**/*.F90'],
        'junitxml_file': str(junitxml_file)
    }
    if max_workers is not None:
        config['max_workers'] = max_workers

    checked = lint_files([RandomFailingRule], config)

    assert checked == 15

    # Just a few sanity checks on the XML
    xml = ET.parse(junitxml_file).getroot()
    assert xml.tag == 'testsuites'
    assert xml.attrib['tests'] == '15'
    assert xml.attrib['failures'] == str(failures)


@pytest.mark.skipif(not HAVE_YAML, reason='Pyyaml not installed')
@pytest.mark.parametrize('max_workers', [None, 1])
@pytest.mark.parametrize('fail_on,failures', [(None,0), ('kernel',4)])
@pytest.mark.parametrize('use_line_hashes', [None, False, True])
def test_linter_violation_file(tmp_path, testdir, rules, max_workers, fail_on, failures, use_line_hashes):
    class RandomFailingRule(GenericRule):
        type = RuleType.WARN
        docs = {'title': 'A dummy rule for the sake of testing the Linter'}
        config = {'dummy_key': 'dummy value'}

        @classmethod
        def check_subroutine(cls, subroutine, rule_report, config, **kwargs):
            if fail_on and fail_on in subroutine.name:
                rule_report.add(cls.__name__, subroutine)

    basedir = testdir/'sources'
    violations_file = tmp_path/'linter_violations_file.yml'
    config = {
        'basedir': str(basedir),
        'include': ['projA/**/*.f90', 'projA/**/*.F90'],
        'violations_file': str(violations_file),
    }
    if use_line_hashes is not None:
        config['use_violations_file_line_hashes'] = use_line_hashes
    if max_workers is not None:
        config['max_workers'] = max_workers

    checked = lint_files([RandomFailingRule, rules.DummyRule], config)

    assert checked == 15

    # Just a few sanity checks on the yaml
    yaml_report = yaml.safe_load(violations_file.read_text())
    if not failures:
        assert yaml_report is None
    else:
        assert len(yaml_report) == failures

        for file, report in yaml_report.items():
            assert fail_on in file
            if use_line_hashes is False:
                assert 'filehash' in report
                assert report['rules'] == ['RandomFailingRule']
            else:
                assert 'filehash' not in report
                assert len(report['rules']) == 1
                assert 'RandomFailingRule' in report['rules'][0]
                if file.endswith('kernelE_mod.f90'):
                    assert len(report['rules'][0]['RandomFailingRule']) == 2
                else:
                    assert len(report['rules'][0]['RandomFailingRule']) == 1

    # Plug the violations file into the config and see if we don't have
    # violations in another linter pass
    config['disable'] = yaml_report
    checked = lint_files([RandomFailingRule, rules.DummyRule], config)
    assert checked == 15
    assert yaml.safe_load(violations_file.read_text()) is None
loki-ecmwf-0.3.6/loki/lint/rules.py0000664000175000017500000001452415167130205017357 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Base class for linter rules and available rule types
"""
from enum import Enum

from loki.lint.utils import is_rule_disabled
from loki.module import Module
from loki.sourcefile import Sourcefile
from loki.subroutine import Subroutine


class RuleType(Enum):
    """
    Available types for rules with increasing severity.
    """

    INFO = 1
    WARN = 2
    SERIOUS = 3
    ERROR = 4


class GenericRule:
    """
    Generic interface for linter rules providing default values and the
    general :meth:`check` routine that calls the specific entry points to rules
    (subroutines, modules, and the source file).

    When adding a new rule, it must inherit from :any:`GenericRule`
    and define :data:`type` and provide ``title`` (and ``id``, if applicable)
    in :data:`docs`.
    Optional configuration values can be defined in :data:`config` together with
    the default value for this option. Only the relevant entry points to a
    rule must be implemented.
    """

    type = None
    """
    The rule type as one of the categories in :any:`RuleType`
    """

    docs = None
    """
    :any:`dict` with description of the rule

    Typically, this should include ``"id"`` and ``"title"``. Allows for
    Python's format specification mini-language in ``"title"`` to fill values
    using data from :data:`config`, with the field name corresponding to the
    config key.
    """

    config = {}
    """
    Dict of configuration keys and their default values

    These values can be overriden externally in the linter config file and are
    passed automatically to the :meth:`check` routine.
    """

    fixable = False
    """
    Indicator for a fixable rule that implements a corresponding :meth:`fix`
    routine
    """

    deprecated = False
    """
    Indicator for a deprecated rule
    """

    replaced_by = ()
    """
    List of rules that replace the deprecated rule, where applicable
    """

    @classmethod
    def identifiers(cls):
        """
        Return list of strings that identify this rule
        """
        if cls.docs and 'id' in cls.docs:  # pylint: disable=unsupported-membership-test
            return [cls.__name__, cls.docs['id']]  # pylint: disable=unsubscriptable-object
        return [cls.__name__]

    @classmethod
    def check_module(cls, module, rule_report, config):
        """
        Perform rule checks on module level

        Must be implemented by a rule if applicable.
        """

    @classmethod
    def check_subroutine(cls, subroutine, rule_report, config, **kwargs):
        """
        Perform rule checks on subroutine level

        Must be implemented by a rule if applicable.
        """

    @classmethod
    def check_file(cls, sourcefile, rule_report, config):
        """
        Perform rule checks on file level

        Must be implemented by a rule if applicable.
        """

    @classmethod
    def check(cls, ast, rule_report, config, **kwargs):
        """
        Perform checks on all entities in the given IR object

        This routine calls :meth:`check_module`, :meth:`check_subroutine`
        and :meth:`check_file` as applicable for all entities in the given
        IR object.

        Parameters
        ----------
        ast : :any:`Sourcefile` or :any:`Module` or :any:`Subroutine`
            The IR object to be checked.
        rule_report : :any:`RuleReport`
            The reporter object in which rule violations should be registered.
        config : dict
            The rule configuration, filled with externally provided
            configuration values or the rule's default configuration.
        """

        # Perform checks on source file level
        if isinstance(ast, Sourcefile):
            cls.check_file(ast, rule_report, config)

            # Then recurse for all modules and subroutines in that file
            if hasattr(ast, 'modules') and ast.modules is not None:
                for module in ast.modules:
                    cls.check(module, rule_report, config, **kwargs)
            if hasattr(ast, 'subroutines') and ast.subroutines is not None:
                for subroutine in ast.subroutines:
                    cls.check(subroutine, rule_report, config, **kwargs)

        # Perform checks on module level
        elif isinstance(ast, Module):
            if is_rule_disabled(ast.spec, cls.identifiers()):
                return

            cls.check_module(ast, rule_report, config)

            # Then recurse for all subroutines in that module
            if hasattr(ast, 'subroutines') and ast.subroutines is not None:
                for subroutine in ast.subroutines:
                    cls.check(subroutine, rule_report, config, **kwargs)

        # Peform checks on subroutine level
        elif isinstance(ast, Subroutine):
            if is_rule_disabled(ast.ir, cls.identifiers()):
                return

            if not (targets := kwargs.pop('targets', None)):
                items = kwargs.get('items', ())
                item = [item for item in items if item.local_name.lower() == ast.name.lower()]
                if len(item) > 0:
                    targets = item[0].targets
            cls.check_subroutine(ast, rule_report, config, targets=targets, **kwargs)

            # Recurse for any procedures contained in a subroutine
            if hasattr(ast, 'members') and ast.members is not None:
                for member in ast.members:
                    cls.check(member, rule_report, config, **kwargs)

    @classmethod
    def fix_module(cls, module, rule_report, config):
        """
        Fix rule violations on module level

        Must be implemented by a rule if applicable.
        """

    @classmethod
    def fix_subroutine(cls, subroutine, rule_report, config):
        """
        Fix rule violations on subroutine level

        Must be implemented by a rule if applicable.
        """

    @classmethod
    def fix_sourcefile(cls, sourcefile, rule_report, config):
        """
        Fix rule violations on sourcefile level

        Must be implemented by a rule if applicable.
        """
loki-ecmwf-0.3.6/loki/lint/linter.py0000664000175000017500000003641415167130205017524 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
:any:`Linter` operator class definition to drive rule checking for
:any:`Sourcefile` objects
"""
from concurrent.futures import as_completed
import inspect
from multiprocessing import Manager
from pathlib import Path
import shutil
from codetiming import Timer

from loki.jit_build import workqueue
from loki.batch import Scheduler, SchedulerConfig, Item, Transformation
from loki.config import config as loki_config
from loki.lint.reporter import (
    FileReport, RuleReport, Reporter, LazyTextfile,
    DefaultHandler, JunitXmlHandler, ViolationFileHandler
)
from loki.lint.utils import Fixer
from loki.logging import logger
from loki.sourcefile import Sourcefile
from loki.tools import filehash, find_paths, CaseInsensitiveDict


__all__ = ['Linter', 'LinterTransformation', 'lint_files']


class Linter:
    """
    The operator class for Loki's linter functionality

    It allows to check :any:`Sourcefile` objects for compliance to rules
    specified as subclasses of :any:`GenericRule`.

    Parameters
    ----------
    reporter : :any:`Reporter`
        The reporter instance to be used for problem reporting.
    rules : list of :any:`GenericRule` or a Python module
        List of rules to check files against or a module that contains the rules.
    config : dict, optional
        Configuration (e.g., from config file) to change behaviour of rules.
    """
    def __init__(self, reporter, rules, config=None):
        self.reporter = reporter
        if inspect.ismodule(rules):
            rule_names = config.get('rules') if config else None
            self.rules = Linter.lookup_rules(rules, rule_names=rule_names)
        elif config and config.get('rules') is not None:
            self.rules = [rule for rule in rules if rule.__name__ in config.get('rules')]
        else:
            self.rules = rules
        self.config = self.default_config(self.rules)
        self.update_config(config)

    @staticmethod
    def lookup_rules(rules_module, rule_names=None):
        """
        Obtain all available rule classes in a module

        Parameters
        ----------
        rules_module : Python module
            The module in which rules are implemented.
        rule_names : list of str, optional
            Only look for rules with a name that is in this list.

        Returns
        -------
        list
            A list of rule classes.
        """
        rule_list = inspect.getmembers(
            rules_module, lambda obj: inspect.isclass(obj) and obj.__name__ in rules_module.__all__)
        if rule_names is not None:
            rule_list = [r for r in rule_list if r[0] in rule_names]
        return [r[1] for r in rule_list]

    @staticmethod
    def default_config(rules):
        """
        Return default configuration for a list of rules

        Parameters
        ----------
        rules : list
            List of rules for which to compile the default config.

        Returns
        -------
        dict
            Mapping of rule names to the dict of default configuration
            values for each rule.
        """
        # List of rules
        config = {'rules': [rule.__name__ for rule in rules]}
        # Default options for rules
        for rule in rules:
            config[rule.__name__] = rule.config
        return config

    def update_config(self, config):
        """
        Update the stored configuration using the given :data:`config` dict
        """
        if config is None:
            return
        for key, val in config.items():
            # If we have a dict, update that entry
            if isinstance(val, dict) and key in self.config:
                self.config[key].update(val)
            else:
                self.config[key] = val

    def check(self, sourcefile, overwrite_rules=None, overwrite_config=None, **kwargs):
        """
        Check the given :data:`sourcefile` and compile a :any:`FileReport`.

        The file report is then stored in the :any:`Reporter` given while
        creating the :any:`Linter`. Additionally, the file report is returned,
        e.g., to use it wiht :meth:`fix`.

        Parameters
        ----------
        sourcefile : :any:`Sourcefile`
            The source file to check.
        overwrite_rules : list of rules, optional
            List of rules to check. This overwrites the stored list of rules.
        overwrite_config : dict, optional
            Configuration that is used to update the stored configuration.

        Returns
        -------
        :any:`FileReport`
            The report for this file containing any discovered violations.
        """
        if not isinstance(sourcefile, Sourcefile):
            raise TypeError(f'{type(sourcefile)} given, {Sourcefile} expected')

        # Prepare config
        config = self.config
        if overwrite_config:
            config.update(overwrite_config)
        disable_config = config.get('disable')
        if not isinstance(disable_config, dict):
            disable_config = {}

        # Initialize report for this file
        filename = str(sourcefile.path) if sourcefile.path else None
        file_report = FileReport(filename, hash=filehash(sourcefile.source.string))

        # Check "disable" config section for an entry matching the file name and, if given, filehash
        disabled_rules = CaseInsensitiveDict()
        disable_file_key = next((key for key in disable_config if sourcefile.path.match(key)), None)
        if disable_file_key:
            disable_file = disable_config[disable_file_key]
            if 'filehash' not in disable_file or disable_file['filehash'] == file_report.hash:
                for rule in disable_file.get('rules', []):
                    if isinstance(rule, dict):
                        for name, line_hashes in rule.items():
                            disabled_rules[name] = line_hashes
                    else:
                        disabled_rules[rule] = True

        # Prepare list of rules
        rules = overwrite_rules if overwrite_rules is not None else self.rules
        rules = [rule for rule in rules if disabled_rules.get(rule.__name__) is not True]

        timer = Timer(logger=None)

        # Run all the rules on that file
        for rule in rules:
            timer.start()
            rule_report = RuleReport(rule, disabled=disabled_rules.get(rule.__name__))
            rule.check(sourcefile, rule_report, config[rule.__name__], **kwargs)
            rule_report.elapsed_sec = timer.stop()
            file_report.add(rule_report)

        # Store the file report
        self.reporter.add_file_report(file_report)
        return file_report

    def fix(self, sourcefile, file_report, backup_suffix=None, overwrite_config=None):
        """
        Fix all rule violations in :data:`file_report` that were reported by
        fixable rules and write them into the original file

        Parameters
        ----------
        sourcefile : :any:`Sourcefile`
            The source file to fix.
        file_report : :any:`FileReport`
            The report created by :meth:`check` for that file.
        backup_suffix : str, optional
            Create a copy of the original file using this file name suffix.
        overwrite_config : dict, optional
            Configuration that is used to update the stored configuration.
        """
        if not isinstance(sourcefile, Sourcefile):
            raise TypeError(f'{type(sourcefile)} given, {Sourcefile} expected')
        file_path = Path(sourcefile.path)
        assert file_path == Path(file_report.filename)

        # Nothing to do if there are no fixable reports
        if not file_report.fixable_reports:
            return

        # Make a backup copy if requested
        if backup_suffix:
            backup_path = file_path.with_suffix(backup_suffix + file_path.suffix)
            shutil.copy(file_path, backup_path)

        # Extract configuration
        config = self.config
        if overwrite_config:
            config.update(overwrite_config)

        # Apply the fixes
        sourcefile = Fixer.fix(sourcefile, file_report.fixable_reports, config)

        # Create the the source string for the output
        sourcefile.write(conservative=True)


class LinterTransformation(Transformation):
    """
    Apply :class:`Linter` as a :any:`Transformation` to :any:`Sourcefile`

    The :any:`FileReport` is stored in the ``trafo_data` in an :any:`Item`
    object, if it is provided to :meth:`transform_file`, e.g., during a
    :any:`Scheduler` traversal.

    Parameters
    ----------
    linter : :class:`Linter`
        The linter instance to use
    key : str, optional
        Lookup key overwrite for stored reports in the ``trafo_data`` of :any:`Item`
    """

    _key = 'LinterTransformation'

    # This transformation is applied over the file graph
    traverse_file_graph = True

    item_filter = Item  # Include everything in the dependency tree

    def __init__(self, linter, key=None, **kwargs):
        self.linter = linter
        self.counter = 0
        if key:
            self._key = key
        super().__init__(**kwargs)

    def transform_file(self, sourcefile, **kwargs):
        item = kwargs.get('item')
        report = self.linter.check(sourcefile, **kwargs)
        self.counter += 1
        if item:
            item.trafo_data[self._key] = report
        if self.linter.config.get('fix'):
            self.linter.fix(sourcefile, report, backup_suffix=self.linter.config.get('backup_suffix'))


def lint_files_scheduler(linter, basedir, config):
    """
    Discover files relative to :data:`basedir` using :any:`SchedulerConfig`
    from :data:`config`, and apply :data:`linter` on each of them.
    """
    scheduler = Scheduler(paths=[basedir], config=SchedulerConfig.from_dict(config))
    transformation = LinterTransformation(linter=linter)
    scheduler.process(transformation=transformation)
    return transformation.counter


def check_and_fix_file(path, linter, fix=False, backup_suffix=None):
    """
    Check the file at :data:`path` with :data:`linter` and, optionally,
    fix it
    """
    try:
        source = Sourcefile.from_file(path)
        report = linter.check(source)
        if fix:
            linter.fix(source, report, backup_suffix=backup_suffix)
    except Exception as exc:  # pylint: disable=broad-except
        linter.reporter.add_file_error(path, type(exc), str(exc))
        if loki_config['debug']:
            raise exc
        return False
    return True


def lint_files_glob(linter, basedir, include, exclude=None, max_workers=1, fix=False, backup_suffix=None):
    """
    Discover files relative to :data:`basedir` using patterns in :data:`include`
    and apply :data:`linter` on each of them.
    """
    files = find_paths(basedir, include, ignore=exclude)
    checked_count = 0
    if max_workers == 1 or loki_config['debug']:
        for path in files:
            checked_count += check_and_fix_file(path, linter, fix=fix, backup_suffix=backup_suffix)
    else:
        manager = Manager()
        linter.reporter.init_parallel(manager)

        with workqueue(workers=max_workers, logger=logger, manager=manager) as q:
            log_queue = getattr(q, 'log_queue', None)
            q_tasks = [
                q.call(check_and_fix_file, f, linter, fix=fix, backup_suffix=backup_suffix, log_queue=log_queue)
                for f in files
            ]
            for t in as_completed(q_tasks):
                checked_count += t.result()

    return checked_count


def lint_files(rules, config, handlers=None):
    """
    Construct a :any:`Linter` according to :data:`config` and
    check the rules in :data:`rules`

    Depending on the given config values, this will use a :any:`Scheduler`
    to discover files and drive the linting, or apply glob-based file
    discovery and apply linting to each of them.

    Common config options include:

    .. code-block::

       {
           'basedir': ,
           'max_workers': , # Optional: use multiple workers
           'fix': , # Optional: attempt automatic fixing of rule violations
           'backup_suffix': , # Optional: Backup original file with given suffix
           'junitxml_file': ,  # Optional: write JunitXML-output of lint results
           'violations_file': ,  # Optional: write a YAML file containing violations
           'rules': ['SomeRule', 'AnotherRule', ...],  # Optional: select only these rules
           'SomeRule': , # Optional: configuration values for individual rules
        }

    The ``basedir`` option is given as the discovery path to the :any:`Scheduler`.
    See :any:`SchedulerConfig` for more details on the available config options.

    See :any:`JunitXmlHandler` and :any:`ViolationFileHandler` for more details
    on the output file options.

    The ``rules`` option in the config allows selecting only certain rules out of
    the provided :data:`rules` argument.

    In addition, :data:`config` takes for scheduler the following options:

    .. code-block::

       {
           'scheduler': 
       }

    If the ``scheduler`` key is found in :data:`config`, the scheduler-based
    linting is automatically enabled.

    For glob-based file discovery, the config takes the following options:

    .. code-block::

       {
           'include': [, , ...]
           'exclude': [] # Optional
       }

    The ``include`` and ``exclude`` options are provided to :any:`find_paths` to
    discover files that should be linted.

    Parameters
    ----------
    rules : list of :any:`GenericRule` or a Python module
        List of rules to check files against or a module that contains the rules.
    config : dict
        Configuration for file discovery/scheduler and linting rules
    handlers : list, optional
        Additional instances of :any:`GenericHandler` to use during linting

    Returns
    -------
    int :
        The number of checked files
    """
    basedir = config['basedir']

    if not handlers:
        handlers = []
    handlers += [DefaultHandler(basedir=basedir)]
    if 'junitxml_file' in config:
        junitxml_file = LazyTextfile(config['junitxml_file'])
        handlers.append(JunitXmlHandler(target=junitxml_file.write, basedir=basedir))
    if 'violations_file' in config:
        violations_file = LazyTextfile(config['violations_file'])
        handlers.append(ViolationFileHandler(
            target=violations_file.write, basedir=basedir,
            use_line_hashes=config.get('use_violations_file_line_hashes', True)
        ))

    linter = Linter(reporter=Reporter(handlers), rules=rules, config=config)
    if 'scheduler' in config:
        checked_count = lint_files_scheduler(linter, basedir, config['scheduler'])
    else:
        checked_count = lint_files_glob(
            linter, basedir, config['include'],
            exclude=config.get('exclude'), max_workers=config.get('max_workers', 1),
            fix=config.get('fix', False), backup_suffix=config.get('backup_suffix')
        )

    linter.reporter.output()
    return checked_count
loki-ecmwf-0.3.6/loki/function.py0000664000175000017500000000757215167130205017111 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.subroutine import Subroutine


__all__ = ['Function']


class Function(Subroutine):
    """
    Class to handle and manipulate a single function.

    Parameters
    ----------
    name : str
        Name of the subroutine.
    args : iterable of str, optional
        The names of the dummy args.
    docstring : tuple of :any:`Node`, optional
        The subroutine docstring in the original source.
    spec : :any:`Section`, optional
        The spec of the subroutine.
    body : :any:`Section`, optional
        The body of the subroutine.
    contains : :any:`Section`, optional
        The internal-subprogram part following a ``CONTAINS`` statement
        declaring member procedures
    prefix : iterable, optional
        Prefix specifications for the procedure
    bind : optional
        Bind information (e.g., for Fortran ``BIND(C)`` annotation).
    result_name : str, optional
        The name of the result variable for functions.
    ast : optional
        Frontend node for this subroutine (from parse tree of the frontend).
    source : :any:`Source`
        Source object representing the raw source string information from the
        read file.
    parent : :any:`Scope`, optional
        The enclosing parent scope of the subroutine, typically a :any:`Module`
        or :any:`Subroutine` object. Declarations from the parent scope remain
        valid within the subroutine's scope (unless shadowed by local
        declarations).
    rescope_symbols : bool, optional
        Ensure that the type information for all :any:`TypedSymbol` in the
        subroutine's IR exist in the subroutine's scope or the scope's parents.
        Defaults to `False`.
    symbol_attrs : :any:`SymbolTable`, optional
        Use the provided :any:`SymbolTable` object instead of creating a new
    incomplete : bool, optional
        Mark the object as incomplete, i.e. only partially parsed. This is
        typically the case when it was instantiated using the :any:`Frontend.REGEX`
        frontend and a full parse using one of the other frontends is pending.
    parser_classes : :any:`RegexParserClass`, optional
        Provide the list of parser classes used during incomplete regex parsing
    """

    is_function = True

    def __init__(self, *args, parent=None, symbol_attrs=None, **kwargs):
        super().__init__(*args, parent=parent, symbol_attrs=symbol_attrs, **kwargs)

        self.__initialize__(*args, **kwargs)

    def __initialize__(self, name, *args, result_name=None, **kwargs):
        self.result_name = result_name

        # Make sure 'result_name' is defined if it's a function
        if self.result_name is None:
            self.result_name = name

        super().__initialize__(name, *args, **kwargs)

    def __repr__(self):
        """ String representation """
        return f'Function:: {self.name}'

    @property
    def return_type(self):
        """ Return the return_type of this subroutine """
        return self.symbol_attrs.get(self.result_name)

    def clone(self, **kwargs):
        """
        Create a copy of the function with the option to override
        individual parameters.

        Parameters
        ----------
        **kwargs :
            Any parameters from the constructor of :any:`Function`.

        Returns
        -------
        :any:`Function`
            The cloned subroutine object.
        """
        if self.result_name and 'result_name' not in kwargs:
            kwargs['result_name'] = self.result_name

        # Escalate to parent class
        return super().clone(**kwargs)
loki-ecmwf-0.3.6/conftest.py0000664000175000017500000000256415167130205016147 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Local pytest plugin to add bespoke pytest extensions for Loki

See the pytest documentation for more details:
https://docs.pytest.org/en/stable/how-to/writing_plugins.html#local-conftest-plugins
"""

from loki.config import config as loki_config


def pytest_addoption(parser, pluginmanager):  # pylint: disable=unused-argument
    """
    Add options to the pytest CLI

    Additional options can be specified via ``parser.addoption`` using the same signature as
    :any:`argparse.ArgumentParser.add_argument`.

    For Loki, we add ``--loki-log-level`` to overwrite the log level in :any:`loki.logging`.
    """
    parser.addoption('--loki-log-level', dest='LOKI_LOG_LEVEL', default='INFO',
                     help='Change the Loki log level (ERROR, WARNING, INFO, PERF, DETAIL, DEBUG)')


def pytest_configure(config):
    """
    Apply configuration changes

    This function is invoked after all command line options have been processed
    """
    loki_config['log-level'] = config.option.LOKI_LOG_LEVEL
loki-ecmwf-0.3.6/codecov.yml0000664000175000017500000000040115167130205016101 0ustar  alastairalastaircodecov:
  require_ci_to_pass: yes
  notify:
    wait_for_ci: yes

coverage:
  precision: 2
  round: down

  status:
    project:
      default:
        enabled: yes
        target: auto
        threshold: 0.1
    patch:
      default:
        enabled: off
loki-ecmwf-0.3.6/README.md0000664000175000017500000000533515167130205015226 0ustar  alastairalastair# Loki: Freely programmable source-to-source translation

[![license](https://img.shields.io/github/license/ecmwf-ifs/loki)](https://www.apache.org/licenses/LICENSE-2.0.html)
[![code-checks](https://github.com/ecmwf-ifs/loki/actions/workflows/code_checks.yml/badge.svg)](https://github.com/ecmwf-ifs/loki/actions/workflows/code_checks.yml)
[![tests](https://github.com/ecmwf-ifs/loki/actions/workflows/tests.yml/badge.svg)](https://github.com/ecmwf-ifs/loki/actions/workflows/tests.yml)
[![regression-tests](https://github.com/ecmwf-ifs/loki/actions/workflows/regression_tests.yml/badge.svg)](https://github.com/ecmwf-ifs/loki/actions/workflows/regression_tests.yml)
[![codecov](https://codecov.io/gh/ecmwf-ifs/loki/branch/main/graph/badge.svg?token=9ZDS95SFWI)](https://codecov.io/gh/ecmwf-ifs/loki)

**Loki is an experimental tool** to explore the possible use of
source-to-source translation for ECMWF's Integrated Forecasting System (IFS) and
associated Fortran software packages.

Loki is based on compiler technology (visitor patterns and ASTs) and aims to
provide an abstract, language-agnostic representation of a kernel, as well as a
programmable (pythonic) interface that allows developers to experiment with
different kernel implementations and optimizations.  The aim is to allow changes
to programming models and coding styles to be encoded and automated instead of
hand-applying them, enabling advanced experimentation with large kernels as well
as bulk processing of large numbers of source files to evaluate different kernel
implementations and programming models.

*This package is made available to support research collaborations and is not
officially supported by ECMWF.*

## Contact

Michael Lange (michael.lange@ecmwf.int),
Balthasar Reuter (balthasar.reuter@ecmwf.int)

## License

Loki is distributed under the [Apache License 2.0](LICENSE). In applying this
licence, ECMWF does not waive the privileges and immunities granted to it by
virtue of its status as an intergovernmental organisation nor does it submit to
any jurisdiction.

## Installation

See [INSTALL.md](INSTALL.md).

## Documentation

Loki has a comprehensive [documentation](https://sites.ecmwf.int/docs/loki) that
describes the API and how to use it to write custom transformations.  There are
also a number of Jupyter notebooks available in the
[example directory](https://github.com/ecmwf-ifs/loki/blob/main/example) that help
getting up to speed with the core functionality of the package.

## Contributing

Contributions to Loki are welcome. In order to do so, please open an issue where
a feature request or bug can be discussed. Then create a pull request with your
contribution and sign the
[contributors license agreement (CLA)](https://bol-claassistant.ecmwf.int/ecmwf-ifs/loki).
loki-ecmwf-0.3.6/example/0000775000175000017500000000000015167130205015374 5ustar  alastairalastairloki-ecmwf-0.3.6/example/02_working_with_the_ir.ipynb0000664000175000017500000004623615167130205023020 0ustar  alastairalastair{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Working with Loki's internal representation\n",
    "\n",
    "The objective of this notebook is to get an impression how Loki's internal representation (IR) can be traversed, searched and manipulated using the provided visitor utilities.\n",
    "\n",
    "We are again going to work with the `phys_kernel_LITE_LOOP` routine. Let's start by parsing the source file and extracting the routine from it. Note, that we can also directly access the routine using its name, although it is wrapped inside a `Module` object as we have seen in the previous notebook:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Subroutine:: phys_kernel_LITE_LOOP"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from loki import Sourcefile\n",
    "source = Sourcefile.from_file('src/phys_mod.F90')\n",
    "routine_lite_loop = source['phys_kernel_LITE_LOOP']\n",
    "routine_lite_loop"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We are going to manipulate this routine and want to try two different ways of doing that, so we start by creating a copy. That way, we don't change the original object in the subsequent steps:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SUBROUTINE phys_kernel_LITE_LOOP (dim1, dim2, i1, i2, in1, in2, in3, in4, in5, in6, in7, in8, in9, in10, out1)\n",
      "  INTEGER(KIND=ip), INTENT(IN) :: dim1, dim2, i1, i2\n",
      "  REAL(KIND=lp), INTENT(INOUT), DIMENSION(1:dim1, 1:dim2) :: in1, in2, in3, in4, in5, in6, in7, in8, in9, in10\n",
      "  REAL(KIND=lp), INTENT(INOUT), DIMENSION(1:dim1, 1:dim2) :: out1\n",
      "  \n",
      "  INTEGER(KIND=ip) :: i, k\n",
      "  DO k=1,dim2\n",
      "    DO i=i1,i2\n",
      "      out1(i, k) = (in1(i, k) + in2(i, k) + in3(i, k) + in4(i, k) + in5(i, k) + in6(i, k) + in7(i, k) + in8(i, k) + in9(i, k) +  &\n",
      "      & in10(i, k))*0.1\n",
      "      in1(i, k) = out1(i, k)\n",
      "    END DO\n",
      "  END DO\n",
      "END SUBROUTINE phys_kernel_LITE_LOOP\n"
     ]
    }
   ],
   "source": [
    "routine = routine_lite_loop.clone()\n",
    "print(routine.to_fortran())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The routine body consists of two nested loops. What we want to try first is to change the order of the loops (i.e., have the `i` loop outermost and the `k` loop innermost) but leave the loop body untouched.\n",
    "\n",
    "For that, we first need to find the loops in the IR, which can be done using the [_FindNodes_](https://sites.ecmwf.int/docs/loki/main/loki.visitors.find.html#loki.visitors.find.FindNodes) visitor. As argument to the constructor we provide the node type (or a tuple of multiple types) that we want to look for and call the `visit` method with the tree to search.\n",
    "The visitor traverses the IR and collects all matching nodes into a list that is returned. For our purposes we are interested in the [_Loop_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Loop) nodes:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Loop:: k=1:dim2, Loop:: i=i1:i2]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from loki import FindNodes, Loop\n",
    "loops = FindNodes(Loop).visit(routine.body)\n",
    "loops"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can see, the visitor has found both loops. Next, we create a substitution map - essentially a dictionary that maps the original node to its replacement. To exchange the two loops, we use the outer loop but with the inner loop's body and make it the body of the inner loop:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{Loop:: k=1:dim2: Loop:: i=i1:i2}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outer_loop, inner_loop = loops\n",
    "new_inner_loop = outer_loop.clone(body=inner_loop.body)\n",
    "loop_map = {outer_loop: inner_loop.clone(body=(new_inner_loop,))}\n",
    "loop_map"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With the substitution map in place, we can call the [_Transformer_](https://sites.ecmwf.int/docs/loki/main/loki.visitors.transform.html#loki.visitors.transform.Transformer). It takes the map as argument to the constructor and applies it to the control flow tree given to the `visit` method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from loki import Transformer\n",
    "routine.body = Transformer(loop_map).visit(routine.body)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The result is the original routine with the exchanged loop order."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SUBROUTINE phys_kernel_LITE_LOOP (dim1, dim2, i1, i2, in1, in2, in3, in4, in5, in6, in7, in8, in9, in10, out1)\n",
      "  INTEGER(KIND=ip), INTENT(IN) :: dim1, dim2, i1, i2\n",
      "  REAL(KIND=lp), INTENT(INOUT), DIMENSION(1:dim1, 1:dim2) :: in1, in2, in3, in4, in5, in6, in7, in8, in9, in10\n",
      "  REAL(KIND=lp), INTENT(INOUT), DIMENSION(1:dim1, 1:dim2) :: out1\n",
      "  \n",
      "  INTEGER(KIND=ip) :: i, k\n",
      "  DO i=i1,i2\n",
      "    DO k=1,dim2\n",
      "      out1(i, k) = (in1(i, k) + in2(i, k) + in3(i, k) + in4(i, k) + in5(i, k) + in6(i, k) + in7(i, k) + in8(i, k) + in9(i, k) +  &\n",
      "      & in10(i, k))*0.1\n",
      "      in1(i, k) = out1(i, k)\n",
      "    END DO\n",
      "  END DO\n",
      "END SUBROUTINE phys_kernel_LITE_LOOP\n"
     ]
    }
   ],
   "source": [
    "reordered_loops = FindNodes(Loop).visit(routine.body)\n",
    "assert len(reordered_loops) == 2\n",
    "assert reordered_loops[0].variable == 'i' and reordered_loops[1].variable == 'k'\n",
    "print(routine.to_fortran())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "Next, we want to start again with the original routine and this time keep the loop order as is but reverse the memory layout of all arrays. We start by creating another copy of the original routine and verify that it is indeed the original version without the above transformations:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SUBROUTINE phys_kernel_LITE_LOOP (dim1, dim2, i1, i2, in1, in2, in3, in4, in5, in6, in7, in8, in9, in10, out1)\n",
      "  INTEGER(KIND=ip), INTENT(IN) :: dim1, dim2, i1, i2\n",
      "  REAL(KIND=lp), INTENT(INOUT), DIMENSION(1:dim1, 1:dim2) :: in1, in2, in3, in4, in5, in6, in7, in8, in9, in10\n",
      "  REAL(KIND=lp), INTENT(INOUT), DIMENSION(1:dim1, 1:dim2) :: out1\n",
      "  \n",
      "  INTEGER(KIND=ip) :: i, k\n",
      "  DO k=1,dim2\n",
      "    DO i=i1,i2\n",
      "      out1(i, k) = (in1(i, k) + in2(i, k) + in3(i, k) + in4(i, k) + in5(i, k) + in6(i, k) + in7(i, k) + in8(i, k) + in9(i, k) +  &\n",
      "      & in10(i, k))*0.1\n",
      "      in1(i, k) = out1(i, k)\n",
      "    END DO\n",
      "  END DO\n",
      "END SUBROUTINE phys_kernel_LITE_LOOP\n"
     ]
    }
   ],
   "source": [
    "routine = routine_lite_loop.clone()\n",
    "print(routine.to_fortran())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This time, we want to modify variables instead of loops. Loki uses a two-level internal representation that separates expressions from control flow. This means, the IR that we have worked with so far, is in fact the control flow tree and, nested inside, we have a second tree level as property of certain control flow nodes. For example, the loop bounds of the `Loop` node or the left and right hand side expressions in [_Assignment_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Assignment) nodes are such expression trees. The advantage of this is that it makes traversing the control flow tree a lot faster and allows to recurse into expressions only when required.\n",
    "\n",
    "Since we are now looking for variables we need to actually search the expression trees and therefore have to use a different visitor [_FindVariables_](https://sites.ecmwf.int/docs/loki/main/loki.expression.expr_visitors.htloki.expression.expr_visitorssitors.FindVariables). Here, we are only interested in arrays and can further restrict ourselves to [_Array_](https://sites.ecmwf.int/docs/loki/main/loki.expression.symbols.html#loki.expression.symbols.Array) expression nodes. We build again a substitution map with the subscript `dimensions` of the arrays reversed:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "in10(i, k) -> in10(k, i)\n",
      "in2(i, k) -> in2(k, i)\n",
      "in3(i, k) -> in3(k, i)\n",
      "in5(i, k) -> in5(k, i)\n",
      "in9(i, k) -> in9(k, i)\n",
      "in7(i, k) -> in7(k, i)\n",
      "out1(i, k) -> out1(k, i)\n",
      "in4(i, k) -> in4(k, i)\n",
      "in1(i, k) -> in1(k, i)\n",
      "in8(i, k) -> in8(k, i)\n",
      "in6(i, k) -> in6(k, i)\n"
     ]
    }
   ],
   "source": [
    "from loki import FindVariables, Array\n",
    "variable_map = {}\n",
    "for var in FindVariables().visit(routine.body):\n",
    "    if isinstance(var, Array) and var.dimensions:\n",
    "        variable_map[var] = var.clone(dimensions=var.dimensions[::-1])\n",
    "print('\\n'.join(f'{a!s} -> {b!s}' for a, b in variable_map.items()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Just like we have a separate find utility for expression trees there is a separate transformer [_SubstituteExpressions_](https://sites.ecmwf.int/docs/loki/main/loki.expression.expr_visitors.html#loki.expression.expr_visitors.SubstituteExpressions). Applying this to the routine's body we obtain the following result:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SUBROUTINE phys_kernel_LITE_LOOP (dim1, dim2, i1, i2, in1, in2, in3, in4, in5, in6, in7, in8, in9, in10, out1)\n",
      "  INTEGER(KIND=ip), INTENT(IN) :: dim1, dim2, i1, i2\n",
      "  REAL(KIND=lp), INTENT(INOUT), DIMENSION(1:dim1, 1:dim2) :: in1, in2, in3, in4, in5, in6, in7, in8, in9, in10\n",
      "  REAL(KIND=lp), INTENT(INOUT), DIMENSION(1:dim1, 1:dim2) :: out1\n",
      "  \n",
      "  INTEGER(KIND=ip) :: i, k\n",
      "  DO k=1,dim2\n",
      "    DO i=i1,i2\n",
      "      out1(k, i) = (in1(k, i) + in2(k, i) + in3(k, i) + in4(k, i) + in5(k, i) + in6(k, i) + in7(k, i) + in8(k, i) + in9(k, i) +  &\n",
      "      & in10(k, i))*0.1\n",
      "      in1(k, i) = out1(k, i)\n",
      "    END DO\n",
      "  END DO\n",
      "END SUBROUTINE phys_kernel_LITE_LOOP\n"
     ]
    }
   ],
   "source": [
    "from loki import SubstituteExpressions\n",
    "routine.body = SubstituteExpressions(variable_map).visit(routine.body)\n",
    "print(routine.to_fortran())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The routine's body is correctly modified, with the array subscript dimensions reversed, but the declarations are still unchanged. For that, we need to change the `shape` of the variables as well as the `dimensions` property of the variable nodes inside the declarations.\n",
    "\n",
    "There are two ways of achieving this: The first and easier way would be to modify the [_variables_ property](https://sites.ecmwf.int/docs/loki/main/loki.subroutine.html#loki.subroutine.Subroutine.variables) of the `Subroutine` object and update all array dimensions and shapes. This automatically recreates declarations for modified variables but inserts separate new declarations for each. Let's try this approach for a copy of the routine:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SUBROUTINE phys_kernel_LITE_LOOP (dim1, dim2, i1, i2, in1, in2, in3, in4, in5, in6, in7, in8, in9, in10, out1)\n",
      "  INTEGER(KIND=ip), INTENT(IN) :: dim1, dim2, i1, i2\n",
      "  \n",
      "  INTEGER(KIND=ip) :: i, k\n",
      "  REAL(KIND=lp), INTENT(INOUT) :: in1(1:dim2, 1:dim1)\n",
      "  REAL(KIND=lp), INTENT(INOUT) :: in2(1:dim2, 1:dim1)\n",
      "  REAL(KIND=lp), INTENT(INOUT) :: in3(1:dim2, 1:dim1)\n",
      "  REAL(KIND=lp), INTENT(INOUT) :: in4(1:dim2, 1:dim1)\n",
      "  REAL(KIND=lp), INTENT(INOUT) :: in5(1:dim2, 1:dim1)\n",
      "  REAL(KIND=lp), INTENT(INOUT) :: in6(1:dim2, 1:dim1)\n",
      "  REAL(KIND=lp), INTENT(INOUT) :: in7(1:dim2, 1:dim1)\n",
      "  REAL(KIND=lp), INTENT(INOUT) :: in8(1:dim2, 1:dim1)\n",
      "  REAL(KIND=lp), INTENT(INOUT) :: in9(1:dim2, 1:dim1)\n",
      "  REAL(KIND=lp), INTENT(INOUT) :: in10(1:dim2, 1:dim1)\n",
      "  REAL(KIND=lp), INTENT(INOUT) :: out1(1:dim2, 1:dim1)\n",
      "  DO k=1,dim2\n",
      "    DO i=i1,i2\n",
      "      out1(k, i) = (in1(k, i) + in2(k, i) + in3(k, i) + in4(k, i) + in5(k, i) + in6(k, i) + in7(k, i) + in8(k, i) + in9(k, i) +  &\n",
      "      & in10(k, i))*0.1\n",
      "      in1(k, i) = out1(k, i)\n",
      "    END DO\n",
      "  END DO\n",
      "END SUBROUTINE phys_kernel_LITE_LOOP\n"
     ]
    }
   ],
   "source": [
    "routine_variant1 = routine.clone()\n",
    "variables = []\n",
    "for var in routine_variant1.variables:\n",
    "    if isinstance(var, Array):\n",
    "        shape = var.shape[::-1]\n",
    "        variables += [var.clone(dimensions=shape, type=var.type.clone(shape=shape))]\n",
    "    else:\n",
    "        variables += [var]\n",
    "routine_variant1.variables = variables\n",
    "print(routine_variant1.to_fortran())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As a finger exercise we demonstrate also a second approach that avoids recreating the declarations but modifies them directly. For that, we search for all [_VariableDeclaration_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.VariableDeclaration) nodes in the routine's specification part (`spec`) and build a substitution map with updated declarations where necessary. This involves updating the list of variables declared in a `VariableDeclaration` node and making sure that only `Array` nodes are modified.\n",
    "\n",
    "Fortran allows to specify array dimensions either using the `DIMENSION` attribute or as dimensions in brackets after the declared symbol's name (e.g., `var(dim1, dim2)`). Loki's default behaviour is the latter (as visible from the auto-generated declarations above). To accommodate both variants in Loki's IR, we allow an optional property `dimensions` on `VariableDeclaration` nodes to produce the syntax of the first. Importantly, in both cases Loki stores the `dimensions` property also on the declared variable nodes to make sure they are always accessible in a uniform way.\n",
    "\n",
    "When building the substitution map for the declaration nodes, we honour both versions and adapt our behaviour accordingly:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SUBROUTINE phys_kernel_LITE_LOOP (dim1, dim2, i1, i2, in1, in2, in3, in4, in5, in6, in7, in8, in9, in10, out1)\n",
      "  INTEGER(KIND=ip), INTENT(IN) :: dim1, dim2, i1, i2\n",
      "  REAL(KIND=lp), INTENT(INOUT), DIMENSION(1:dim2, 1:dim1) :: in1, in2, in3, in4, in5, in6, in7, in8, in9, in10\n",
      "  REAL(KIND=lp), INTENT(INOUT), DIMENSION(1:dim2, 1:dim1) :: out1\n",
      "  \n",
      "  INTEGER(KIND=ip) :: i, k\n",
      "  DO k=1,dim2\n",
      "    DO i=i1,i2\n",
      "      out1(k, i) = (in1(k, i) + in2(k, i) + in3(k, i) + in4(k, i) + in5(k, i) + in6(k, i) + in7(k, i) + in8(k, i) + in9(k, i) +  &\n",
      "      & in10(k, i))*0.1\n",
      "      in1(k, i) = out1(k, i)\n",
      "    END DO\n",
      "  END DO\n",
      "END SUBROUTINE phys_kernel_LITE_LOOP\n"
     ]
    }
   ],
   "source": [
    "from loki import VariableDeclaration\n",
    "decl_map = {}\n",
    "for decl in FindNodes(VariableDeclaration).visit(routine.spec):\n",
    "    if decl.dimensions:\n",
    "        shape = decl.dimensions[::-1]\n",
    "        symbols = [var.clone(dimensions=shape, type=var.type.clone(shape=shape)) for var in decl.symbols]\n",
    "        decl_map[decl] = decl.clone(dimensions=shape, symbols=symbols)\n",
    "    elif any(isinstance(var, Array) for var in decl.symbols):\n",
    "        symbols = []\n",
    "        for var in decl.symbols:\n",
    "            if isinstance(var, Array):\n",
    "                shape = var.shape[::-1]\n",
    "                symbols += [var.clone(dimensions=shape, type=var.type.clone(shape=shape))]\n",
    "            else:\n",
    "                symbols += [var]\n",
    "        decl_map[decl] = decl.clone(symbols=symbols)\n",
    "routine.spec = Transformer(decl_map).visit(routine.spec)\n",
    "print(routine.to_fortran())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And with that we have achieved the same result while retaining the compacted notation for declarations. Notably, the body is the same for both variants:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "from loki import fgen\n",
    "fcode = fgen(routine.spec)\n",
    "assert '(1:dim2, 1:dim1)' in fcode\n",
    "assert '(1:dim1, 1:dim2)' not in fcode\n",
    "fcode = fgen(routine_variant1.spec)\n",
    "assert '(1:dim2, 1:dim1)' in fcode\n",
    "assert '(1:dim1, 1:dim2)' not in fcode\n",
    "assert fgen(routine.body) == fgen(routine_variant1.body)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For further details on how to work with Loki's internal representation, have a look at the [relevant section in the documentation](https://sites.ecmwf.int/docs/loki/main/visitors.html)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.8 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "49643b39beb1b0a7ebd0b57318d9385a5a724f398f0bc0540a61bbc4360c8e5d"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
loki-ecmwf-0.3.6/example/src/0000775000175000017500000000000015167130205016163 5ustar  alastairalastairloki-ecmwf-0.3.6/example/src/phys_mod.F900000664000175000017500000002711115167130205020267 0ustar  alastairalastairmodule phys_mod

use :: iso_fortran_env

use omp_lib

implicit none

integer, parameter :: sp = REAL32
integer, parameter :: dp = REAL64
#ifdef FLOAT_SINGLE
integer, parameter :: lp = sp     !! lp : "local" precision
#else
integer, parameter :: lp = dp     !! lp : "local" precision
#endif

integer, parameter :: ip = INT64

real(kind=lp) :: cst1 = 2.5, cst2 = 3.14
integer, parameter :: nspecies = 5

contains 

subroutine phys_kernel_LITE_LOOP(dim1,dim2,i1,i2, in1,in2,in3,in4,in5,in6,in7,in8,in9,in10, out1)
  integer(kind=ip),intent(in) :: dim1, dim2, i1,i2
  real(kind=lp),dimension(1:dim1,1:dim2),intent(inout) :: in1,in2,in3,in4,in5,in6,in7,in8,in9,in10
  real(kind=lp),dimension(1:dim1,1:dim2),intent(inout) :: out1

  integer(kind=ip) :: i,k
  do k=1,dim2
    do i=i1,i2
      out1(i,k) = (in1(i,k) + in2(i,k) + in3(i,k) + in4(i,k) + in5(i,k) + &
 &                 in6(i,k) + in7(i,k) + in8(i,k) + in9(i,k) + in10(i,k)) * 0.1
      in1(i,k) = out1(i,k)
    end do
  end do
end subroutine phys_kernel_LITE_LOOP

subroutine phys_kernel_VERT_SEARCH(dim1,dim2,i1,i2, in1,in2,in3,in4,in5,in6,in7,in8,in9,in10, out1)
  integer(kind=ip),intent(in) :: dim1, dim2, i1,i2
  real(kind=lp),dimension(1:dim1,1:dim2),intent(inout) :: in1,in2,in3,in4,in5,in6,in7,in8,in9,in10
  real(kind=lp),dimension(1:dim1,1:dim2),intent(inout) :: out1

  integer(kind=ip) :: i,k
  real(kind=lp) :: temp(i1:i2)
  integer :: kmax(i1:i2)

  temp = -1.
  kmax = -1
  do k=1,dim2
    do i=i1,i2
      if (in1(i,k) > temp(i)) then
        temp(i) = in1(i,k)
        kmax(i) = k
      end if
    end do
  end do

  do i=i1,i2
    do k=1,kmax(i)
      out1(i,k) = (in1(i,k) + in2(i,k) + in3(i,k) + in4(i,k) + in5(i,k) + &
 &                 in6(i,k) + in7(i,k) + in8(i,k) + in9(i,k) + in10(i,k)) * 0.1
      in1(i,k) = out1(i,k)
    end do
  end do
  
  do i=i1,i2
    do k=kmax(i)+1,dim2
      out1(i,k) = (in1(i,k) * in2(i,k) * in3(i,k) * in4(i,k) * in5(i,k) + &
 &                 in6(i,k) * in7(i,k) * in8(i,k) * in9(i,k) * in10(i,k)) * 0.3
      in1(i,k) = out1(i,k)
    end do
  end do
end subroutine phys_kernel_VERT_SEARCH

subroutine phys_kernel_NASTY_EXPS(dim1,dim2,i1,i2, in1,in2,in3,in4,in5,in6,in7,in8,in9,in10, out1)
  integer(kind=ip),intent(in) :: dim1, dim2, i1,i2
  real(kind=lp),dimension(1:dim1,1:dim2),intent(inout) :: in1,in2,in3,in4,in5,in6,in7,in8,in9,in10
  real(kind=lp),dimension(1:dim1,1:dim2),intent(inout) :: out1

  integer(kind=ip) :: i,k
  real(kind=lp) :: temp_s1, temp_s2

  do k=1,dim2
    do i=i1,i2
      temp_s1 = (in1(i,k) + in2(i,k) + in3(i,k) + in4(i,k) + in5(i,k) + &
 &                 in6(i,k) + in7(i,k) + in8(i,k) + in9(i,k) + in10(i,k)) * 0.1
      temp_s1 = min( exp( (temp_s1 - cst1) / (temp_s1 - cst2) ), exp(in4(i,k)-in5(i,k)) )

      temp_s2 = (in1(i,k) - in2(i,k) * in3(i,k) - in4(i,k) + (in5(i,k) - &
 &                 in6(i,k)*0.5) + (in7(i,k) - in8(i,k)*0.1) - in9(i,k) - in10(i,k)) * 0.2
      temp_s2 = min( exp( (temp_s2 - cst2) / (temp_s2 - cst1) ), exp(in6(i,k)+in7(i,k)) )

      if (temp_s1 < temp_s2) then
        out1(i,k) = temp_s1
      else 
        out1(i,k) = temp_s2
      end if

      in1(i,k) = out1(i,k)
    end do
  end do
end subroutine phys_kernel_NASTY_EXPS

subroutine phys_kernel_LU_SOLVER(dim1,dim2,i1,i2, in1,in2,in3,in4,in5,in6,in7,in8,in9,in10, out1)
  integer(kind=ip),intent(in) :: dim1, dim2, i1,i2
  real(kind=lp),dimension(1:dim1,1:dim2),intent(inout) :: in1,in2,in3,in4,in5,in6,in7,in8,in9,in10
  real(kind=lp),dimension(1:dim1,1:dim2),intent(inout) :: out1

  integer(kind=ip) :: i,k
  real(kind=lp),dimension(dim1,nspecies,nspecies) :: lu_lhs, lu_rhs_implicit
  integer(kind=ip) :: s1,s2,s3, jm,jn
  real(kind=lp) :: dp(i1:i2), temp_hor1(i1:i2)
  real(kind=lp) :: temp_out(i1:i2,nspecies), out_lev_m_1(i1:i2,nspecies)

  ! initialise for k=1
  out_lev_m_1 = 0.

  do k=1,dim2
 
    do i=i1,i2
      lu_rhs_implicit(i,:,:) = 0.5
      lu_lhs(i,:,:) = 0.1
    end do

    if (k 0.8) then
        lu_rhs_implicit(i,1,4) = lu_rhs_implicit(i,1,4) + in6(i,k) 
        lu_rhs_implicit(i,3,4) = lu_rhs_implicit(i,3,4) + in8(i,k) 
      end if
    end do

    ! set up lhs properly
    do s2=1,nspecies
      do s1=1,nspecies
        if (s1==s2) then
          do i=i1,i2
            lu_lhs(i,s1,s1) = lu_lhs(i,s1,s1) + sum(lu_rhs_implicit(i,:,s1))  ! diagonal term 
          end do
        else 
          do i=i1,i2
            lu_lhs(i,s1,s2) = - lu_rhs_implicit(i,s1,s2) ! off-diagonal
          end do
        end if
      end do ! s1
    end do ! s2

    ! set the rhs hopefully plausibly 
    do s2=1,nspecies 
      do i=i1,i2
        if (s2==1) then
          temp_out(i,s2)= in1(i,k) + in9(i,k) ! zexplicit
        elseif (s2==2) then
          temp_out(i,s2)= in2(i,k) + in9(i,k) ! zexplicit
        elseif (s2==3) then
          temp_out(i,s2)= in3(i,k) + in9(i,k) ! zexplicit
        elseif (s2==4) then
          temp_out(i,s2)= in4(i,k) + in9(i,k) ! zexplicit
        elseif (s2==5) then
          temp_out(i,s2)= in5(i,k) + in9(i,k) ! zexplicit
        end if
      enddo
    enddo

    ! following factorization code taken straight from CLOUDSC
    ! Non pivoting recursive factorization 
    do s2 = 1, nspecies-1
      do s1 = s2+1,nspecies
        do i=i1,i2
          lu_lhs(i,s1,s2)=lu_lhs(i,s1,s2) / lu_lhs(i,s2,s2)
        enddo
        do s3=s2+1,nspecies
          do i=i1,i2
            lu_lhs(i,s1,s3)=lu_lhs(i,s1,s3)-lu_lhs(i,s1,s2)*lu_lhs(i,s2,s3)
          enddo ! do i
        enddo ! do s3
      enddo ! do s1
    enddo ! do s2

    ! backsubstitution 
    !  step 1 
    do s2=2,nspecies
      do s1 = 1,s2-1
        do i=i1,i2
          temp_out(i,s2)=temp_out(i,s2)-lu_lhs(i,s2,s1) * temp_out(i,s1)
        end do !  i
      end do ! s1
    end do ! s2
    !  step 2
    do i=i1,i2
      temp_out(i,nspecies)=temp_out(i,nspecies)/lu_lhs(i,nspecies,nspecies)
    end do !  i
    do s2=nspecies-1,1,-1
      do s1 = s2+1,nspecies
        do i=i1,i2
          temp_out(i,s2)=temp_out(i,s2)-lu_lhs(i,s2,s1) * temp_out(i,s1)
        end do !  i
      end do ! s1
      do i=i1,i2
        temp_out(i,s2)=temp_out(i,s2)/lu_lhs(i,s2,s2)
      end do !  i
    enddo ! s2

    ! extract solution values into output
    do i=i1,i2
      out1(i,k) = sum(temp_out(i,:))
    end do
    ! save k level values for use at k+1
    do s1=1,nspecies
      do i=i1,i2
        out_lev_m_1(i,s1) = temp_out(i,s1)
      end do ! i
    end do ! s1

  end do !! do k

end subroutine phys_kernel_LU_SOLVER

subroutine phys_kernel_LU_SOLVER_COMPACT(dim1,dim2,i1,i2, in1,in2,in3,in4,in5,in6,in7,in8,in9,in10, out1)
  ! To satisfy my curiosity, flip the allocation of the matrix to be compact for each grid point.
  integer(kind=ip),intent(in) :: dim1, dim2, i1,i2
  real(kind=lp),dimension(1:dim1,1:dim2),intent(inout) :: in1,in2,in3,in4,in5,in6,in7,in8,in9,in10
  real(kind=lp),dimension(1:dim1,1:dim2),intent(inout) :: out1

  integer(kind=ip) :: i,k
  ! Invert the matrix allocation, so that parallel dim is outermost
  real(kind=lp),dimension(nspecies,nspecies,dim1) :: lu_lhs, lu_rhs_implicit
  integer(kind=ip) :: s1,s2,s3, jm,jn
  real(kind=lp) :: dp(i1:i2), temp_hor1(i1:i2)
  real(kind=lp) :: temp_out(i1:i2,nspecies), out_lev_m_1(i1:i2,nspecies)

  ! initialise for k=1
  out_lev_m_1 = 0.

  do k=1,dim2
 
    do i=i1,i2
      lu_rhs_implicit(:,:,i) = 0.5
      lu_lhs(:,:,i) = 0.1
    end do

    if (k 0.8) then
        lu_rhs_implicit(1,4,i) = lu_rhs_implicit(1,4,i) + in6(i,k) 
        lu_rhs_implicit(3,4,i) = lu_rhs_implicit(3,4,i) + in8(i,k) 
      end if
    end do

    ! set up lhs properly
    do s2=1,nspecies
      do s1=1,nspecies
        if (s1==s2) then
          do i=i1,i2
            lu_lhs(s1,s1,i) = lu_lhs(s1,s1,i) + sum(lu_rhs_implicit(:,s1,i))  ! diagonal term 
          end do
        else 
          do i=i1,i2
            lu_lhs(s1,s2,i) = - lu_rhs_implicit(s1,s2,i) ! off-diagonal
          end do
        end if
      end do ! s1
    end do ! s2

    ! set the rhs hopefully plausibly 
    do s2=1,nspecies 
      do i=i1,i2
        if (s2==1) then
          temp_out(i,s2)= in1(i,k) + in9(i,k) ! zexplicit
        elseif (s2==2) then
          temp_out(i,s2)= in2(i,k) + in9(i,k) ! zexplicit
        elseif (s2==3) then
          temp_out(i,s2)= in3(i,k) + in9(i,k) ! zexplicit
        elseif (s2==4) then
          temp_out(i,s2)= in4(i,k) + in9(i,k) ! zexplicit
        elseif (s2==5) then
          temp_out(i,s2)= in5(i,k) + in9(i,k) ! zexplicit
        end if
      enddo
    enddo

    ! following factorization code taken straight from CLOUDSC
    ! Non pivoting recursive factorization 
    do s2 = 1, nspecies-1
      do s1 = s2+1,nspecies
        do i=i1,i2
          lu_lhs(s1,s2,i)=lu_lhs(s1,s2,i) / lu_lhs(s2,s2,i)
        enddo
        do s3=s2+1,nspecies
          do i=i1,i2
            lu_lhs(s1,s3,i)=lu_lhs(s1,s3,i)-lu_lhs(s1,s2,i)*lu_lhs(s2,s3,i)
          enddo ! do i
        enddo ! do s3
      enddo ! do s1
    enddo ! do s2

    ! backsubstitution 
    !  step 1 
    do s2=2,nspecies
      do s1 = 1,s2-1
        do i=i1,i2
          temp_out(i,s2)=temp_out(i,s2)-lu_lhs(s2,s1,i) * temp_out(i,s1)
        end do !  i
      end do ! s1
    end do ! s2
    !  step 2
    do i=i1,i2
      temp_out(i,nspecies)=temp_out(i,nspecies)/lu_lhs(nspecies,nspecies,i)
    end do !  i
    do s2=nspecies-1,1,-1
      do s1 = s2+1,nspecies
        do i=i1,i2
          temp_out(i,s2)=temp_out(i,s2)-lu_lhs(s2,s1,i) * temp_out(i,s1)
        end do !  i
      end do ! s1
      do i=i1,i2
        temp_out(i,s2)=temp_out(i,s2)/lu_lhs(s2,s2,i)
      end do !  i
    enddo ! s2

    ! extract solution values into output
    do i=i1,i2
      out1(i,k) = sum(temp_out(i,:))
    end do
    ! save k level values for use at k+1
    do s1=1,nspecies
      do i=i1,i2
        out_lev_m_1(i,s1) = temp_out(i,s1)
      end do ! i
    end do ! s1

  end do !! do k

end subroutine phys_kernel_LU_SOLVER_COMPACT

end module phys_mod
loki-ecmwf-0.3.6/example/src/loop_fuse.F900000664000175000017500000000467415167130205020451 0ustar  alastairalastairsubroutine loop_fuse(n,var_in,var_out)
use parkind1, only : jpim,jprb
implicit none

integer(kind=jpim),intent(in) :: n
real(kind=jprb),   intent(in ) :: var_in (n,n,n)
real(kind=jprb),   intent(out) :: var_out(n,n,n)
integer(kind=jpim) :: i,j,k

do k=1,n
  do j=1,n
    do i=1,n
      var_out(i,j,k) = var_in(i,j,k)
    enddo
    do i=1,n   
      var_out(i,j,k) = 2._JPRB*var_out(i,j,k)
    enddo
  enddo
enddo

end subroutine loop_fuse

subroutine loop_fuse_v1(n,var_in,var_out)
use parkind1, only : jpim,jprb
implicit none

integer(kind=jpim),intent(in) :: n
real(kind=jprb),   intent(in ) :: var_in (n,n,n)
real(kind=jprb),   intent(out) :: var_out(n,n,n)
integer(kind=jpim) :: i,j,k

do k=1,n
  do j=1,n
    do i=1,n
      var_out(i,j,k) = var_in(i,j,k)
    enddo
    do i=1,n   
      var_out(i,j,k) = 2._JPRB*var_out(i,j,k)
    enddo
  enddo

  call some_kernel(n,var_out(1,1,k))

  do j=1,n
    do i=1,n
      var_out(i,j,k) = var_out(i,j,k) + 1._JPRB
    enddo
    do i=1,n   
      var_out(i,j,k) = 2._JPRB*var_out(i,j,k)
    enddo
  enddo
enddo

end subroutine loop_fuse_v1

subroutine loop_fuse_v2(n,var_in,var_out)
use parkind1, only : jpim,jprb
implicit none

integer(kind=jpim),intent(in) :: n
real(kind=jprb),   intent(in ) :: var_in (n,n,n)
real(kind=jprb),   intent(out) :: var_out(n,n,n)
integer(kind=jpim) :: i,j,k

do k=1,n
  do j=1,n
    do i=1,n
      var_out(i,j,k) = var_in(i,j,k)
    enddo
    do i=1,n   
      var_out(i,j,k) = 2._JPRB*var_out(i,j,k)
    enddo

    call some_kernel(n,var_out(1,j,k))

    do i=1,n
      var_out(i,j,k) = var_out(i,j,k) + 1._JPRB
    enddo
    do i=1,n   
      var_out(i,j,k) = 2._JPRB*var_out(i,j,k)
    enddo
  enddo
enddo

end subroutine loop_fuse_v2

subroutine loop_fuse_pragma(n,var_in,var_out)
use parkind1, only : jpim,jprb
implicit none

integer(kind=jpim),intent(in) :: n
real(kind=jprb),   intent(in ) :: var_in (n,n,n)
real(kind=jprb),   intent(out) :: var_out(n,n,n)
integer(kind=jpim) :: i,j,k

do k=1,n
  do j=1,n

    !$loki loop-fusion group(g1)
    do i=1,n
      var_out(i,j,k) = var_in(i,j,k)
    enddo
    !$loki loop-fusion group(g1)
    do i=1,n   
      var_out(i,j,k) = 2._JPRB*var_out(i,j,k)
    enddo

    call some_kernel(n,var_out(1,j,k))

    !$loki loop-fusion group(g2)
    do i=1,n
      var_out(i,j,k) = var_out(i,j,k) + 1._JPRB
    enddo
    !$loki loop-fusion group(g2)
    do i=1,n   
      var_out(i,j,k) = 2._JPRB*var_out(i,j,k)
    enddo
    
  enddo
enddo

end subroutine loop_fuse_pragma
loki-ecmwf-0.3.6/example/src/intent_test.F900000664000175000017500000000356415167130205021013 0ustar  alastairalastair
module kernel_mod
  use parkind1, only : jpim,jprb
  implicit none
contains
  subroutine some_kernel(n,vout,var_out,var_in,var_inout,b,l,h,y)
  
  integer(kind=jpim),intent(in)  :: n,l,b
  integer(kind=jpim),intent(in)  :: h
  real(kind=jprb),   intent(in )   ::   var_in   (n)
  real(kind=jprb),   intent(inout) ::   var_out  (n)
  real(kind=jprb),   intent(inout) ::   var_inout(n)
  real(kind=jprb),   intent(inout) ::   vout(n)
  real(kind=jprb),   intent(inout) ::      y(:)
  
  end subroutine some_kernel
end module kernel_mod

subroutine intent_test(m,n,var_in,var_out,var_inout,tendency_loc)
use parkind1, only : jpim,jprb
use kernel_mod, only: some_kernel
use yoecldp, only : nclv
implicit none

integer(kind=jpim),intent(in) :: m,n
integer(kind=jpim) :: i,j,k,h,l
real(kind=jprb),   intent(in )        :: var_in   (n,n,n)
real(kind=jprb),   target,intent(out) :: var_out  (n,n,n)
real(kind=jprb),   intent(inout)      :: var_inout(n,n,n)
real(kind=jprb), allocatable :: x(:),y(:)
real(kind=jprb), pointer :: vout(n)
type(state_type), intent (out) :: tendency_loc

allocate(x(n))
associate(mtmp=>m)
allocate(y(mtmp))
end associate

associate(mtmp=>n)
do k=1,mtmp
  do j=1,mtmp
    do i=1,mtmp  
      var_out(i,j,k) = 2._jprb
    enddo
 
    associate(mbuf=>mtmp) 
    var_out(m:mbuf,j,k) = var_in(m:mbuf,j,k)+var_inout(m:mbuf,j,k)+var_out(m:mbuf,j,k)
    end associate

    vout=>var_out(:,j,k)

    associate(vin=>mtmp)
    call some_kernel(vin,vout,vout,var_in(:,j,k),var_inout(:,j,k),1,h=vin,l=5,y=y)
    end associate

    nullify(vout)

    associate(vout=>tendency_loc%cld(:,j,k))

    associate(vin=>var_in(:,j,k))
    call some_kernel(mtmp,vout,var_out(:,j,k),vin,var_inout(:,j,k),1,h=mtmp,l=5,y=y)
    end associate

    end associate

    do i=1,mtmp
      var_inout(i,j,k) = var_out(i,j,k)
    enddo
  enddo
enddo
end associate

deallocate(x)
deallocate(y)

end subroutine intent_test
loki-ecmwf-0.3.6/example/src/phys_driver.F900000664000175000017500000000762415167130205021012 0ustar  alastairalastairprogram phys_layout

use omp_lib

use phys_mod, only: ip, lp, &
 &                  phys_kernel_LITE_LOOP, phys_kernel_VERT_SEARCH, phys_kernel_NASTY_EXPS, &
 &                  phys_kernel_LU_SOLVER, phys_kernel_LU_SOLVER_COMPACT

implicit none

!! arrays for passing to subroutine
real(kind=lp),dimension(:,:,:),allocatable :: arr1, arr2, arr3, arr4, arr5, arr6, arr7, arr8, arr9, arr10
real(kind=lp),dimension(:,:,:),allocatable :: out1 

integer(kind=ip) :: nproma, npoints, nlev, nblocks, ntotal, num_main_loops
integer(kind=ip) :: nb, nml, inp, i, i1, i2, k, real_bytes, tid

integer :: i_seed
integer, dimension(:), allocatable :: a_seed

integer :: iargs, lenarg
character(len=20) :: clarg

real(kind=lp) :: time1, time2

! Constants
nlev    = 137
num_main_loops = 8

! Defaults
nproma = 32
ntotal = 32*4096

iargs = command_argument_count()
if (iargs >= 1) then
   call get_command_argument(1, clarg, lenarg)
   read(clarg(1:lenarg),*) nproma

   if (iargs >= 2) then
      call get_command_argument(2, clarg, lenarg)
      read(clarg(1:lenarg),*) ntotal
   endif
endif
npoints = nproma  ! Can be unlocked later...
nblocks = ntotal / nproma + 1
write(*, '(A8,I8,A8,I8,A8,I8,A)') 'NPROMA ', nproma, "TOTAL ", ntotal, "NBLOCKS", nblocks

#ifdef FLOAT32
real_bytes = 4
#else
real_bytes = 8
#endif 

allocate(arr1 (npoints,nlev,nblocks))
allocate(arr2 (npoints,nlev,nblocks))
allocate(arr3 (npoints,nlev,nblocks))
allocate(arr4 (npoints,nlev,nblocks))
allocate(arr5 (npoints,nlev,nblocks))
allocate(arr6 (npoints,nlev,nblocks))
allocate(arr7 (npoints,nlev,nblocks))
allocate(arr8 (npoints,nlev,nblocks))
allocate(arr9 (npoints,nlev,nblocks))
allocate(arr10(npoints,nlev,nblocks))
allocate(out1 (npoints,nlev,nblocks))

! ----- Set up random seed portably -----
call random_seed(size=i_seed)
allocate(a_seed(1:i_seed))
call random_seed(get=a_seed)
a_seed = [ (i,i=1,i_seed) ]
call random_seed(put=a_seed)
deallocate(a_seed)

! Initialize values in parallel region to avoid accidental
! NUMA issues in single-socket mode on dual-socket machines.
!$omp parallel

!$omp master
do inp = 1, ntotal, nproma
  nb = (inp-1) / nproma + 1
  call random_number(arr1(:,:,nb))
  call random_number(arr2(:,:,nb))
  call random_number(arr3(:,:,nb))
  call random_number(arr4(:,:,nb))
  call random_number(arr5(:,:,nb))
  call random_number(arr6(:,:,nb))
  call random_number(arr7(:,:,nb))
  call random_number(arr8(:,:,nb))
  call random_number(arr9(:,:,nb))
  call random_number(arr10(:,:,nb))
end do
!$omp end master

! Pre-processor fudging to inject the relevant kernel call
#ifdef LITE_LOOP
#define PHYS_KERNEL phys_kernel_LITE_LOOP
#elif VERT_SEARCH
#define PHYS_KERNEL phys_kernel_VERT_SEARCH
#elif NASTY_EXPS
#define PHYS_KERNEL phys_kernel_NASTY_EXPS
#elif LU_SOLVER
#define PHYS_KERNEL phys_kernel_LU_SOLVER
#elif LU_SOLVER_COMPACT
#define PHYS_KERNEL phys_kernel_LU_SOLVER_COMPACT
#endif

time1 = omp_get_wtime()

do nml=1,num_main_loops

!$OMP PARALLEL DEFAULT(SHARED), PRIVATE(nb,inp,i1,i2)

! The classic BLOCKED-NPROMA structure from the IFS

!$omp do schedule(static)
do inp = 1, ntotal, nproma
  nb = (inp-1) / nproma + 1
  call PHYS_KERNEL( nproma, nlev, 1_ip, nproma, arr1(:,:,nb), arr2(:,:,nb), arr3(:,:,nb), &
 &                  arr4(:,:,nb), arr5(:,:,nb), &
 &                  arr6(:,:,nb), arr7 (:,:,nb), arr8(:,:,nb), &
 &                  arr9(:,:,nb), arr10(:,:,nb), out1(:,:,nb) )
end do
!$omp end do

!$OMP END PARALLEL

end do ! outer nml loop

time2 = omp_get_wtime()


write(*, '(A,3F12.6)') 'Result check : ', arr1(1,1,1), sum(arr1) / real(npoints*nlev*nblocks), sum(out1) / real(npoints*nlev*nblocks)
write(*, '(A,F8.4)') 'Time for kernel call : ', time2 - time1
write(*, '(A,F12.2,A,F12.2,A)') 'Bandwidth estimate : ', num_main_loops*11*npoints*nlev*nblocks*real_bytes*1.e-6, &
     ' MB transferred: ', num_main_loops*11*npoints*nlev*nblocks*real_bytes*1.e-6 / (time2-time1), ' MB/s '

! TODO: Should really dealloc all this stuff... :(

end program phys_layout
loki-ecmwf-0.3.6/example/README.md0000664000175000017500000000116015167130205016651 0ustar  alastairalastairUsing the notebooks
===================

The notebooks contain many `import`  statements for loading various Loki modules. In order to ensure the interpreter can locate the relevant modules, the jupyter notebook server should be launched from a terminal where the `loki-env` virtual environment has been activated:

```bash
$ source /loki-activate
(loki_env) $ jupyter notebook
```

The interpreter used by the notebook server can be checked using the following command:

```bash
(loki_env) $ jupyter kernelspec list
Available kernels:
  python3    /loki_env/share/jupyter/kernels/python3
```
loki-ecmwf-0.3.6/example/01_reading_and_writing_files.ipynb0000664000175000017500000007006015167130205024122 0ustar  alastairalastair{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "40786f36",
   "metadata": {},
   "source": [
    "# Reading and writing files with Loki\n",
    "\n",
    "This is the first introductory notebook on how to work with Loki. The intention is to give an overview of how Fortran files can be read into Loki's internal representation to be able to work on their content and apply transformations, and how we can generate Fortran source code again. It includes a short peak at the control flow representation but details will be discussed in other notebooks.\n",
    "\n",
    "Let's start by parsing the file `src/phys_mod.F90` from the `example` directory.\n",
    "Loki uses a [_Sourcefile_](https://sites.ecmwf.int/docs/loki/main/loki.sourcefile.html#module-loki.sourcefile) object to represent an entire source file, which can contain modules or subroutines. To initialize the object with the content of a file on disc, we use the `from_file` class method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "fa7c571e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Loki::Sourcefile] Constructed from src/phys_mod.F90 in 7.46s\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       ""
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from loki import Sourcefile\n",
    "source = Sourcefile.from_file('src/phys_mod.F90', preprocess=True)\n",
    "source"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "836efe96",
   "metadata": {},
   "source": [
    "Let's examine the content of the source file by looking at the modules and subroutines contained in that file:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1c262254",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Modules: (Module:: phys_mod,)\n",
      "Subroutines: ()\n"
     ]
    }
   ],
   "source": [
    "print(f\"Modules: {source.modules}\")\n",
    "print(f\"Subroutines: {source.subroutines}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0a6d4aa",
   "metadata": {},
   "source": [
    "We can see from the above that `source` contains one module by the name \"phys_mod\" and no free subroutines.\n",
    "We can access modules either via their index in the `modules` property (i.e., `source.modules[0]`) or using a subscript operator with their name directly on the `Sourcefile` object:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5c7ba4af",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Module:: phys_mod"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "phys_mod = source['phys_mod']\n",
    "phys_mod"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8c4518e",
   "metadata": {},
   "source": [
    "Fortran modules are represented as [_Module_](https://sites.ecmwf.int/docs/loki/main/loki.module.html#loki.module.Module) objects in Loki. They consist of a specification part and may contain, e.g., subroutines. Let's examine this object further:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6b98dc7b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Spec: Section::\n",
      "Subroutines: (Subroutine:: phys_kernel_LITE_LOOP, Subroutine:: phys_kernel_VERT_SEARCH, Subroutine:: phys_kernel_NASTY_EXPS, Subroutine:: phys_kernel_LU_SOLVER, Subroutine:: phys_kernel_LU_SOLVER_COMPACT)\n"
     ]
    }
   ],
   "source": [
    "print(f\"Spec: {phys_mod.spec}\")\n",
    "print(f\"Subroutines: {phys_mod.subroutines}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10c81ec9",
   "metadata": {},
   "source": [
    "The specification part consists of a [_Section_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Section) node, which acts as the root node of Loki's control flow tree. At this point, it may be useful to learn more about Loki's internal representation by reading the [relevant part](https://sites.ecmwf.int/docs/loki/main/internal_representation.html) of the documentation. But for the objectives of this notebook we can also carry on and treat them as a black box for now.\n",
    "\n",
    "To get an impression of what the IR of the specification part looks like, we can call `view()` on any of the nodes to print a representation of this node and the tree below it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4a6159d5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "   ()>\n",
      "  \n",
      "   ()>\n",
      "  \n",
      "  \n",
      "  \n",
      "  \n",
      "  \n",
      "  \n",
      "  \n",
      "  \n",
      "  \n",
      "  \n",
      "  \n",
      "  \n",
      "  "
     ]
    }
   ],
   "source": [
    "phys_mod.spec.view()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7fa4ba7-a0f6-4f10-a47f-d508a121653d",
   "metadata": {},
   "source": [
    "Or alternativly, if `graphviz` is available, we can call `ir_graph()` on any of the nodes to view a graph representation of this node and the tree below it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6942fbb4-113c-466d-be0e-4fce35d837ac",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Loki::Graph Visualization] Created graph visualization in 0.01s\n"
     ]
    },
    {
     "data": {
      "image/svg+xml": [
       "\n",
       "\n",
       "\n",
       "\n",
       "\n",
       "\n",
       "%3\n",
       "\n",
       "\n",
       "\n",
       "0\n",
       "\n",
       "<Section::>\n",
       "\n",
       "\n",
       "\n",
       "1\n",
       "\n",
       "<Import:: iso_fortran_env => ()>\n",
       "\n",
       "\n",
       "\n",
       "0->1\n",
       "\n",
       "\n",
       "\n",
       "\n",
       "\n",
       "2\n",
       "\n",
       "<Import:: omp_lib => ()>\n",
       "\n",
       "\n",
       "\n",
       "0->2\n",
       "\n",
       "\n",
       "\n",
       "\n",
       "\n",
       "3\n",
       "\n",
       "<Intrinsic:: IMPLICIT NONE>\n",
       "\n",
       "\n",
       "\n",
       "0->3\n",
       "\n",
       "\n",
       "\n",
       "\n",
       "\n",
       "4\n",
       "\n",
       "<VariableDeclaration:: sp>\n",
       "\n",
       "\n",
       "\n",
       "0->4\n",
       "\n",
       "\n",
       "\n",
       "\n",
       "\n",
       "5\n",
       "\n",
       "<VariableDeclaration:: dp>\n",
       "\n",
       "\n",
       "\n",
       "0->5\n",
       "\n",
       "\n",
       "\n",
       "\n",
       "\n",
       "6\n",
       "\n",
       "<VariableDeclaration:: lp>\n",
       "\n",
       "\n",
       "\n",
       "0->6\n",
       "\n",
       "\n",
       "\n",
       "\n",
       "\n",
       "7\n",
       "\n",
       "<VariableDeclaration:: ip>\n",
       "\n",
       "\n",
       "\n",
       "0->7\n",
       "\n",
       "\n",
       "\n",
       "\n",
       "\n",
       "8\n",
       "\n",
       "<VariableDeclaration:: cst1, cst2>\n",
       "\n",
       "\n",
       "\n",
       "0->8\n",
       "\n",
       "\n",
       "\n",
       "\n",
       "\n",
       "9\n",
       "\n",
       "<VariableDeclaration:: nspecies>\n",
       "\n",
       "\n",
       "\n",
       "0->9\n",
       "\n",
       "\n",
       "\n",
       "\n",
       "\n"
      ],
      "text/plain": [
       ""
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "graph = None\n",
    "try:\n",
    "    graph = phys_mod.spec.ir_graph()\n",
    "except ImportError:\n",
    "    print(\"Install graphviz if you want to view the graph representation!\")\n",
    "graph"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80f46c51",
   "metadata": {},
   "source": [
    "We can see a number of (empty) [comments](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Comment) - which are simply empty lines and retained to be able to produce Fortran code with a formatting similar to the original source. Since [comments](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Comment) might introduce additional noise, they are ignored by default in the graph representation. Other than that, we also have some [_Import_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Import) statements, [preprocessor directives](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.PreprocessorDirective) and [declarations](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Declaration).\n",
    "\n",
    "We can also convert this representation of the specification part back into a Fortran representation using the Fortran backend via [_fgen_](https://sites.ecmwf.int/docs/loki/main/loki.backend.fgen.html):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d0ccd9c2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "USE iso_fortran_env\n",
      "\n",
      "USE omp_lib\n",
      "\n",
      "IMPLICIT NONE\n",
      "\n",
      "INTEGER, PARAMETER :: sp = REAL32\n",
      "INTEGER, PARAMETER :: dp = REAL64\n",
      "\n",
      "\n",
      "\n",
      "INTEGER, PARAMETER :: lp = dp!! lp : \"local\" precision\n",
      "\n",
      "\n",
      "INTEGER, PARAMETER :: ip = INT64\n",
      "\n",
      "REAL(KIND=lp) :: cst1 = 2.5, cst2 = 3.14\n",
      "INTEGER, PARAMETER :: nspecies = 5\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from loki import fgen\n",
    "print(fgen(phys_mod.spec))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e00f560e",
   "metadata": {},
   "source": [
    "When comparing the Fortran code to the above internal representation makes it easy to identify the one-to-one correlation between IR nodes and statements in the original source code.\n",
    "\n",
    "Let's pick out one of the kernel loops next:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d06624b9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Subroutine:: phys_kernel_LITE_LOOP"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lite_loop = phys_mod['phys_kernel_LITE_LOOP']\n",
    "lite_loop"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7aac3d17",
   "metadata": {},
   "source": [
    "Subroutines and functions are represented as a [_Subroutine_](https://sites.ecmwf.int/docs/loki/main/loki.subroutine.html#loki.subroutine.Subroutine) object. This allows, for example, to inspect the names of the dummy arguments expected by this routine:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "7e66ac24",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['dim1',\n",
       " 'dim2',\n",
       " 'i1',\n",
       " 'i2',\n",
       " 'in1',\n",
       " 'in2',\n",
       " 'in3',\n",
       " 'in4',\n",
       " 'in5',\n",
       " 'in6',\n",
       " 'in7',\n",
       " 'in8',\n",
       " 'in9',\n",
       " 'in10',\n",
       " 'out1']"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lite_loop.argnames"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb0b6061",
   "metadata": {},
   "source": [
    "Furthermore, all subroutines contain a specification and body part (either of which can of course be empty in principal):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "4a1850b5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "  \n",
      "  \n",
      "  \n",
      "  \n",
      "  "
     ]
    }
   ],
   "source": [
    "lite_loop.spec.view()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "6dad7303",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "  \n",
      "    \n",
      "      \n",
      "      "
     ]
    }
   ],
   "source": [
    "lite_loop.body.view()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67ad8799",
   "metadata": {},
   "source": [
    "As we can see from the above, this kernel accepts a large number of arguments and consists essentially of two nested loops. Instead of viewing abstract representation, we can also produce Fortran source code again, either by calling `fgen` for individual parts or the entire `Subroutine` object, or, in this case, we can also use a convenience API offered by the object:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "e5aba927",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SUBROUTINE phys_kernel_LITE_LOOP (dim1, dim2, i1, i2, in1, in2, in3, in4, in5, in6, in7, in8, in9, in10, out1)\n",
      "  INTEGER(KIND=ip), INTENT(IN) :: dim1, dim2, i1, i2\n",
      "  REAL(KIND=lp), INTENT(INOUT), DIMENSION(1:dim1, 1:dim2) :: in1, in2, in3, in4, in5, in6, in7, in8, in9, in10\n",
      "  REAL(KIND=lp), INTENT(INOUT), DIMENSION(1:dim1, 1:dim2) :: out1\n",
      "  \n",
      "  INTEGER(KIND=ip) :: i, k\n",
      "  DO k=1,dim2\n",
      "    DO i=i1,i2\n",
      "      out1(i, k) = (in1(i, k) + in2(i, k) + in3(i, k) + in4(i, k) + in5(i, k) + in6(i, k) + in7(i, k) + in8(i, k) + in9(i, k) +  &\n",
      "      & in10(i, k))*0.1\n",
      "      in1(i, k) = out1(i, k)\n",
      "    END DO\n",
      "  END DO\n",
      "END SUBROUTINE phys_kernel_LITE_LOOP\n"
     ]
    }
   ],
   "source": [
    "print(lite_loop.to_fortran())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "773a5f1f",
   "metadata": {},
   "source": [
    "In this notebook, we will not go into detail on how to actually modify the control flow tree of this routine. But we will extract this routine from the module and put it into a standalone module.\n",
    "\n",
    "Let's start by creating a clone of this routine with a new name:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "c3343bd9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Subroutine:: my_routine"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "my_routine = lite_loop.clone(name='my_routine')\n",
    "my_routine"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f00603d",
   "metadata": {},
   "source": [
    "Next, we create a new module and insert `my_routine` as a subroutine. To make sure the relevant declarations from the original module are available, we create a copy of the relevant spec:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "59e719b4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Module:: my_module"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from loki import Module\n",
    "my_module = Module(name='my_module', spec=phys_mod.spec.clone(), contains=(my_routine,))\n",
    "my_module"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b368f8b2",
   "metadata": {},
   "source": [
    "Let's ensure the new module contains `my_routine`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "85448ed7",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert len(my_module.subroutines) == 1\n",
    "assert my_module.subroutines[0] is my_routine"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "571c62a4",
   "metadata": {},
   "source": [
    "We can also take a look at the Fortran code of this new module:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "0f14e7ba",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MODULE my_module\n",
      "  USE iso_fortran_env\n",
      "  \n",
      "  USE omp_lib\n",
      "  \n",
      "  IMPLICIT NONE\n",
      "  \n",
      "  INTEGER, PARAMETER :: sp = REAL32\n",
      "  INTEGER, PARAMETER :: dp = REAL64\n",
      "  \n",
      "  \n",
      "  \n",
      "  INTEGER, PARAMETER :: lp = dp  !! lp : \"local\" precision\n",
      "  \n",
      "  \n",
      "  INTEGER, PARAMETER :: ip = INT64\n",
      "  \n",
      "  REAL(KIND=lp) :: cst1 = 2.5, cst2 = 3.14\n",
      "  INTEGER, PARAMETER :: nspecies = 5\n",
      "  \n",
      "  CONTAINS\n",
      "  SUBROUTINE my_routine (dim1, dim2, i1, i2, in1, in2, in3, in4, in5, in6, in7, in8, in9, in10, out1)\n",
      "    INTEGER(KIND=ip), INTENT(IN) :: dim1, dim2, i1, i2\n",
      "    REAL(KIND=lp), INTENT(INOUT), DIMENSION(1:dim1, 1:dim2) :: in1, in2, in3, in4, in5, in6, in7, in8, in9, in10\n",
      "    REAL(KIND=lp), INTENT(INOUT), DIMENSION(1:dim1, 1:dim2) :: out1\n",
      "    \n",
      "    INTEGER(KIND=ip) :: i, k\n",
      "    DO k=1,dim2\n",
      "      DO i=i1,i2\n",
      "        out1(i, k) = (in1(i, k) + in2(i, k) + in3(i, k) + in4(i, k) + in5(i, k) + in6(i, k) + in7(i, k) + in8(i, k) + in9(i, k) + &\n",
      "        &  in10(i, k))*0.1\n",
      "        in1(i, k) = out1(i, k)\n",
      "      END DO\n",
      "    END DO\n",
      "  END SUBROUTINE my_routine\n",
      "END MODULE my_module\n"
     ]
    }
   ],
   "source": [
    "print(my_module.to_fortran())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fe01e4c",
   "metadata": {},
   "source": [
    "And, ultimately, we can write this to a separate source file:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "3ff6dcd8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Loki::Sourcefile] Writing to my_module.F90\n"
     ]
    }
   ],
   "source": [
    "from pathlib import Path\n",
    "Sourcefile.to_file(fgen(my_module), Path('my_module.F90'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34e235dc",
   "metadata": {},
   "source": [
    "Finally, let's take a peek at the generated file (disregard the pylint comment, which is there only for technical reasons related to our automated testing):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "acf60783",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MODULE my_module\n",
      "  USE iso_fortran_env\n",
      "  \n",
      "  USE omp_lib\n",
      "  \n",
      "  IMPLICIT NONE\n",
      "  \n",
      "  INTEGER, PARAMETER :: sp = REAL32\n",
      "  INTEGER, PARAMETER :: dp = REAL64\n",
      "  \n",
      "  \n",
      "  \n",
      "  INTEGER, PARAMETER :: lp = dp  !! lp : \"local\" precision\n",
      "  \n",
      "  \n",
      "  INTEGER, PARAMETER :: ip = INT64\n",
      "  \n",
      "  REAL(KIND=lp) :: cst1 = 2.5, cst2 = 3.14\n",
      "  INTEGER, PARAMETER :: nspecies = 5\n",
      "  \n",
      "  CONTAINS\n",
      "  SUBROUTINE my_routine (dim1, dim2, i1, i2, in1, in2, in3, in4, in5, in6, in7, in8, in9, in10, out1)\n",
      "    INTEGER(KIND=ip), INTENT(IN) :: dim1, dim2, i1, i2\n",
      "    REAL(KIND=lp), INTENT(INOUT), DIMENSION(1:dim1, 1:dim2) :: in1, in2, in3, in4, in5, in6, in7, in8, in9, in10\n",
      "    REAL(KIND=lp), INTENT(INOUT), DIMENSION(1:dim1, 1:dim2) :: out1\n",
      "    \n",
      "    INTEGER(KIND=ip) :: i, k\n",
      "    DO k=1,dim2\n",
      "      DO i=i1,i2\n",
      "        out1(i, k) = (in1(i, k) + in2(i, k) + in3(i, k) + in4(i, k) + in5(i, k) + in6(i, k) + in7(i, k) + in8(i, k) + in9(i, k) + &\n",
      "        &  in10(i, k))*0.1\n",
      "        in1(i, k) = out1(i, k)\n",
      "      END DO\n",
      "    END DO\n",
      "  END SUBROUTINE my_routine\n",
      "END MODULE my_module\n"
     ]
    }
   ],
   "source": [
    "# pylint: disable=undefined-variable\n",
    "%cat my_module.F90"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cf2467a7",
   "metadata": {},
   "source": [
    "Loki's documentation holds further details on [how to read files](https://sites.ecmwf.int/docs/loki/main/frontends.html) and additional options (choice of frontends, preprocessing) for that as well as the [different backends](https://sites.ecmwf.int/docs/loki/main/backends.html) that are available to generate code."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.8 ('loki_env': venv)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  },
  "vscode": {
   "interpreter": {
    "hash": "5b6429b76fde06fc4400bf3c27b3ae893ffb7a047f8b8ee9418a3bc77878d107"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
loki-ecmwf-0.3.6/example/03_loop_fusion.ipynb0000664000175000017500000007755315167130205021316 0ustar  alastairalastair{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "eac9ce5b",
   "metadata": {},
   "source": [
    "# Loop fusion with Loki"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27a3debb",
   "metadata": {},
   "source": [
    "The objective of this notebook is to go through examples of how loop fusion can be performed using Loki. It is a continuation in a series of notebooks, and builds on the lessons of notebooks on [`Reading and writing files with Loki`](https://git.ecmwf.int/projects/RDX/repos/loki/browse/example/01_reading_and_writing_files.ipynb) and [`Working with Loki's internal representation`](https://git.ecmwf.int/projects/RDX/repos/loki/browse/example/02_working_with_the_ir.ipynb).\n",
    "\n",
    "Let us start by parsing the file `src/loop_fuse.F90` from the `example` directory and pick out the `loop_fuse` [_Subroutine_](https://sites.ecmwf.int/docs/loki/main/loki.subroutine.html#loki.subroutine.Subroutine) from that file:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2e5feac7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Subroutine:: loop_fuse"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from loki import Sourcefile\n",
    "source = Sourcefile.from_file('src/loop_fuse.F90')\n",
    "routine = source['loop_fuse']\n",
    "routine"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32e6430e",
   "metadata": {},
   "source": [
    "`loop_fuse` starts with an [_Import_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Import) statement to load the parameters `jpim` and `jprb`. Even though we have not specified where the [_Module_](https://sites.ecmwf.int/docs/loki/main/loki.module.html#loki.module.Module) `parkind1` is located, Loki is still able to successfully parse the file and treats `jpim` and `jprb` as a [_DeferredTypeSymbol_](https://sites.ecmwf.int/docs/loki/main/loki.expression.symbols.html#loki.expression.symbols.DeferredTypeSymbol). We can verify this by examining the specification [_Section_](https://sites.ecmwf.int/do/docs/loki/main/loki.ir.html#loki.ir.Section) of the  `loop_fuse`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b2078568",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "USE parkind1, ONLY: jpim, jprb\n",
      "IMPLICIT NONE\n",
      "\n",
      "INTEGER(KIND=jpim), INTENT(IN) :: n\n",
      "REAL(KIND=jprb), INTENT(IN) :: var_in(n, n, n)\n",
      "REAL(KIND=jprb), INTENT(OUT) :: var_out(n, n, n)\n",
      "INTEGER(KIND=jpim) :: i, j, k\n"
     ]
    }
   ],
   "source": [
    "from loki import fgen\n",
    "print(fgen(routine.spec))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "85708bf9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "   (DeferredTypeSymbol('jpim', ('scope', Subroutine:: loop_fuse)), \n",
      "  DeferredTypeSymbol('jprb', ('scope', Subroutine:: loop_fuse)))>\n",
      "  \n",
      "  \n",
      "  \n",
      "  \n",
      "  \n",
      "  "
     ]
    }
   ],
   "source": [
    "routine.spec.view()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7903824",
   "metadata": {},
   "source": [
    "Examining the body [_Section_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Section) of `loop_fuse` reveals a nested loop with three levels:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "48e2c128",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "DO k=1,n\n",
      "  DO j=1,n\n",
      "    DO i=1,n\n",
      "      var_out(i, j, k) = var_in(i, j, k)\n",
      "    END DO\n",
      "    DO i=1,n\n",
      "      var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "    END DO\n",
      "  END DO\n",
      "END DO\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(fgen(routine.body))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9cb0d7e0",
   "metadata": {},
   "source": [
    "As a first exercise, let us try to merge all the loops that use `i` as the iteration variable. This will involve using Loki's visitor utilities to traverse, search and manipulate Loki's internal representation ([_IR_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#module-loki.ir)). If you are unfamiliar with these topics, then a quick read of [`Working with Loki's internal representation`](https://github.com/ecmwf-ifs/loki/blob/main/example/02_working_with_the_ir.ipynb) is highly recommended.\n",
    "\n",
    "Let us start by identifying all the instances of [_Loop_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Loop) that use `i` as the iteration variable:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "98ae76b5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Loop:: i=1:n, Loop:: i=1:n]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from loki import FindNodes,Loop,flatten\n",
    "iloops = [node for node in FindNodes(Loop).visit(routine.body) if node.variable == 'i']\n",
    "iloops"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b8c054a",
   "metadata": {},
   "source": [
    "As the output shows, the visitor search correctly identified both loops. Merging these loops comprises of three main steps. The first is to build a new loop that contains the body of both the loops indentified above:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d236f472",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "var_out(i, j, k) = var_in(i, j, k)\n",
      "var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n"
     ]
    }
   ],
   "source": [
    "loop_body = flatten([loop.body for loop in iloops])\n",
    "new_loop = Loop(variable=iloops[0].variable, body=loop_body, bounds=iloops[0].bounds)\n",
    "print(fgen(new_loop.body))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "91b94218",
   "metadata": {},
   "source": [
    "`new_loop` now contains both the [_Assignment_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Assignment) statements of the original `iloops`. The next step is to build a transformation map - a dictionary that maps the original node to its replacement:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "40214ad6",
   "metadata": {},
   "outputs": [],
   "source": [
    "loop_map = {iloops[0]: new_loop, iloops[1]: None}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a120a02",
   "metadata": {},
   "source": [
    "Since we want to merge two loops into one, the first loop is mapped to `new_loop` and the secone is mapped to `None` i.e. it will be deleted. With the transformation map defined, we can execute the [_Transformer_](https://sites.ecmwf.int/docs/loki/main/loki.visitors.transform.html#loki.visitors.transform.Transformer):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5cb5d7c0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SUBROUTINE loop_fuse (n, var_in, var_out)\n",
      "  USE parkind1, ONLY: jpim, jprb\n",
      "  IMPLICIT NONE\n",
      "  \n",
      "  INTEGER(KIND=jpim), INTENT(IN) :: n\n",
      "  REAL(KIND=jprb), INTENT(IN) :: var_in(n, n, n)\n",
      "  REAL(KIND=jprb), INTENT(OUT) :: var_out(n, n, n)\n",
      "  INTEGER(KIND=jpim) :: i, j, k\n",
      "  \n",
      "  DO k=1,n\n",
      "    DO j=1,n\n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = var_in(i, j, k)\n",
      "        var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "      END DO\n",
      "    END DO\n",
      "  END DO\n",
      "  \n",
      "END SUBROUTINE loop_fuse\n"
     ]
    }
   ],
   "source": [
    "from loki import Transformer\n",
    "routine.body = Transformer(loop_map).visit(routine.body)\n",
    "print(routine.to_fortran())\n",
    "\n",
    "assert len([node for node in FindNodes(Loop).visit(routine.body) if node.variable == 'i']) == 1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9100bf60",
   "metadata": {},
   "source": [
    "We have also added an `assert` statement to programatically check the output of our loop tranformation. The `assert` will allow `pytest` to determine if this notebook continues to function as expected with future updates to Loki."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71c989bf",
   "metadata": {},
   "source": [
    "## Loops separated by kernel call"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eea4d2ed",
   "metadata": {},
   "source": [
    "Let us now try a more complex loop fusion example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5e1ca42c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SUBROUTINE loop_fuse_v1 (n, var_in, var_out)\n",
      "  USE parkind1, ONLY: jpim, jprb\n",
      "  IMPLICIT NONE\n",
      "  \n",
      "  INTEGER(KIND=jpim), INTENT(IN) :: n\n",
      "  REAL(KIND=jprb), INTENT(IN) :: var_in(n, n, n)\n",
      "  REAL(KIND=jprb), INTENT(OUT) :: var_out(n, n, n)\n",
      "  INTEGER(KIND=jpim) :: i, j, k\n",
      "  \n",
      "  DO k=1,n\n",
      "    DO j=1,n\n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = var_in(i, j, k)\n",
      "      END DO\n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "      END DO\n",
      "    END DO\n",
      "    \n",
      "    CALL some_kernel(n, var_out(1, 1, k))\n",
      "    \n",
      "    DO j=1,n\n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = var_out(i, j, k) + 1._JPRB\n",
      "      END DO\n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "      END DO\n",
      "    END DO\n",
      "  END DO\n",
      "  \n",
      "END SUBROUTINE loop_fuse_v1\n"
     ]
    }
   ],
   "source": [
    "routine = source['loop_fuse_v1']\n",
    "print(routine.to_fortran())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61b39ce5",
   "metadata": {},
   "source": [
    "In `loop_fuse_v1`, there are two `j`-loops separated by a [_CallStatement_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.CallStatement) to a kernel that modifies `var_out`. Therefore we can only merge the `i`-loops within each `j`-loop.\n",
    "\n",
    "Using the visitor to locate the `i`-loops as was done in the previous example is inappropriate in this case, because it will locate all four `i`-loops:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "7f81d32d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Loop:: i=1:n, Loop:: i=1:n, Loop:: i=1:n, Loop:: i=1:n]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "iloops = [node for node in FindNodes(Loop).visit(routine.body) if node.variable == 'i']\n",
    "iloops"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03c6ae62",
   "metadata": {},
   "source": [
    "Since we know the hierarchy of the loops, we can instead run the visitor on two levels. First to locate the `j`-loops, and then locate the `i`-loops within the body of each `j`-loop:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "25e64b2b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Loop:: i=1:n, Loop:: i=1:n] [Loop:: i=1:n, Loop:: i=1:n]\n"
     ]
    }
   ],
   "source": [
    "jloops = [node for node in FindNodes(Loop).visit(routine.body) if node.variable == 'j']\n",
    "iloops = [[node for node in FindNodes(Loop).visit(loop.body) if node.variable == 'i'] for loop in jloops]\n",
    "print(iloops[0], iloops[1])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c013040f",
   "metadata": {},
   "source": [
    "We can now merge the two blocks of `i`-loops:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "e1257b63",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SUBROUTINE loop_fuse_v1 (n, var_in, var_out)\n",
      "  USE parkind1, ONLY: jpim, jprb\n",
      "  IMPLICIT NONE\n",
      "  \n",
      "  INTEGER(KIND=jpim), INTENT(IN) :: n\n",
      "  REAL(KIND=jprb), INTENT(IN) :: var_in(n, n, n)\n",
      "  REAL(KIND=jprb), INTENT(OUT) :: var_out(n, n, n)\n",
      "  INTEGER(KIND=jpim) :: i, j, k\n",
      "  \n",
      "  DO k=1,n\n",
      "    DO j=1,n\n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = var_in(i, j, k)\n",
      "        var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "      END DO\n",
      "    END DO\n",
      "    \n",
      "    CALL some_kernel(n, var_out(1, 1, k))\n",
      "    \n",
      "    DO j=1,n\n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = var_out(i, j, k) + 1._JPRB\n",
      "        var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "      END DO\n",
      "    END DO\n",
      "  END DO\n",
      "  \n",
      "END SUBROUTINE loop_fuse_v1\n"
     ]
    }
   ],
   "source": [
    "for loop_block in iloops:\n",
    "    loop_body = flatten([loop.body for loop in loop_block])\n",
    "    new_loop = Loop(variable=loop_block[0].variable, body=loop_body, bounds=loop_block[0].bounds)\n",
    "    loop_map[loop_block[0]] = new_loop\n",
    "    loop_map.update({loop: None for loop in loop_block[1:]})\n",
    "routine.body = Transformer(loop_map).visit(routine.body)\n",
    "print(routine.to_fortran())\n",
    "\n",
    "assert len([node for node in FindNodes(Loop).visit(routine.body) if node.variable == 'i']) == 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9005ba3c",
   "metadata": {},
   "source": [
    "In `loop_fuse_v1`, identifying the two blocks of `i`-loops was relatively straightforward because they were nested in different `j`-loops. Let us now try an example where all the `i`-loops and the kernel call are within the same `j`-loop:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "aece6f3d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SUBROUTINE loop_fuse_v2 (n, var_in, var_out)\n",
      "  USE parkind1, ONLY: jpim, jprb\n",
      "  IMPLICIT NONE\n",
      "  \n",
      "  INTEGER(KIND=jpim), INTENT(IN) :: n\n",
      "  REAL(KIND=jprb), INTENT(IN) :: var_in(n, n, n)\n",
      "  REAL(KIND=jprb), INTENT(OUT) :: var_out(n, n, n)\n",
      "  INTEGER(KIND=jpim) :: i, j, k\n",
      "  \n",
      "  DO k=1,n\n",
      "    DO j=1,n\n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = var_in(i, j, k)\n",
      "      END DO\n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "      END DO\n",
      "      \n",
      "      CALL some_kernel(n, var_out(1, j, k))\n",
      "      \n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = var_out(i, j, k) + 1._JPRB\n",
      "      END DO\n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "      END DO\n",
      "    END DO\n",
      "  END DO\n",
      "  \n",
      "END SUBROUTINE loop_fuse_v2\n"
     ]
    }
   ],
   "source": [
    "routine = source['loop_fuse_v2']\n",
    "print(routine.to_fortran())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f9c763e",
   "metadata": {},
   "source": [
    "The [_FindNodes_](https://sites.ecmwf.int/docs/loki/main/loki.visitors.find.html#loki.visitors.find.FindNodes) visitor we used previously returns an ordered list of nodes that match a specified type. Previously we only searched for nodes of type [_Loop_](https://sites.ecmwf.int/do/docs/loki/main/loki.ir.html#loki.ir.Loop). We can easily extend this to also search for [_CallStatement_](https://sites.ecmwf.int/docs//docs/loki/main/loki.ir.html#loki.ir.CallStatement) by passing both node-types as a tuple when initializing the visitor:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "ddf60a6a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Loop:: i=1:n, Loop:: i=1:n, Call:: some_kernel, Loop:: i=1:n, Loop:: i=1:n]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from loki import CallStatement\n",
    "jloops = [loop for loop in FindNodes(Loop).visit(routine.body) if loop.variable == 'j']\n",
    "assert len(jloops) == 1\n",
    "nodes = FindNodes((CallStatement,Loop)).visit(jloops[0].body)\n",
    "nodes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb60a77b",
   "metadata": {},
   "source": [
    "By first using `FindNodes` to locate the `j`-loop, and then applying `FindNodes` to that we have built an ordered list (`nodes`) containing just the `i`-loops and the kernel call. We can now identify the loops that appear before and after the kernel call:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "eec8d2b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "call_loc = [count for count,node in enumerate(nodes) if isinstance(node,CallStatement)][0]\n",
    "iloops[0] = [node for node in nodes[:call_loc] if node.variable == 'i']\n",
    "iloops[1] = [node for node in nodes[call_loc+1:] if node.variable == 'i']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "737c3769",
   "metadata": {},
   "source": [
    "We can now fuse the two blocks of `i`-loops:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "7b751473",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SUBROUTINE loop_fuse_v2 (n, var_in, var_out)\n",
      "  USE parkind1, ONLY: jpim, jprb\n",
      "  IMPLICIT NONE\n",
      "  \n",
      "  INTEGER(KIND=jpim), INTENT(IN) :: n\n",
      "  REAL(KIND=jprb), INTENT(IN) :: var_in(n, n, n)\n",
      "  REAL(KIND=jprb), INTENT(OUT) :: var_out(n, n, n)\n",
      "  INTEGER(KIND=jpim) :: i, j, k\n",
      "  \n",
      "  DO k=1,n\n",
      "    DO j=1,n\n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = var_in(i, j, k)\n",
      "        var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "      END DO\n",
      "      \n",
      "      CALL some_kernel(n, var_out(1, j, k))\n",
      "      \n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = var_out(i, j, k) + 1._JPRB\n",
      "        var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "      END DO\n",
      "    END DO\n",
      "  END DO\n",
      "  \n",
      "END SUBROUTINE loop_fuse_v2\n"
     ]
    }
   ],
   "source": [
    "for loop_block in iloops:\n",
    "    loop_body = flatten([loop.body for loop in loop_block])\n",
    "    new_loop = Loop(variable=loop_block[0].variable, body=loop_body, bounds=loop_block[0].bounds)\n",
    "    loop_map[loop_block[0]] = new_loop\n",
    "    loop_map.update({loop: None for loop in loop_block[1:]})\n",
    "routine.body = Transformer(loop_map).visit(routine.body)\n",
    "print(routine.to_fortran())\n",
    "\n",
    "assert len([node for node in FindNodes(Loop).visit(routine.body) if node.variable == 'i']) == 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ec146b7",
   "metadata": {},
   "source": [
    "## Using the built-in `loop_fusion` utility"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a91b698",
   "metadata": {},
   "source": [
    "To facilitate loop fusion and make it readily available to users, Loki has a built-in `loop_fusion` transformation utility. However, currently this relies on manually annotating the loops with `!$loki` pragmas. To illustrate how it works, we outline its mechanics in the following:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "c79326a9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SUBROUTINE loop_fuse_pragma (n, var_in, var_out)\n",
      "  USE parkind1, ONLY: jpim, jprb\n",
      "  IMPLICIT NONE\n",
      "  \n",
      "  INTEGER(KIND=jpim), INTENT(IN) :: n\n",
      "  REAL(KIND=jprb), INTENT(IN) :: var_in(n, n, n)\n",
      "  REAL(KIND=jprb), INTENT(OUT) :: var_out(n, n, n)\n",
      "  INTEGER(KIND=jpim) :: i, j, k\n",
      "  \n",
      "  DO k=1,n\n",
      "    DO j=1,n\n",
      "      \n",
      "!$loki loop-fusion group( g1 )\n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = var_in(i, j, k)\n",
      "      END DO\n",
      "!$loki loop-fusion group( g1 )\n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "      END DO\n",
      "      \n",
      "      CALL some_kernel(n, var_out(1, j, k))\n",
      "      \n",
      "!$loki loop-fusion group( g2 )\n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = var_out(i, j, k) + 1._JPRB\n",
      "      END DO\n",
      "!$loki loop-fusion group( g2 )\n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "      END DO\n",
      "      \n",
      "    END DO\n",
      "  END DO\n",
      "  \n",
      "END SUBROUTINE loop_fuse_pragma\n"
     ]
    }
   ],
   "source": [
    "routine = source['loop_fuse_pragma']\n",
    "print(routine.to_fortran())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99e6b49a",
   "metadata": {},
   "source": [
    "The routine `loop_fuse_pragma` is identical to `loop_fuse_v2` except for the `i`-loops being preceded by `!$loki loop-fusion` pragmas. The loops that are candidates for fusion have been assigned to the same group i.e. `g1` or `g2`. Examining the body of `loop_fuse_pragma` reveals the following:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "cccf8d6a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "  \n",
      "  \n",
      "    \n",
      "      \n",
      "      \n",
      "      \n",
      "        \n",
      "      \n",
      "      \n",
      "        \n",
      "      \n",
      "      \n",
      "      \n",
      "      \n",
      "      \n",
      "        \n",
      "      \n",
      "      \n",
      "        \n",
      "      \n",
      "  "
     ]
    }
   ],
   "source": [
    "routine.body.view()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aab07764",
   "metadata": {},
   "source": [
    "One of the difficulties in parsing pragmas is that it is not always immediately clear whether they should be associated with the subsequent or preceeding node, or should stand alone; as examples think of the differing behaviours of `!$omp do`, `!$omp end do` and `!$omp barrier`. Therefore in Loki, pragmas are not attached by default to other nodes. Instead, Loki treats pragmas essentially like comments, but gives them a separate node-type to easily distinguish them.\n",
    "\n",
    "In situations where we do wish to associate pragmas with certain nodes, we can do so using the [_pragmas_attached_](https://sites.ecmwf.int/docs/loki/main/loki.pragma_utils.html#loki.pragma_utils.pragmas_attached) context manager:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "88a095e5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "  \n",
      "  \n",
      "    \n",
      "      \n",
      "      \n",
      "        \n",
      "      \n",
      "        \n",
      "      \n",
      "      \n",
      "      \n",
      "      \n",
      "        \n",
      "      \n",
      "        \n",
      "      \n",
      "  "
     ]
    }
   ],
   "source": [
    "from loki import pragmas_attached\n",
    "with pragmas_attached(routine,Loop):\n",
    "    routine.body.view()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e768ca6",
   "metadata": {},
   "source": [
    "We can now visit the loops and sort them into their respective fusion groups:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "16c23deb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "g1: [Loop:: i=1:n, Loop:: i=1:n], g2: [Loop:: i=1:n, Loop:: i=1:n]\n"
     ]
    }
   ],
   "source": [
    "from loki import is_loki_pragma,get_pragma_parameters,Pragma\n",
    "from collections import defaultdict\n",
    "\n",
    "fusion_groups = defaultdict(list)\n",
    "with pragmas_attached(routine,Loop):\n",
    "    for loop in FindNodes(Loop).visit(routine.body):\n",
    "        if is_loki_pragma(loop.pragma, starts_with='loop-fusion'):                         \n",
    "            parameters = get_pragma_parameters(loop.pragma, starts_with='loop-fusion')\n",
    "            group = parameters.get('group', 'default')\n",
    "            fusion_groups[group] += [loop]\n",
    "\n",
    "print(f\"g1: {fusion_groups['g1']}, g2: {fusion_groups['g2']}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e327326",
   "metadata": {},
   "source": [
    "`fusion_groups` is now a dictionary with keys for the two fusion groups, and the associated `Loop` nodes are values for each key. We can now create and apply a transformation map similar to how it was done in the previous examples:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "625d0640",
   "metadata": {},
   "outputs": [],
   "source": [
    "for group,loops in fusion_groups.items():\n",
    "    loop_body = flatten([loop.body for loop in loops])\n",
    "    new_loop = Loop(variable=loops[0].variable, body=loop_body, bounds=loops[0].bounds)\n",
    "    loop_map[loops[0]] = new_loop\n",
    "    loop_map.update({loop: None for loop in loops[1:]})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f2f4222",
   "metadata": {},
   "source": [
    "Since `!$loki` pragmas are only intended to pass instructions/hints to Loki on source manipulations, and aren't needed for the eventual compilation, we can also remove them from the routine:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "785e12df",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SUBROUTINE loop_fuse_pragma (n, var_in, var_out)\n",
      "  USE parkind1, ONLY: jpim, jprb\n",
      "  IMPLICIT NONE\n",
      "  \n",
      "  INTEGER(KIND=jpim), INTENT(IN) :: n\n",
      "  REAL(KIND=jprb), INTENT(IN) :: var_in(n, n, n)\n",
      "  REAL(KIND=jprb), INTENT(OUT) :: var_out(n, n, n)\n",
      "  INTEGER(KIND=jpim) :: i, j, k\n",
      "  \n",
      "  DO k=1,n\n",
      "    DO j=1,n\n",
      "      \n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = var_in(i, j, k)\n",
      "        var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "      END DO\n",
      "      \n",
      "      CALL some_kernel(n, var_out(1, j, k))\n",
      "      \n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = var_out(i, j, k) + 1._JPRB\n",
      "        var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "      END DO\n",
      "      \n",
      "    END DO\n",
      "  END DO\n",
      "  \n",
      "END SUBROUTINE loop_fuse_pragma\n"
     ]
    }
   ],
   "source": [
    "routine_copy = routine.clone()\n",
    "routine.body = Transformer(loop_map).visit(routine.body)\n",
    "pragma_map = {pragma: None for pragma in FindNodes(Pragma).visit(routine.body)}\n",
    "routine.body = Transformer(pragma_map).visit(routine.body)\n",
    "print(routine.to_fortran())\n",
    "\n",
    "assert len([node for node in FindNodes(Loop).visit(routine.body) if node.variable == 'i']) == 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac7e3b33",
   "metadata": {},
   "source": [
    "You may have noticed in the previous code-cell we made a copy of the object `routine` before applying a transformation to it. We can now apply the `loop_fusion` utility directly on `routine_copy` and compare the results:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "164b5054",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "loop_fuse_pragma: fused 4 loops in 2 groups.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SUBROUTINE loop_fuse_pragma (n, var_in, var_out)\n",
      "  USE parkind1, ONLY: jpim, jprb\n",
      "  IMPLICIT NONE\n",
      "  \n",
      "  INTEGER(KIND=jpim), INTENT(IN) :: n\n",
      "  REAL(KIND=jprb), INTENT(IN) :: var_in(n, n, n)\n",
      "  REAL(KIND=jprb), INTENT(OUT) :: var_out(n, n, n)\n",
      "  INTEGER(KIND=jpim) :: i, j, k\n",
      "  \n",
      "  DO k=1,n\n",
      "    DO j=1,n\n",
      "      \n",
      "      ! Loki loop-fusion group(g1)\n",
      "      DO i=1,n\n",
      "        ! Loki loop-fusion - body 0 begin\n",
      "        var_out(i, j, k) = var_in(i, j, k)\n",
      "        ! Loki loop-fusion - body 0 end\n",
      "        ! Loki loop-fusion - body 1 begin\n",
      "        var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "        ! Loki loop-fusion - body 1 end\n",
      "      END DO\n",
      "      ! Loki loop-fusion group(g1) - loop hoisted\n",
      "      \n",
      "      CALL some_kernel(n, var_out(1, j, k))\n",
      "      \n",
      "      ! Loki loop-fusion group(g2)\n",
      "      DO i=1,n\n",
      "        ! Loki loop-fusion - body 0 begin\n",
      "        var_out(i, j, k) = var_out(i, j, k) + 1._JPRB\n",
      "        ! Loki loop-fusion - body 0 end\n",
      "        ! Loki loop-fusion - body 1 begin\n",
      "        var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "        ! Loki loop-fusion - body 1 end\n",
      "      END DO\n",
      "      ! Loki loop-fusion group(g2) - loop hoisted\n",
      "      \n",
      "    END DO\n",
      "  END DO\n",
      "  \n",
      "END SUBROUTINE loop_fuse_pragma\n"
     ]
    }
   ],
   "source": [
    "from loki import do_loop_fusion\n",
    "do_loop_fusion(routine_copy)\n",
    "pragma_map = {pragma: None for pragma in FindNodes(Pragma).visit(routine_copy.body)}\n",
    "routine_copy.body = Transformer(pragma_map).visit(routine_copy.body)\n",
    "print(routine_copy.to_fortran())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2534fd4",
   "metadata": {},
   "source": [
    "As we can see, the built-in Loki utility `loop_fusion` achieves an identical result to our manual transformation. "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
loki-ecmwf-0.3.6/example/gfx/0000775000175000017500000000000015167130205016160 5ustar  alastairalastairloki-ecmwf-0.3.6/example/gfx/intent_out_map-crop.png0000664000175000017500000010266115167130205022662 0ustar  alastairalastair‰PNG


IHDR|„âb»'iCCPkCGColorSpaceAdobeRGB1998(‘c``RH,(Èa``ÈÍ+)
rwRˆˆŒR`ÆÀÁÀà ÎÀÈ ˜\\ààÃ0|»T—uAfaÊã\)©ÅÉ@úg'•000fÙÊå% v-’”
f/±‹€²·€Øéö	°ûXMH3ýÈæK³™@vñ¥CØ 6Ô^tLÉOJUù^ÃÐÒÒB“D?%©% Ú9¿ ²(3=£DÁR©
žyÉz:
FFF p‡¨þOF±31@ˆÍ‘``ð_ÊÀÀò!fÒËÀ°@‡*BLÍA@ŸaßœäÒ¢2¨1ŒLÆ„øö×JAôÎ8eXIfMM*‡i  |Ïó@IDATxìœW‘ÿksÎY»Zåœs´‚%ÙÆÙlÎæ6ŽãL:¸»?æ8àÜÙ‡
c08'YÉ–•sÎyµÚœsÞýÞªW½£™ÝÕ†™Ù™ßÓG;=ÝýÒ÷MwW׫WЮI˜H€HÀTU×Ë‘Sy^Ð6a(˜6a˜ÄÆDŦûE›ƒý¢—ì$	À p1·Tžü¯
7$ÚËFzÆúJùÑ7ï铲¼§QlI8ºàà Oˆ– ˆO7ƒõ1Ò:ÄZìÍ
ô¿.³Ç$@$@$@î&@ÃÝÄY			ø!
~8èì2			¸›wg}$@$@$à‡(pøá ³Ë$@$@$àn8ÜMœõ‘		€ ÀᇃÎ.“		€»	Pàp7qÖG$@$@~H€‡:»L$@$@î&@ÃÝÄY			ø!
~8èì2			¸›wg}$@$@$à‡(pøá ³Ë$@$@$àn8ÜMœõ‘	ø,¶Öf9sà¥é_ÞùíR]~y@ʺ‘BŠsISCõuY\í¿îÄAØö”œ„’Y¤;	Pàp'mÖE$Ð/-Í
²ù/_“æÆÚËÙðüãRpqw·çÕVæK{[ßÚ;æÏ9µI.ëÿ¾¤Ö–&©«.êÌZUzQNî~¾óû@l”ž’wž}HÚÛÛ\·oÿ	ÎsL®ö;ž7ßKóŽÊÞõ?Œ¢Y¦	Pàp#lVE$Ð?ím-Rtù€@“ÐSš¼ðS’:¾ÛÓÖýþÓªEÈíöœî:æ¿t|­ŸpswY\+Î=([^ùFçñáãWJî™ÍÒÜT×¹¯¿ÑñY2ý¦ÏI@oýýeÉü7N øÆ³0		xž´‡6ÿ—”éÛx\Ò(™Û·»4*ïÜV‰ŒM•Ö–F9¶ã7)yç¶KbúD™±ìó²Ó/¤µ¹Qömø©¤˜+“>"ùçwè”È_¤Uš‘“o•QSïKÇß•†ºr©(>'µ•y2fÆ=2bÒ-æÛžÂÜKqîa™³úïL;êkJäè¶_Keéy‰Ž&S?*1	ÃåÂÑ·µü&;ã^sÞÖW¿%çRŽnÿµÔU¡cÂÜOHêðY“*%:Å‘1zQgßö®û±Ÿ¸JÒ²ç˜};ßú'mÓ½ªYÙh¦BÂcdò‚‡%%k†i{mU¶ë $¤M4u^:±N2ÇÞ$¥ùÇääž?˜:c“FÊÌ_’°ˆxSfÁ¥=r|×o%8$B¦.yL·qõcš˜#[Ÿ’Ö3&Ë´%Ÿ‘àÐÈÎsœõ\B¢®3WeU•^Ò:þWškícú(æý1dHÀ/	àÁY©ÓKîþ¡dŒZxƒUÃ7ÖU¨† V.ÇØCL[ú˜Ñ\>ý¾ŒyŸËèéw«Vb•T—åÈÎwþYð°Ÿ8ïA9ðÞšxMÅ}ð=-É™Ód„
!PíCëà˜ÈööVÒô³]¶¾ö-	4„˜Ê¨.Ï‘j=×JE9û$$4B²µ
¡á±2I……ø”±æpdlºÔ¨cOÁzîùÃo˜]Eg¶A!aŸ)nÿž¤dN—ýfŽ£íÇv<+Yã–Ë0ZÀ"$5åvYxÇ÷Ðqöà«f?þ@ðšºø3*°¥ËÎ7`úÓyP7¶½þ‰Òc³W]Ê‹N˱¿µvÚÇ–æ:q6fÎÊ¿o}_"¢SdÒü¿’â«mîR	¿98†Ü±Á$@ €^eÉy¹ræI¹ [(xûžwËFkm4x¨™i—˜„,É¿¸K"¢’Tˆ¹`„“ð¨D)Í;fÊ1ù£‚	þc:¢¾ºøºüM
UFC!¦®ºP ÌXþE£À4FME®Ö[à´A!æ-ÂDò°©*xĘó"¢“Uèš'{â¶K‹
=—O¿g´‰iŒÐHЮX	Â4 É*ˆØ´4áÚß²‚“Ý%´ÐLQ­OMåÜÊ:³Ö×›¾)4'aqF[ÒyB7Žc檬†Ú2
Äô‘óÌ4P7ÅòÐ!@cˆ›I$Е@úÈù²øÎ’¼;dío–ÆúÊ®'ؾAH0{‚ôáîÌP´¥¹^ƒ%((ÔüŸ0çc’>ªC	TÁÄJ ÉpL^P„€€ ’7߃‚CÍgG¾£é0;zøƒi#‘Žv[§bJ(B§Z Ù€À1bÒ#½ûÜ#úà?®uw´Á:Úgéð–ÿ•Ýï>)µUùjÓ¢§´wžfÙxé”
4Vêh“˜é°J1ßh!¬ãŸÎûè8fu*¸!aêÆ^–Å1è*?«=eóïP%@c¨ŽÛM~NS x_vßU€h1Z‰E¢oöÅæ
„L?ÀN#kürUé'w[¤=dLбÁÎ(µÙ€)çÄ	LT­L†Ñ”«öBìC`ÏA¦¡¶Ü<ü­|½ÚJ8k„Œ#ÛþOœcËQ˜³W’Ò'Û”øÔŽéSh7ò/ìÔ飫°ðIóÀ·ŸŠ©!¹:õ»h{¬„)£Ð(Ó.p‚K|Êë°ù„ÖÃYÇö5ÎÊK!˜Úð¥ë»”Ï/C“F‡æ¸±Õ$à§®¾íëÞðaôˆ7ã¸äÑ×­HF£ãÍX?íZ¬Ð¸ªíÈ³ÄØZd]&ó?ô]cœùö³ªp`4Ë?òó«ç^Ó2t”ÕñÝ1?l0*ŠÏš©›9jß°ÃÏä Ú‚@û£VL·d«ÁçÙC¯Ê«¿¼S
9'ZfBHÁþù·>!ÃÆ.½:­ðøuc=qµ1„7û#¦ìaÚ3^–7Ÿ~@í.Ò®oúy­í¨Çb1rÊmr`ã/äøŽßH˜ömG7r¾òß·«Æ'H§¢¾i8Z<ÑÆ9kþÞØ²úàWf‰ò¤Ÿ”ñ³è¬×UóÎoë2fIS\–5méã²óíV$R’tš	mgÚT’¾¦+Ú}aëI€†8Ã'råÿc³Ed¹ì	Þ¼ñàFÂÔE›j
,›{&ûyöm3Ò)ŒˆqrÛ	L§ ÁÊ
Õ·tpñ«W.sL{?cÙKr<°°ÌY‚_¬q+dóK+Ç52+¸)‘qéfz¥»h§F¢‚ÕHÒU´WU}(GKkÒsäTD;]r÷“²wÃOäÕÿ¾AHŒAê¨iwKrætyû™5®JdgÌ”mEtµê‚ ±Ó/ÌÊf>~¥W›h+Ð+¬«°ŽQm3F/’ë¢âÚëEµÐZ amqÅ.-{îµ6›šy4Ùë^pû÷tYo¹‰4‹H¹~˜ü‡£ÅúÏX³§$àõz-öF:‡3šXEáDÍY„RgQ£½:‹Àê¬,gí4S5:€·•uy,¥µ··vN1Xíp¬«±®ÂØlاfìfQ&„ǰØoj‹ïŽQq±Ïª×lv]£ÑZÁÜœõמ×h/l‹³ºíP_£Åö—àà秆cð³ °ŽÂšcE(µ7Í1")ÎqÌkÞú;^Þ;³:+«ó m†•ŽÉ*? àÚíØj‡c]ÐØ8&8ó²"Ìâ"Àv&
=o%+’¬õ|LJ¾U/αØYçÛ¿;ë¯=¯£ƒ1gu[åòÓІÃÆš=%   Àá1ô¬˜H€H€ü‡ÿkö”H€H€&Ôåy<@Þ@€‡7ŒÛ@$@W	ì:\,e•MrËâaÐ#—ÖÖvY¿#Ïg¤ôx>O O Àá)ò¬—H€lJ+å½Ý²pz²d¥GÙŽôn3·°Vv*‘•óÓ%)>¬w™x	¸‘7ÂfU$@$àH ½½]¶(–ú†Y¥S#=k5˰¾·µµÓðÐ@Y2;Uú^–U&?I` Pà(’,‡H€n@ai½lÙ[(Kç¤IzrÄ
æv}º)w_‘,U¡c Ëu]#@Ï(pô̈g	À€€&âƒ}…Ò®>ÎVÌOM4'›U˜AËç¥õKs2 ga~K€‡ß=;N$à	–­Å
’ßQUmC<Áˆuú&
¾9®ì	€—Àj…¾S(îNÛöI}c‹•fôjõ‹»ÛÇú|ŸßcöHÀÃà/cï±RY£K]=é/þ=ÖïÈWga‰2ºÿÆÅê}”Xv‹HÀóàt“:ðJÔeªó¦&{¾AW[°÷X‰””5ʪE]<˜zMÙŸ$@Ã'‡•"ð4Ó«äÈér¹uÉ0¯Œy‚-ë¶çÉ”1ñ2aTœ§q±~? @Ù]$pƦVã–|Xj¤Ìœ˜è¾ŠûXÓ¡Se’[Pg|€„‡õ±f#ž	Pàè™Ï  ^8v¶BÎ\ª2Zˆðà^åñ†“àtì]
76;F¦ŽKð†&±
>H€‡*»D$à^uúÀ†­1'woåXÛñsr6§ZV«ÇÓȈ¡#0
 5ˆ(p"\M$àûj´Öœ‚Z£ÕúSM:%mGfZ¤®fIòýdÝF€‡ÛP³" _"P]Û¬qKòe²]ŽëK]3}9«SCGuŠèæâsýc‡ÜO€‡û™³F !NËJKŒV#88pˆ÷Æuó[tY/V²¤$†{Õ²^×-æo&@ÛG‡m#ð*UMF«©†‘™Ñ^Õ¶ÁlÌ¥¼Ù{´Tn^˜.	±aƒYËöa8|xpÙ5 #°óP±ÀSçêEÃüÒ58ÎmPÃØh5&]<+uàÀ²$¿!@Ão†š%è’òyO¡,š‘b)ûR†/åÉ+ª“í‹eùÜ43ÕâK}c_—ŽÁåËÒI€†(„wߪÏššÛ4àY:Ã»ÛÆlˆ.(0@–©à`;ÊMpN€‡s.ÜK$àÇ
Šëeë"Y6'UR“"ü˜D÷]/*köÊâ™)ϪL$Ð
ÝÑá1 ¿";…Í:}¨O–ÏK÷«¾÷§³[öRÔ€~’—‡Ÿ4»I$Ð=Ëê¼k×á¹Y§OÝ•éÆ”W5ª·Õ]>›$ÙÃügÏQòï³)pø÷ø³÷$à÷ZZÛä½]ÅÕò[Àjžªšfu–.¾ì£d@`ùY!8ülÀÙ] kÎ_®–ýêš|Í¢‰‹	½v€[ý"cÃŽ<™>!AÂùžÖ~ÁñãÌ8üxðÙuðWXy‚`kɉa2wJ²¿bô~ï?^*%õ&\¨Ä™t`>^`vH +S*å˜FE½eñ0‰ŽdŒ®tþ[]}‹q>aT¬L=t#é<ÿ+‘‡ÿ9{L~I ¡±U6ªVcxz¤ªúý’';}ôL¹\¼R#«fHDx°'›Âº=D€‡‡À³Z ÷ÀÃîœÚkܺ$SÂÆ~y÷‘Øš ô!ÜÈaQúí(Ç&6’†&––V¹t¥ÌcÇnçáb‘%c†ÇʨlÚkxl0lcZë¤þ‡¶ÓZùE•RWßd;Ãs›©É1î¹øpÍ8|xpÙ5ð4¢’*yøkÏIttœÛ›‚ÐêêÇKB4|<M`ÿñ2ãtÌ𙡑E[ÕqƸ±Æ)Ô‘3"Kf¥JZR¸•5ÈöÅRQÝá *&*XVÎO—¤øp)¯l”M»LXtŸÆÎ3U6}|‚>]î”Htd°,š‘"™i‘æ7qèT‡§Y§'s§ÛPàpjVD$àŒÀŸß½(O|fšÜsóp#pàœ×^”ä„0yò«³Ím›
"ÐzÜ·*[þþ§{ÍCä‡_™-é)F‰×ÐòAòòúgUpŸØy¨Xª0ñÉ»FË7¶ïºžáwóÏ_ž¥‚è5Ÿ½m¤üê§dóÞÂëÎç÷ Àá>Ö¬‰HÀ	KË1u\‚9zò|¥UíÆãŒ7ÂÆS/ž–÷÷ȃwŒ’;W—›ÕöÖ}EFØØ}¤D~úì1ÉV×åÍêY”É÷	`Ì!`N'Ëç¦	<–ÚÓýkFaã•
9Fp6>Þ´©€BÃNÊýÛœôr?sÖH$à@àϪѰҋªñ@3<Ú|ŽÖé–G?jÂ/Ó·ÜùÛ9rÛÒaæ|þñ}gsª=¦Mn»)³K‡aÛSWß"-­Hçjjmm3ŠÊ(L$À)ÂgÕ$@× Øš}Z¤¤¢Á¨Æÿýw'¤¬òúH¢ˆ6úÄÏ÷Ë´qñòÏN—ùÓ’eíÖ¼kr˧	¼ðö™?=YàÝžJ5.ö¥¨-Gq¹ÆhÑ×ꬴ(iÓ™—ZD˜2ÞœY¦Ë_×iHs¨Ì1åòðUxøüá­¶Ò¸ék¬ßl7¬ôÒúK²@µ©ú;©oh5»ßÜœk~77ÍM•ï­S+m‚•-ϾrÖÊÆOÐÁ»6zj«%ðME%Uòùï¾,A‘£zì ?q7r|…J<).Lu5BUMs—r0‡ßªsõ–oŽ.¾4Vž’×~ý¸Ã^~õó)­O’ ÂÓÃ@ÔqU
ÚߦLì)H—Q'Æ…JyUS{û9ŽÛMuòҦ˚›&9â÷ @
Ç@d$@ý'`7ò³—†	æâ%Ì×3ùgÂzï(l`WW¿gr?®Rq?sÖH$@$@~G€‡ß
9;L$@$@î'À)÷3g$à7*tþ¼±©U""h*æ7ƒÞŽzܤ?Ó~Œ^ÏY)pô̈g	ôÀÖ}…RR^#‘aíÒZ{¦%l–¸˜k±5¶d–Ö_ÇÅKé±Kâ`÷Ùßb»ÍÙ6!úiy+U»e]ýÄßI·àúq«TúYI€®'PTZ/h¬“%³R$#%òú¸‡¼ˆ–T÷¯ƒ?(8Ÿ1k ¿ Ç[Æ)—¾%"¨V@Þ™H`hØ¢9h>Á¿aÜ`ï)pÜ 0žNþB ¾AÕȪÕ-SÇñÆê/ãÎ~öŽÀÉ•rò|¥ÑvDEPï
5
½¡ÄsHÀÏ:U&9yµF«JÕ±Ÿ
?»ÛKp¶aGž¤%EÈœ)œjì	Žžñ8	øšÚf³eò˜8?’Æq~4ôìj?œ¿\-UH_­KÄc£¹DÜJ
®Èp?	ø½GKq%`LgG~6üìn?	´´¶ÉF‚Œ‹	•…3¸\ÜN
Ψp	ø
õ¨¸iW¾ÌRG£²èàȆž]—je÷á“%1ž‘gíˆ)pØip›üŒÀŽƒÅR­Ó(ð„ØÜL$@ý&€@†›4&KXh Ü4'­ßåùJ8|e$Ù¸¥¤jw,œž,YéRuèx*	ôš@Aq½l=P$Ëæ¤Jª–ú{¢Àáï¿ö߯´··Ë–}E÷ä+†Û¯Æžõ\sì-”¶6‘åM90Ð5‰8<ód­$àvˆ±u‘‰êšžÌ·-·+ôk%å
òþžBY¤¥™i‘~É‚‡_;;íë`¯ÖòH˜O~_§O‚uåÉMªÚ
ðß7,_wöÏû	l×)Äf¹yÁ5»),GTça¾®ýàÚ7ïÿ}²…$pCv.–{ö˜@•›[X+/oÈ‘édÙÜ4
7D’'“ÀÀX<+UæMM–×6嘸,¨á^C ØgzÂŽøêÚ©©mô‹¾²“K Pý¤¥Ðˆv`©z®´Æ¦)«¨õ\Xó

7„Œ'{šÀï^Ú-ï¼wRBÃ…ÑÓc1Ô꯫­–w~÷¥¡Öl¶×£'¯È÷~ú–DD2D^³»­µU†¥EQàðšaCzE^3ÂR%(2¡Wçó$°„6Ÿ²6ùé#"¢$(b˜ôƇ»ÑÒ¨ñ›Ê„6><Æì			x
Þ2l			ø0
><¸ì			x
Þ2l			ø0
><¸ì			x
Þ2l			ø0
><¸ìÚà(Î=$M
Õ7\xKsƒ^Ú{Ãù"C_Û<u»£ŒªÒKRWU莪X	@	Pàè#8fóNx¨¿ÿç¿‘æÆkOï{QÎy³Ï
n¬¯è"`ìÛðoR^xã>j+ódÇ[?¸¡v´¶4I]uÑ
åqvr_Û쬬¾ì;¾ó·²ãÏû’µWyNîýƒ\:±®Wçò$ Ï Àáî¬u´·µHqîAik½ò¹ª,Gj*rû\ã‘-OÉÙƒ/÷92¢/[^ùFŠðŠ¼ÃÆ,•‘S>äma#H€#Á¡‘eÙ7Jóiþ?˜i€Ø¤‘¦Œ°ˆxÕļ!޾#Á!a2aîƒrtû¯Í9:&Ìý„¤Ÿeйxü]¢®ÈÔÅšïç½&m*l%¦OtZ®U7òÙö´,ºãfש½”ˆ˜Éž°ªÇö÷¦Í³VþDƦ]dz²äœÑ¡}wË©½/H«z	‹–äaÓeĤ5rlÇo$:!SòÎm7ý˜±ìóÊ7DòÏï3þ"­*DŽœ|«Œšz‡iû%eÍUxT’@Ÿeu“Ÿ$à1U¥õZ*”ô‘ú݆¼óÛͽ÷_HÔpøÂ(²×8¶ó7rxËÿ˜ÿe':7Ô•ËŒe_ÔÿŸ78ܠНÔÏ%wÿP2F-ì<i#æH|ÊXIÍš%“>,2zîéÍ*¬<¤‚A¸©çn{ý;›.³W]Ê‹NË1Fp•ðà5åvYxÇ÷@qöà«úЬ4Ós4ÿ”EJDt’BÃceÒ‚‡M;¬ò¢ã†Éé}Ra©Þì´Q„>x•kåÁgscj€uîª,½ µyæ{OíwV¶c›ò„V¥uAÐÛñæ÷eüœÉØY÷›–̱K¥¹©V.GªË/Ë´¥Iî™Í*¾/ÕªÚùÎ?aÂá÷þS0ž•%çeïúŸÈØ™÷©6SJòŽtö‰$˜ö|çÙ‡¤½½Í%þNY:ËxËÿ
´ª}IŽåáþtr÷ó})jÀóÔVæK{[k¿Ê¥ÀÑ/|Ìì­BB£ÄúÔÙÌIóÒ7í*ó`
–úš#$àvåÌ×½•DD§H¸ÆmÁ[{RÆ”ÎrfÝüUFæê÷mZVž–S,EgŒFÓ aqª9Öy¾ãƈI·˜7󲂓rµ^ \àmùã’G´Á¡ª˜ªÇb:‹IÒïhÞþËO›7üŒÑ‹TSp}¹™ºÙèMû•íØf¼‰AèrÅþ£oiÙsÌÃ
$hƒæÝò„Ñâ€-l^ò/î2‚##áQ‰RšwLŠ.4†O¸YÆÌ¸W†¿¹›Þñ?ˆV×ô›>'®sý²tÌßXW¡ÑÝú{\Ù'äŽå¡ßÍMu}*o 3­ûý§õìûÔ4ÚÂ)•–å5ÆÏ~@Â"ãM{j¯®^€]Çúß?¦ñª²Ÿ,ÒÞÑÜô‘óeñÿ$'UÍÿ·üÕ3æß]g¬›Tþú1-ƒanpé#æ«’î²¼åß&ÙW›·~4&(8LV?ø¿:ÝðGYÿüãfjtg)@C­#oΩMFåš5n¹Éï¬\Çüím׿ñõ¦ýÎÊvÖæ¬qË®ãiµçcêhýï>-:M…q‚ LÑ/¤ N®õF«¤SWHT3’š=[ ak+A³ÂDvÐæA{‰©SL¿A»YQ|βcfÜ£/S¯›²LΜf4
ùv©p› šÆ¿6šÅÞæo¨-5ÂtDt²iJÁÅ=jÿõ’¾TMÜä…˜ët÷Ú'Ö..y”jìNÊ…coëoûã×µS¨‘1©R¢ZI¼PX©£=eªI=«ýÉW
è'6ðòé÷ÌËÉtÕà¶4ÕËÑmÿg4‚!zMV-iJÖŒ«,4¯¾¨`ê/Ú+ŠÏ
¦„¡µ3ýº¼ÝªÎh[›e߆Ÿš-ôÃUß:39ÙàUê
wù&}c®.Ï1oѰ°T­PÛãB\vßUxh1oÒv*°RÄ:ß~ÌÚŽŒI3Ühp3ƒA|ÊëðuŸùvê
çã2iþ';œf©­>p¡=AûJ®26
µåF(io¿*!]--{â½èw™H¶žä¬Ü«§›<ì¡áÁM
†´E9ûÍþÞ´ßYÙÎÚÜOÔ‰~ÀîeÞšoÈ88ºK˜ÎÂ
¸f_®SMÉFh,¹rÄØ… /Ūñ`";LÓA†dl—¶>-(F¨¦ã‚â®›²<½÷Oæšš»æï%6q„N“~×ü^{›š½(½ á:À4%´p3W|Q_0¶ËñÏ™c%W›ë_U*Í;j~×°¥rœBÅ‹î]ödÚ£ÂDš
Ⴆ
ÜöÚ·MÓ–>.9'7ÍgCm™DÇgʂۿ')™Óuºög¦ˆÎ¼#çɸY˜iJ‡6ÿÒØS-¼ýŒc¯S—°c=ýníϪnûfÏç¸M
‡#~â:Þåê:ÓñÖ`.¾d½ðÞ~æA5VŒ4o8Ž8ñ
UBêxìîL°ëØöÆwjóîϽjʳ4Ö[yPp¨ÌÑ›nd‡>ø•Y–‹·¼Á_KÚ¬Ž†œr›Øø9¾ã7ª‰I07Lï¼ÿç¯í
T¨Kïù‘¶y˜¶3T^ýå2ÿÖ'ô»¢³8hjÐ^ܰR²fšýÎÊÅ0@[15„àÚß>"a:}¥åëQSGOíwV¶³6;ãY¨†·¨S"êp4NE:ý4kÅ—%!mb'´W ×6gŒZdÞìÞ~öA3…„·¶åù¹ÙŸ6AÞ|Z5Y:ý›aÊ7yù‡œ1Yßäõ‰thóK“Nا,±¿Kß…9{Íý¡®ª@ µ@êM~ÇáWµ(öEÐ  ŸýQ9wèUc£dv8üÁuîØœípL(/6õ5¥r|×s2ÿCß56e¸ÇA‡fFÙ¥ùGMV\«V2yÕ†éì¡—
¦Bsϼ/ñ©c¯›Z†à SÓ¸7Æ$d-ìôͪ—‡E‚Ÿ>A DßZîÿÊ#[š½êkfÜ•ýóÆŒ·ˆööV£®‡½nDm:5b·“°òC•yïÞÒY6sþšOþº³|¨÷ñi¸
xÇŠ‰P½Ð¡±'¨Oïúì+f¦Pg@@PÇyZ6.è»>û²4©úÓA–Psǧÿ(MÕf*Â^¶W}ü—:Ó®çv2ÎÊÅyö6/Ô7ø)mˆ=õÔ~ge;ksBÚøëxB«ÀÕ0t­üÝù¦jhXýƒ%³˜½ò«Fà@¿ð¶9K¿·¨ªÂ…Õ×›îû׫ýˆÔ]µ?¦pþ!»-—™
Õëß1Á¾(PüxùŠÐ—ˆU_7-ÎëM~•·^^ÅïvdV
Ò•g¸Ç˜¤¿ëî4¦V|bºÓ^ŽuÌjO`P‡šu¿@P6l¶¼úM#”X/:ŽyñÝLW굉ûä¹Ã¯Ãuh‚¨ã*uÛ7W™t?§TºÃCC“Tö„Ѻ±oÄxhY+N°‹Î„
C‚-…µÄÕ±|ûw”cNGa££½àlmƒ
Þjn$´	Z{{qS!ÎŽÙûsœ–k«ç@0³¸t©«‡ö;-ÛI›yšº´­˜ºjÔ)¬X§vÅ—ªZ÷~4©ôËÞ.pêÂlª«Î~thoìç›ù‡z €ëÔ>e‰7ùv]I…©Lá¥é´ƒ«kE;æÇsËQ_|ê8ÒÙo4$°óº|r£$
ë0<3«h°*KÜ­äXö×ët®eb×›Ohi’ÔV
+ò µ°§Âœ}R«Z´œNÿbe
®Å¹·|£s*Êž'00ÄØZaJ´»¾Ùó8nw½3;åw @X9pû£/˜n®ΘH`pà½þª€ª‚4¦ô¬Ô±?@§Æu™²„‘èö7¾'o>u¿y	Àòóýw£mëMþÄôIrfÿ_Œ†aؘ%råì3…‹ß:„™9Z>–…ï~÷Gf*¤Ã·NGÛÛ3L—Œci,l3º¤.ýéÈkiþ´á¦½¨ÿÌ—Í´#¦R»¦ø5¶a°A6¾ð5¬-3šD,AwL™ZÞÖ×¾%Yc—éôÍw\öÍ1Ÿý{€J+ÔEÚ‰pÛ«	ü×o6˺å|Pyõ(ygã+OÉk¿v¸q{gSÙª^Øwø’<ù«í:õ;$ç	~_ U4Óú¨³4‰Ö~ä‚öÁqÊÒ,CÕó¡	4çÀ—G/ò£Ü7žú°`ÊÐrÒw¨ÃÒšõ´h´§®Úƒ• {T0¹óñ¿tÑøõÔ@àvOÐÞbÚÐcÛŸ1‚ÅLµÂOˆÍ9!–õbª«Éœ%‰ã¸¥UuÕ7Ǽ8/9²ŒËbÁð;		€o°¦0ÍtÛ5ÇuÓwŽS–ö‡0Hô6?êÃ’Rh6,ÃÕÃÓ„V²Úiê²M¡^9ûŒ™yoaÜc´Vîë§#­#¬©bKH°Ž¡]×<uìµ\	Xç8~ZeYû]õÍ:îøÉ)G"üN60ØÂ’5–zSrÕ.Wû]µ^G±ÒÅñFâê|oÚ%pÆf¿q{SûØÿ$0uɧ;4*Ð}r´mÒx]Žm'F=Auº€c¤×ÞTŒ%eP"õ%Âkoêè霞¢«ºj—«ý®êótYWíêÍþ­¯|Ó8.ê͹<‡ÜI` „„*ÇÞwho¬©"û~wlSàpeÖá1}‰ô:.|ûÛaFWí/Aæ'ð6œRñ¶a{Œ€³H¯E9œº¶*…_»_8Ï։ݿ×(¦ÛÔ%ú$³ÌV箢˜¢¬:u§¾WÝ/¹ûÿ¬ªÒKÆmñ„¹sÍuïºKLâpãªñàlZˆ®ê,Â-ê@»a5ÿÂãgÆò/`wg‚‘?9ºiî<áê¢áßõ[³œvê’ÇŒå>–Ë9‹|ëj?ÚŸ”9UrN¬×ùñkt­º:¼:Rk‹h³phæŒ-ÏÐ&,ãKËž+mNü(Xõñ“HÀûPÃá}cÂ
ÇH¯ˆŠêÊÕ°U¥£_쇓›Ž(¦÷ÅÔ*ëñ±œ-Oƒ«!!¦â}àÁë%ÇKÔNKئ,ü”ñšiEWÅ1gn±íBX›#Zê‰]¿ÃîÎäÊMsç	W7ðpŸºø3&öËÎ7 ‚L»ËÈ·®"Ê¢ý'w=]]«.WQjE›u!Þ`áŠ}úMŸ5S)Ö´—U?I€¼›ï¶®#½Ú]
ÃØ®†;\n_«ÄÑ…/Ž`9[o£˜Z%aîëÛ/ŸÚhv!¨Òˆ‰·tÍuÖʯÖÎù=9‹p‹ãhלÕg‚5žv—Æ^颽;7Íöò×¹¦¨7КÊ+ÄéŒÓÈ·=E”uŒ k¯ÃÕ¶³h³.#ΰöpã>ïÖoMŠ«r¹ŸHÀûPàð¾1a‹‰@_ÝñBx°œê8‹b
×Àpû>jA—–gOZm¦3°šoøˆJ‹ˆ«»ß}R½ü囵÷:1Ò™ÇýÔÚiE¸Å´Cƒ®‘·n¬×;Û¥KÜàâÜžìnš±ìÏî¦Ù~že˜t5újKӵȷè"ßNšÿW]"âÚ÷[eYå`ºÉ™öÁY”Z,«C„Üààp!7÷Ì&|=–ð¡;[ôÇZ®ˆ~ÛV9ZMà'	€ 
‡›ÖxøÁîöpÇ{îðkÆ>`»«a{M–ߘÄlûî.ÛЄœ;øª‰‚µëX
ã¸&ç ZãžuÿªqVVšµÿVÄÕ‘“o“:}ÑS²"Ü®úį̃øô¾?ufA<”˧ß7mÀzýøä1ǰú-7Íx@Õ1x8&L÷@k«eb¡ vÜ)ÆN‰L|?ãl¿cyξ£\+J-bØ`j$F½ŽÚ£ÍÂærS³ç8e£Ñ;aGƒàsˆˆÙÚÚä¬*îó+*f{h‰§_aîgg­W!
ýÉìÞMÀéÑÒG.pêjØÞ»߉>¾Ë»´q¸ã:Š)íi¤F¨„VÊé,â*Î7oì(ÛJW]C`qán–!èÛñŒìYû¤DÆ¥›é!V{]ºi¶êÐOÔì•ÿ¾]½éÔÑ7çCW‘c]í··ß®²ªr¥ÖY´Y†˜ç!v’ŽÅ–W¾!¯ýò.û!<2=°ªà§ŸQÇUõÒVwÖ·z¨AÛê}ªO¸³e˺6÷©aõýÎôŵ9œa©ŽßØ<€PoÜñÚ]ø:ºÖ§´™Îè(«ÉoÚ£˜:ަEìšL
8F‰µ×üŽ®‹a\ip‹s ÀÃmE@:+9–åè¦Ù:ŸÖ¹Ð>À&ÄîoŽÎ"ß:Ûo•c•íøÝÚï¥íG?#äâ|¸~Æ4˜-궦V\ÕaÕåøI׿ŽDøÝ	¬ÝzEn[šéMëw›¨áè7Bàí`˜hOŽSöcÖ¶Ýó¦£Ûaë|ÂËeOž.íÂòt±ÕÐðôHö:ðÝØCØ^Þ-" àÚ%k	Ö1äCr,˲{è8Úõ¯u.„Ç­…³àjÎö[åXe8~·ö;s8„~ B®crÆu[ýqU‡c9üN$àl:\ïh[A$@$@$à{(pøÞ˜²G$@$@$àu(pxݰA$@$@$à{(pøÞ˜²GC„ŒY/í"­uÝÌúšR©*»tÕ¯ˆëóx„HÀ¿	PàðïñgïÝ@«-°ôÔ1ÝhdWÇüŽßíQn
Ö÷cÛŸ‘/|N¶¿þ]yë×U/¥>¶Dq°À±\ðC8üpÐÙe÷(Î=hüGv­žˆr;zúÝrÇg^”Û>õ;Píäž»›,ŸH`ˆ¸¶Ænˆv€Í&o!G·ýŸ”œð™¼àaõÖ™b¢Ä"z,œVM˜û	I>Ëi“EH-Í?æ4º¬c„ÕÓû_ìåv²ÆE±œkÝök©,=/ÑqÃdÊâG%&a¸\:þ®	WQ|N m3ãžNeð`zîЫê¤UݶܸmwyÞH­„2Ë
OZ_ùI$@]PÃÑ¿@ß	4Ô–Wænÿž¤dN—ýf܃#ˆü\LRîÆ%WRE—uaÕY”[ÔGY[_û–‰T‹@o„¶¾ú-ãXÌ„Œßú´z2&#&ß*{×ÿDà$Èîµ?4îØ'ÍØL“tyÓEÐlœ9ðãvÜYÿ¸H€(pð7@Dn͇«pQYrޔ͜WÅ&T/ž&ª«Ý¡˜½ZWRá=<*Iµ&'UXˆÖpô%Æõ8<àáôš2êšYf,û¼F¦½6u‚yCm¹Y6Š)g	
L•d[n4Yã—›é+º,—Y.ÑíVGLZc"¬¢L+Ê­½Ž(µÙ€@pùô{¦ÚœÔ]y¢
ΚaöÅ%h-GBÆÛ#ÏÂÖ#M5–Ëö8R‹è·L$@$ÐvG‡ÇHà³D§9^–7Ÿ~@ÕJ˜ÊÀCÿÕ_Þ)óo}B²Æ¯¸zq];¦dŒZä4Bª³è²Î"¬¢@{”[Ø‘ A³1gõ×eÿ†ŸÉÁ÷þÓL»Ì¿íÛñV4.‰=ÚjG[$sÜM’sj£¼õôGÍùãf? ÝEž-¸°SΪ)BÙ3‘	€+ŒëŠ÷{%¾D‹ugGùÓ)°ÕÀ¶`­½­UšÕ5¾Ks#ž:‹ê4º¬‹«ö(·öŠ õ@DX{ 7Lj´ŽmvCçc:5+(ÏYäY”ÝŽ¾ª&Ç›£Åzóè°mF‹µHð“H [0,ÃPKØ@t:
Øïñšü·'k*Å쳢˺ˆ°jÕmÏmDYµûtFÕ‘Ö±-WH­öý(;ÀË…
{{¹M$à´áðwÖJ$@$@~E€‡_
7;K$@$@ž!@Ã3ÜY+			ø
~5Üì,			x†Ïpg­$@$@$àWè‡Ã¯†{èw6&:L¤©@šZ®÷>ô{××J`p”´µÔh΋õµd_Ê»/
'û2	ÐÇ46™,N”Êå‚:Y17M6í.±Ù12u\‚u˜Ÿ$@CŒ€/ûáà”Êû1²¹$ÕµÍòú¦Ë"w¯.±1¡rïªlu ¯¿wYêê[ŠH€¼Š•Œ^5l	ôL`ï±),iÛ—eJpp×w†I£ãeLVŒ¼»-O2Ó"eö䤞ä$@$à]ïVn¨U	ô@yU£¼´î’¤$„Ë˳®6¬RCCƒä.h=¢CäÕ9RΗ›@IDATUÓlâ'		xŒ5CÏŠI ÷v*–Êê&¹ou¶Æg±ù#麗±Ù±223ZÖ©¶#Y…”ùÓ’»9›‡H€H`p	PÃ1¸|Y:	ô‹@IyƒüEµÙQrëÒÌ^V¥ÁA:õ’%éIáF;RVÙhâ'		¸•5nÅÍÊH wuëþ"ijn“߀VÃUéÙâ%+=J6ìÌ—èˆ`Y<+ÕÕ©ÜO$@ƒB€ŽAÁÊBI ïò‹ëä¥õ92aTœ¬Z˜qÃZ
W5c*æ–ÅÃÌ4´&Ee
®Nå~ pÔp8RH}#ÐÖÖ.›÷ª€!ò‘[Fô­^ä–)÷¯É–÷ÔoG
!ËÔ‡BÌ3‘	À` À1˜tY6	ô’@N~­ì9R"+¤KbœzSäãæR¬ZhSÏL"L$@$0X8¥2XdY.	ô‚@Kk›¬ßž'¹µr¿j5Ü!lØ›•’n´)ç.WËFµï€–…‰H€ƒ5ƒA•e’@/à!àD™ÚUd¨ÏŒÐ^ä¼Snš“&ðóñʆ™;5IF¨‘)		$j8’&Ë"^ÀÊ“µ[®¿°Õð´°a59!6ÌhYò‹ë–6ë?I€H ß¼­ßY	ôžÀÉ•rü\…ܺ$S¢tyª·&x'ÅTÏŒ		2vD¬·6“í"Ÿ!€)ÍÓ«äJa	K;î‡ï#‘áÞ{Ÿ¸QøÔpÜ(1žO} ÐÐØ*omνêWc„Wèܢ椺®YÞþ@ÛÝÔÚ‡^3	@o	 öV޾Te> Ñ—„
pðÑ©·£ÊóHÀÍŽœ.ØkܦžBÃÂÜ\{ÿª›5)I&Œl‘·u
hüÈX™<&¾27	€SÕïδqñräL…9þÀ­#ž7”wRÃ1”Gm÷jµ"¡âC4¢+BÇ5a©S?h?V°¼ùþe©oh±ñ“H`	>A|Q»1¢Ñ¨¯üRÙ¯ PUÓ$›vÈ4½iŒãmŒFœ¾X)'ÎW×ëÑ‘!ƒQË$¿$€¨Ðq1ž]&?Xà)pY–ëwv«§PDw½eÉ0A”V_OͺlÚŽôä™3%É×»Ëþ‘	ô“Ž~dv@È÷÷T«1oZ²	#ïoD.äv80»YÍÅûè›™¿)ûKƒA€Ç`Pe™~C`û"©QãÐÕÕu(Âkmm—õ;òŒÀ±pFÊPìÛì‡~ò?dëîs A™\hkk•ç~ñˆÄÇö/ÞF]3æpI¡Ý?Ø[(Kf¥HFJÿ.B—•¡AAÆ@1aþ²î’¬œŸ.Iñƒ„n!bS½@Ye½FdIH˜ï­HÜ-µç¥}Â,QàÈQaY>O ]¯:­ºD!ÞÖ½ëg¥Gɇ5ê,gÃBeéìT2ꊈßHÀo	PàðÛ¡gÇo”–¹n?X,Ëç¦	¢¬29'¨êéÕ‹2¤ ¤^^ZŸc„–2‘	ø7ß7¥÷ïñeï€^!μ…"Ø…ÞA…-Ð)£Zpd"ð_ÔpøïØ³ç½ p)¯Fö-•›¦¢©2ÝL9-Ÿ—.¥òò†Y8=Y0íÂD$à¨áð¿1g{A¡Ùác¡ÚÄŒÂF/ us
H¡Êɯ5Qh±ª…‰HÀ¿PÃá_ãÍÞö‚ÀYÖxèT¹¬Y<ÌDMíEžÒK‹g¥J…zR|ecŽÌ™œ(£²|×k/‘ð4ðÔpøÍP³£=@v„bGHvh5¢ià	À9´˜fY«Qhá±”‰HÀ÷	PÃáûcÌö‚À‰órò|•ܪnÉ•ið	Ìš,5*ܽ¡h§Ž—ñ#ã¿RÖ@$à1Ôpx=+öµŽë°)¸ou6…
7
¿}xõ©ol•77çJƒ~2‘	ø&
¾9®ìU/>U¦Œùê3b˜ ô:“ç Ôýªéòî¶+rôL¹çšIÀË´µ6Ë™/
H«ªËs%ïÜö)«/…PwÜjÌÓo…ÅUò•ø³´xÀ7¼…6ª½"ºÊ«o‡É3?ý«~÷‰ô@Dx°Üss¶?W!¯oºlœ‡ýøWïÊ‘Sùý+xrëê^ùæçW˼# 41T	ßù[i¨-“Ù«¾æ²õê]7HBÃûf][™/‘1ê¡70ÈÔ‘sj“\ÖÿãfÝï²ÎîØË”=ïþHî|ü%	
í.Û £À1(XYhOš[Z¥¥=X#Gôtê ´¹‚(+?=(u°Ð¾˜<&^Æ‘uÛóäÔù‘ðèYÞ–ú©«oê[‡˜Ëg³T qè.Ùò”DƦÉä…twšËcë~ÿiYõñ_IlRǽñÒñµ2|ÂÍ.Ïï逽¼èøa¥ÿóÎmíW™=Õéê8Wd¸ß-‹Ä-˜‡\%¡¡ArçŠáòÂˈâÀx,Cn}³Á•%礩¡ZÓ'ÊÞu?–¤Ì©’sb½…ÊÌ_’Ò¼cRpi~–Òüã²äžÈ‘­OINg$fL–iK>cÊ8¶ã7i¦8PÞŒeŸ—ý›~!­Í²oÃO%mÄ\™0÷ãRœ{Xæ¬þ;´]£¶žÜû‚\Ø%ª¡=õNVJMÅ9²íiYtÇÌy§öþQ"bR¤(g—ò eŒZ(ùwyDà 
‡o^ì	øÐÞ¢|b }¤ªJ/˜Þ”ä•“»ž—‰ó’àp9¼åTH˜#ñ)c%5k–j8VÍ\°l{ý;›.³W]Ê‹NË1–inª•‹Çß‘êòË2méc’{f³\>ý¾ŒyŸVFO¿[‚UZ×%ÒÚªS,i¦N—N¬“é7}NFMùì~÷I)/<-Í5*˜ê¤\©m¬­È»®<œ€¶à˜'¯fOPg$@$@CžÀ¬›¿j4#'ߦZ†<‰ˆN‘ðÈ3¥’”1EêkŠ¥¢èŒ‡Fª@pPÂ"âTóqÌôûæÝò„¤Ÿeʨ­Ì3Â
l7RÇKLB–jBª4O¼B)ïìV£õH6E²'®6‚MaÎ^—!üØËÉáQÉR[Uà2Ï`à”Ê`ÒeÙ$@$@>KF˜HA!‚éÇÔÒÜ`v‡Dè´` ¤˜¯ÂHºÙ‡ïÖ”r‹üZšë;‹mn®“ÐkhA!aõ¶·õÎ^kKCg½»iƒ7f5$@$@¾OÂG]u‘N…´™©ÑÉ2fÆ=2jêªÅÓ-HC3‚Õt‘j‡ÑÚÒhl>)>y¬äœÜhޡޢËšhA 
ÁŠ”šŠ\c»aUb/ûíñD¢Àá	ꬳG_zp¢üñ§ËdÎsgOJ4û`\ç>nÀÄQ±òÛ'—ÈS?X$™©‘@æMM’?üø&™­1[˜H`@`}´1#A;ai8ìÚ
e^Ô•%¯ýênΔ9kþ^ö®ÿ‰¼ý̃òÆS÷«­Æ{¦ÍmÊ1 )1e‹dŽY"[_û–ì~çÿIT\¦.¯•Šâ³æ´©K>­Æ§—µì»äÝç>%cÔÖ#5{¶™ÂÁJ–µ¿}DÞño4_†©™ìåá{eñ95z„M·'N©¸9+ì
í‹dÙÜ4Ódï±R“ñMdÇÁâÞÁsü„¼•‡GD¸È#÷Œ‘'Ÿ>bz¢>VàgŸL$0ÌRWÕ< ­ùä¯;m+ðÐÇw¤ŒÑ‹äÞ/¼%:×aŒF‡_!Yã–üs„†E›éœw×g_Á‡I³W~µSà˜{Ë7dº®X	Å4L€
K%W
JaëŸ)·<ü¬‰Ö
¦S`”j¥…·ÏìG>{²—×ÖÖ"WtIìâ»þÉ~ŠÛ¶¯µÖmU²"è™ÀþãerîrµŒQ3&$H«:7"VN^¨”#g*L`µ%y4-)\ŠÊdû5ÎÒ(¤H1QÁ²r~º$ŇKye£lÚ] U5ݯï¹E{L²3¢þÜ-£åùJ0æ0'ŽŽ“å:×ÔÜÕjÿþ5#Œ°ñʆ#¸NoÚ‡T@¡Àáùñc\°lE\ŸÑû#YVïkí8““›7JŒç»•ÀŸU£a¥Uã4fx´ù­Ó-~x¬ŒÌìøž‘Ü¡Õ@X›Âð4XµùÅ×–•™Œü㳞{ýœ±àÿ؇F‰£Ó0LÉ!-4g˜¢;x²\Žž­0BJ\´g]§û쀰c$`#@ÛÞG`ÿ‰235RRÞ`´h!Œ‘2Ó"Ì\|
'ÎWÈ…ÜóVû£§ÊÅ+5Æèô_þvŽÜ¶t˜9Ÿ|ŸÀÙœjcσi“ÛnÊìÒaØöÔÕ·HKk‡Ñ¶¶¶åª`—óù…H``	pJe`y²´A ÐÒÒÖeZ¤¤¢Á¨Æÿýw'¤¬òú€Z0,}âçûeÚ¸xùÎg§Ëüiɲv«g\ùÙÞ¾ ó§'Ë謮Ñ:K+;µå(.oTÛ‘¬´(µ©UA„‰H`p	Pà\¾,}ìÓe²“FÇË·›&ë5¢(–=ŽÒ‡ËW
ÿ>ußÙ´³ÀìÇÒv–2ù¬ZZ»åŠÜµrx—Nï=Z*c³cåÛO—Íjû3MW³¤$†›ß
¦X˜H€—ŽÁåËÒ€ÔÝv•÷Û\Ñå°²rAº|æ#ãM
eºüáÌ¡2Ç”ËÃê	Ÿ?¼uÁlóo°~ðÌh¥—Ö_’ªåHÕßI}C‡Ëé77çšßÍMsSåÁ;GëÔJ›`e˳¯t8U²òò“H`pPà®,u	|õG»»˜ƒú/gä×/Ÿ‘¤¸0iÔÕv?‡™riÕó,ߨåeö/•O~sK—U)u*d|釻ʹ‰b«V~õ§SòÔŸOKb\¨”W5u±çð²n±9$às(pøÜú^‡ìF~öÞáA‚¹xg	óõLþCÀq	¬ÕsKذ¾ãÓ'®~7öó¸íûB‚¤¥.W¤)È÷;ÛÖÔÔö#÷µ¬8®±à		€øÎ—o3+•ü¨Ë}îjDxÿ—ŽSàè3~fì/––vá{E)2?	@_	„…òØWv}ÉGÚ}¡Æ<ý"ŸX¦ÚÐP#!!žw9ÞÜÂ%‘ýÐAÌ+í5W¤
kXÝ”°ª©YíBuõÓÕžú{`TÊnê«!¯$ –Ý×L»½²‰l”¯ÀOm‹º‡_%³S$7Ïy€-w÷±XFe'»»ZÖ×y…R_ïþÀ{mú[…«tx+¥áJ<-%V¢£®ÅaéEóy
	€n‚’zÙº¿Hnšƒ¯]Ã'^­,™úG @Ýâo=Àßmÿ(27	t ÀÁ_ €zú}
¬êiÓÌDC‰4sp*‡/Ë祙èÄC©ýl+	x
Þ2>ØŽÜ‚ZÙy¸DVÎOïÜ»Ê.ùؽ¿§PÎH6.Ñý Ëì"	(
Š“…<8nÚU ?ƒeñÌTB!Ÿ"°]§X{忣V& Þ ÀÑ;N<«—ÎçVËAðºja†ÄÅ„ö2O#¡El7îÈ—9S’ddfôÐj<[K"@ÃCà}­Úf]y²AoÀ)‰a2w
W|øÚø²?Î	ì9Z"%êívõ¢,ÐùYÜK$8ø;è7Sþø¹JY³8ÃNëw,€†šÚfY¯Âö”±ñ2~dìj9›Jî%@ý¼}ª¶†ÆVÕjäIvF”LŸèS}cgHàF	‘@¿	?[!gsªeµN3FR ï7Oà(pøÆ8º¥N”Ê•¢z¹EUÆ¡T»…9+º›ZmGVj¤ÌT÷èL$àï(pøû/ ý¯ªi–;óeÚ¸x«š
& Þ8s©JŽž©0+Yb¢úâ»÷5óLð.8¼k<¼®5»KYe“¹YÂ=9	À@ÀB¬dIŠ“ùÓ¸lüÆ	2‡/ Àá£8}(¯j”M;dîÔ$1ŒŽ1‹ôC¯ÔȾc¥Æ1^|,ãùáOÀ¯»Lï‡ßyç·iT×:µ¶‡·P„ng"8­­íêú?_¢"Ôõÿ,ºþ8²,ÉÛ	PàðörcûŠËd³FÅ\<3E†©¡	Àย>;v*–69!|ð*bÉ$à%(pxÉ@x²¿
A£á·=9¬Û	´µéµ§hÕ	Ð'õ´¾6ù7¯œ“£gËeXJ„–6ÈJ½©%Æ…õµ8æ#`7©ð_VÙ(¯lȑ̴HùÓ;åàÉrùþfpM,ŽÜO€ŽÜÏÜ#5î>R"‡O—«fCä…·/Ƚ«²)lxd$X)	tO/÷­Î6×i‹®h9¦nÒwªq)	u8†úö¢ýÍêtè¹×ÎužY\Þ(¯½—Óù$@ÞE`íÖ+’_\ßÙ¨ç^?'Mê*‰†2ÚpåÑëeÛñvô®BIN3Ëï’ÕÛazr„ŒÓËx	€;	œÍ©’<[TZÑ(%å
æséìTÚ[¹sX×€ À1àHY 			€#N©8áw   'à÷«TÞÚxD®T8Xèê8éÁûæIdãTxf¿Ö?¼²[jêš¿"ÖÐoðWž/aaŒ’Ûo˜>P€ß/¾yPJk#$0 Ȇ“]hm,’{nNÇ
|c¿H}Ç…!nm,”Þ9›ÇP,7´Ñï0ˆ—À Jànø½
z­R1èu°ÏT?à¡Q‰žmkï–öò^Ç“üƒm8ücœÙK  ð(
ÅÏÊI€H€HÀ?Pàðqf/I€H€HÀ£(px?+'  ÿ @Ã?Æ™½$   ÀáQü¬œH€H€üƒÿgö’H€H€Qfèo8çäÆë®‰B½¦Îø‹´êÔÈÈÉ·êïôs
ã:qÌ¿Ó/ºün'Ìý¸ç–9«ÿδÓ•'÷¾ z½‡Êè©wª0²Òù­­M2V¯¤­¯~Ëä9¶ã™.åázεPòu
Õ.Äàž‚öDÅgh¶éµ?KÏ[,Çô~.³V~E¢ã3õÞðŠ^³›¤­µÅܧ&Î{Ðôåø®çô’¢÷¡zÿ'3—AD_œmþ/)SÍm\Ò(™Û·MÛðçâ±M¨ýÞ—Ÿ£Û~-•¥ç%:n˜LYü¨Ä$ïÌÃ
èŽ@`wyìzÍMµRtù€9PSqE¦Ob„Þ¤ð
‹’ì	«$4ˆß¿îš¨­È“ïü³ÖñpÆoµ«üŽ¿ÛªÒKz=·êï»ã…°—hBGMùì~÷Is-T—çHµžk¥¢œ}ÒÒ\wÝu€ãQ±é‚vÙ“Õ¼LYô¨œ=ð²¾düP&Ìû„^×A*xº!!¡QR«oLŽ	åâín´j#æè»*€ÌQ¡jºÑzÄ&0„ ¤úšó‰?&ï”ÛõøJI6Yª#æÆ™{æ}sÓOÕrì	Zž½QBý“e8ÝHßìeq›œ”y·­þuô«³}Ú¼ÉàFüîs˜éL¯Ø“•ûÐWÍ ³W}MFê[Þþ?7ocöó·»í›ãÉüN½ €ßo€^HA¸>­kÅ–×\ŸúûÆoÿ'Ìù˜¤ë©7ù!4£+5«Ö½•pvÔÛ»ëùZõú´Úm•ƒO{{p_€¶)00D¯·6Fi–õ¿LÎyC§{+tŸ9lþtÉ«ýl×é#çËâ;ÿIòtšeío6SŸ×rtÝÂõ	-´±Hx9B‚v‡‰zC€Go(Ýà9jËÍ*îxƒi×ùTLÀÎ#M§pŽ«ä˜?"&¥sÕK¼¾Y]Þo4$¸‰]Öùh¼± ……Ç©êö”±n¿¨ö%VByX5c©>ëÊÌ
0R˽ф75Ø…`.<>ul—ì…ª"Æ[ꂽH\ʳ²fܬûeî-ß蜊²g²¾¦XÛ¦œºé›=·I ¿ì×®Oh³Æ-7×gÖøåöR®ê±ÿnqÁ^ÄZ	Ÿ<ÖØ‰à7kS°ÐFâ%£\5¸na/{+ÙËÃ>ä³l¶¬szóY£LÝ@«3bÒšÎky›kÍ”îÐhÄ'1Ó˜°ÁZvßµ]-FÃc¯ÇÎ)Jm6 dÀ¦)çÄÕž&ª3Þ…Û$à’F]¢qu @ðÏ$óætu[wtì0êF\˜¯þòN™ëFí¹ýïÉ›OÝo¦`lµâ£ÿnÞ`ðÖdû€ºÒž?1}’œÙÿsó6f‰\9»EÞ~æA#´àf9GP‘Æë›Ùn]Nã0”¡EHP¹n{ã»fûîϽjŒ\ÃU…•i—Ô¥?y;ß²ÐN=ŽúÏèÜñ›O?`¦iºä×úÖýîQ#hÁmØøÂôM«LMëÔVåÁ®§ë·L-óÂYc—éôÍw\öíºŒÜA.à7Ûq]Ù®Uœ{õ7ŒMû5q×g_‘ŒÑ‹äíg4ÓÐÜ-ÿÈÏ‘áê5šlù»þn¿kl¶` ßüÔ%Ÿ–íj§õÚ¯î2×,ŒDSÕ.*&q¸œ=ôª¹/$¤M¸ªåì¸ÎìåÁ>ªR
ÑqÝwM]Ûs­Ÿ¸4µý£QLâþ©÷‘°Î" 
9¾ó7²GË£â2tzåïÌRÞ“{þ`4;0Lw\Ýcç„{ÇœÕ_—ý~&ÕΚ™Â‹‰zC @¥p›Ò­7Y|ëœG¾ö;©mËèVãàØcXã"3ŇÉÚm¼Å45V›)|GjÖ‡®fÒA‡ºÕ~>ŽÛ¿ÛócÿO}XÝùW	¨\M˜—µ'¬ÝG» ö´—ׂ)ì×ó±:Öó³oþ{Öûƒ.nlmú&„›2ìS°
Õ.Ö XÌ\ñe3Åbkø`Ób¿ùÙ+Æ›!Ž£$W}³çqµÝZ{^þçÉHJRŒ«S¸ˆ¸÷3OIhì—½°ÿîíÛ¸¦ðP¶„|û5Âpí@0UM„y€ë¾îòÛ·æšRM"¦­¦S¬ß5öãvk¦õ`/Ǭò°ýÆÿ~XßõOºl&¾v&{\{˜âèlëÕk'CcÃuLw ~Øi¼÷âWäžÏ¿n걦tq.ÚƒiY\ÏÎ’#'ô¡Ãn¥cZØYk_KÍ9ùÍ¿}Bbc"¬]üôcœRéÃà[½ÑJt±á¸&éC½†=á!l	Øo•ccÿnÏýXR
͆•ððv6pškŽÕ^öá|Ü,òÎoÓ¥y÷YEu~öÔëÆ†˜us²ßL;êÓùë®BPXd¼KayP–½W}ël(7H öß½}×”%l »uMXEáÚÁoÕúcwùí¿Û±³>l–ªZeá׺ýw}(ÛºØËÆ1«¼]bí££°sìyPv—¶^Øq
»¾þÀ«šœÕ‘p_øÿípÇy€ôN¢	‘ V, (J¢¨.J–ÇMîvä‰ËL&qœ‰íLfÇñŒãIâ±'‘¥‘íÄv'–#S6IQ$-‹EI‘bŽ÷ú<äÿ\àx¸wïïÞ{wïþå÷înë·åþÝýwWÖgùLyUs"•‹Ò.ÿfzx„c#z@ƒùŽF.”
æ|âžž•™¯jwêÑõûPÝóG¨H‡/#ኙUõ“âo¥_Â?¬ó´\v>z!ðäŽ@èEžÛÙ.yä¥-âRd…°AP¬6ÈOõ¨=cÃÜDÀªúI̬ôKø‡£;¡6Ü”÷œVcxJÅ'¶Å˜`L€	@€Žà±S&À˜`LÀ8Œqb[L€	0&À˜@Xà;eL€	0&ÀŒ`Ã'¶Å˜`L€	@€W© <Ú°Š6Ðaã|^ϘóÁ)Ð%09é9L×"¿;É1®aÏEÀõÇ'ŸÛ­íÓGºÛ(_,‹JϸÖç…þa/¤&GCFjäfD®pU)ɳÛ9[’=²
Ï~hÌžCb›ˆÍ3"£ž)¨kÅ]CÑ}ÁÂX9ÏQÑÅàûܨy"cg%àú¿šoºÑöz§àÝ+]ÐÒ>+§AÙŠ8p¬	»¯êšὫݰ´(Ö­ÊÒõ‡_2&\—oôÂíæAQ7£££àØÙvÀTmÈ
nÀì;8Â=XAŽŽMÂ;ïuÀàð8l*_9³[ŒKC†-’ÂTX_Ê‚‡äÂW&
tÄÀá“-“•8Gð—uó‰û‹ >.rF;BÁ•ð7×O©Ø;{ŒÅ®»wN^è€ìmÇžQjŠÿ!L4è?5n¯®¿Yð0Æ›m1@Ã~q| 2_j¿¨^æ¢ ²÷H=ìØ”yùà35#¾w&ápf¾‰X×6ÀÅë½µ ¶®ËØXß½!õ‡:Ù²WU²GÓ‰ñѰymŽ)¯®Þêƒ[عxt{!ÄÄj)&àL,p8$ß<ÔÏ8ß}ƒã°G ŠòSLÅܨÀ!=­GÁã,_™À¼	nÕ¾·a
…y³zUf<ìðÀÁãͰ»ªGXà°O^hÆ„š“¨:…o«Öç@zZ¼¦=Í
Ò?)x,.Háa]	…¯LÀ ÆÖ!8u©žØQ‰	-G§ÕgTiÚ“Vž±aN#À‡Ms¬¾eÎãòÕ©q°m}nÀÚêó8$ŠÏY\jË‚‡$ÂW& Oà$ŽHzƽp?*~ZiÞÑÇÎÞQصµ7,ä)+Ù²_Á%ÀGpùšöýÂûݸrd§L’Q‡"˲%PC&¤{lï^fÁCòà+PŸðÂþ£MPŽ£ËŠÓÔ¯-¹oï…·O·Âc;
qC?ÿ«Ò,	”=a`#@€V8§ª{CÝ}X»*S,QµÂ_¥V	ÒO)x£.Iެ\*¹ðÕÝB)P»Aº!kîÉš`ãîÜäÔ[M€«‰šðo×ãÓF]x6-kÍJž2˜Õ‡L&ÍQŸÁ<6–Y7"#ýç+p
šêèè…‡·…vª#XS7NáÎñt8ÂWÍíÃB"9)AsV&3’„`	2l)xÐTPîÀsË’_#)s¾+Hh)yÙòð(sJåTÚß#!>0åÔHÏ/N_ø°ÀBöWpßõP›$>ÊtvB¨L°™ŽÆ6ñ¸ÔE¸ü¦ZXðdø‰zûq¹ê‰fx¤jdà]á4´üö÷lD%óX”;¿å·áŒ?‡ùXàrÓtÉ©‹ÐJ^÷¢ÙŠÅ‚¢¶÷¡8dè,xH|T5·z¡¶qöà5»¤ó-T&MNŒ…Ík²í%ŽXàRA™€çÚa—Åm]›
Ù™‰A
ɘ·¡8d¬šÚ†á4îC@m≅¯&0}ðZ+,̈·ín¼7n÷C
îPúøý…ãûÈgGÝXà°8ÓÚºFðÛ…[ÇàAj9„=
;˜p	2í¤·B#=,xH"|u"êHìÇ•!´·F®ÍU#¥tª÷nÎ{‡Ç‰yÍq¶ž1½V×WkûÅ)4”Jý#I·À!ã(šc®,gÉ…¯ö'@ç™\x¿žÜYq:%Ú)%4óæ‰ÈÏI2|î’âÏq‰,,pŸ¤NÓ­#°ji:”â»»’	Ä® ;	*Q@cåRI†¯v$pôÝ6!dÐòu'šK×{ §7I¹Õn!'òä8Ïóà62Šú¸]éc™gó¡UJ¢Ý‰I‚¯v$0†‡&þ§P6—g›>0ÑnééîƒÃÕ-âÔY:2
58LïÄM}H!‡S·oÈ…ÜGÃ)Æ®‡ä×Ò1­ã‘#4%Å#’_ÃE€žiS-šB	ôàµp¥A.­š;€Û®¯,IÇÿáY1§Žß»‡òúfÃ\Æ=4âN [Öe;RëÛî‡Ì¥àQ‰½Jþ•døJ§.tÀ0îkñ@e~(ƒ
YXgp:s`hÜ™éHÈ8|à"ýŒ³5]ÐÜ6‚ç¤B9žWàdãC2níê‹bºÊŽJ¸2ž|,w^[½<=l{愊(	÷'ÎuÀ¸tÖ.«éB•v'<XàPq§9[:ß„–”ÑN™‘²cŸÓ™-,xH|
6š2ýéVxü><5Å:¡£Ò$ÚºÅ)ÁFÌþ»œw
@OÿÎ×v¢îˆóM"M©Ê©‡¬Ÿ´¨‡¹IA—G<$¾ZEàüûÝÐÖ9
»«B{ðšUñÔŸc¸
'7£mÑÙ0`p½Àq»yÎãÚúŒ´xQÙœ²¾ÞlpºÀ!ÓK«U£`˜“•[pWÖñdø:4uz^+ÆÞ==àfS‹ûŒP[è¤}FÜœ_NL»+ÚçüÕ¨o‚ÅxÂãúÒ,'æ©8GŠÀ!Í‚‡$Á×ùè˜>xm×ÖÈB…p6NÚI•óËy\%pxð\ZæÖ‹
ÍúU™(l¤:/ÇæãH8$†vñ ©°ì¬<³&‡G<$¾ê¸ZÛ'Nn&}˜˜ÐÚ¬)›¼”gÅd¥ÇÃÆ²…6‰G#¸Bàèôˆºp`¶áNá>F:'RÉ’I‚¯þ9Ù"Úþ˜ê“²ëi¸ú±æ·v&Ñ‘~îJ7¤¡Æy*CÅãjn5‘.pÈ|mï…“¸Ê(;GÜ5ô¾Š< 
æØø'@£Á¤ãB[¢»±£æŸÛ0C ":7àVà8™tcYïZ‰%Â-‡,üBðÀé³ì<$7_ëšὫݸçDÄÇñqífÊ)ÖRûQR˜
eËÝ­Xk†ÛK bÚ°çä…N óÊWfÀ²¢´¹©uñ·	2«;pÄãÜËpádãnTJÁƒV.ñÒÈ»ù8ùî4L6È“=I £Ógw .&`„€#ŽÉÉ) F¤½kJ—ñ±ÊF2Vm‡5‘é{<¨áÌ\@;Íæ²à¡ÉOùèõÐgSCëh›ŸÄ)–¯=yg†hcxtBL(1Ò³çÚaÔãçgäâvÖlŒ ]¿ü­“0‚ÇlÓÖFÔ‹ßó`1|äñ¥Æ=q‰M)xÐÒ¿*•àA+^¸ìÙ¯ \ÀóOzÆáþMyB™ñHu<º½"í$û‘¿;F#ØNï{»I(eÓa—g¯tÁÈè$lߘ{·E¾s=X; %k?ü¯«ðݯVˆ9uRð;u±%æhÑø''Ù"šv@e*¤Ôõî¤ø›Cõ€{žA|t<¾£È”n±œ™O£0FZø4?-I\ø·?8_y¾ÊïÉtG¤óÕ7ëáfÃvH&a
ó郻—°NNrŽŽ¶ÿà#Kà-WGOÝ´eÜì)jŸ>ù\%äfóV¡È—ؼ

1ŠðlÆØH?,.ÌXàØ{è"î#žoë´†*rSžnx¨j弎ú¦n8Rݱ	æ…•P¥ÏáŒöš҂ ‡_ƒš:ÄÄ&Ø!ÙaƒgtÒÓ’l+pœ»Ô‡«Û >!5¬œìø”§žx°ŒŽe”Øø+.!
bãø`´É‰qK°ÓËqIü‘$˜Þ¨¡€˜Æa¹Œg–ú½£úï-|KmE\|²…>:Ó+¯wÒö‹K†®;ºùä…Ý÷üÒZ¬Ãa-Oö	0&À˜Ð À‡~Ę`L€	XK€ky²oL€	0&À˜€84 ð#&À˜`LÀZ,pXË“}cL€	0&À4°À¡…1&À˜`Ö`ÃZžì`L€	0& A€
(üˆ	0&À˜°–Öòdߘ`L€	0
¶8ú»ê µ®Z#Êæ5ß:=
æF¸‹¦Ga¨¯ÅP*™á4&ïä8\?÷ª!fVXbîæ)ši;z¡ùæ	ó°ÓÌä‹iÏ5pÝÑ€æG†ŽC¿ø~üOéF•>\SóÜîwrÂÃíwùáè ¿»þ®gFoÔþQA¿zêF‡ÜÞ•“ÿgÿ‹n¸ê4éZÖx96ÒžÑÙm|=£ýP½ÿÛ=mïÃþŸ|¦¦¼.§Ù‘a°Ë%¥\]®ëß?
øŸŒ‘|MüQç³¹MŽ‘r¥N¯Q¿¥=-÷²í0~TT4œ~ã;@þ¸É»îèåq6¾ÙüP‡éäºc6íN±oHà(ÛúÈÌ]©›¦ƒ?ÿ<Ž"4êÚñõ²£ñ=8ú5ózl¸ZoŸ‚â•Î<3óCíùÓxý0î±ç	®‹–ß%÷>®›Dušt-k¼¼xô%¸ñÞofÞ4\{óôHNË…ÔŒ"X»ãO__ÆŽƒ].‰…º\ß¾rŠW=$0É7_<}=Wç³¹ûŠ»ú¹‘r¥N¯Ú÷j÷ʶÃHø©‹ ÿ7ß<æ/¨ˆz캣—/ÒHøf«ÃtrÝ1›v§Ø‡·ù‹,UÆä¹ØƒËïüR3Å0dV>å~ÿáì‘…Éñ1x÷Ð÷ oÉ&(Ýüq1¢ÐR[
‰)™pï¶ÏBFÎ
¸}å
îÞŽ›Øsl†åëž…åpéÄ+0Üß&„ŽU›>£C]ž½’R³EÔF;áÒñW ¯ë¤¦/‚{«>i™ÅpêÀ?ÙOÏ^
Ý­W¡öò>XUñÑ9þåoÖÎÆóP°l›¿ä†ü}_çM1ú@<Ïü.,,,‡úš7ñßxXÿÀWPˆ™“¦ÔŒB¸xì%D!/« ÖlÿˆÅCµ´Üw5_Fî4ú]-W`û3߆6±*XºU¤u|lnׄÂ;4óhÉêG %½Àv͖˲­Ÿ†–[ïà”ȯa§FJÊ…¥åOвg¤\¯ÚôQèh¼ÿ¥àæ/ߨŒú*»µ—öa<°bݳ¯c¯}ëÍ'æä³ÝË®^eñW®´ê~váͶC«\kÕeÛAmŠ¿rMñ§zЂӷRÔKS¤¼³ªîÌ·M÷>^î«îDJ»)eÉL:|wi¾t6_ê9Œ{† îÊ~¡±æ¾Ĩõ”W¬ÿ€ø˜-[»+í.¸væWBÿbÓî¯Á‚¬%p|ïßàpýö6áGòe Fe	6ögÞü'ÑѱЋŒ“Óó…Ê×ö†sâ·VÉ‘!»14[.pŠîäþo	!µ´òy8÷‡  Zc¸\÷wÝÆr7‰‚×tÙô—ozew §Ð?iÚëßÅSX“Wveüµ®þÊ•VÝ÷Õvh•k­º®l;ü…/Ë5Õ£¡^{¶
Z\­xfUÝÑj/´òU™/áëÕHi÷­ÈG§ùaHàP&ŠzÑ•ü5PÏ‹F3h¤‚F/H€¦]Ò2‹€”u¨Qn«?1±	؃l£äÏ’²G`9
&ôŸ†ð=(È,XX‚½ó$È^TŽ‚GŒôAâÑá6èm¿ëv~YLÐÐÿ`o#†ÛªŒÖÌï˜Øø9þÑKjœ(N0ú3Á¶¤ì1Lk32¼;M“£‚	å
#&$¥ãÈÅ噤©Ý'¥æ@br&
]y8¢t¯°çAÆI)Ó#H3ïüPçÑÈ@‡xcg†FÊ%õb“RâHY­šS²PHæfÄ=	·tÜ7i5w³e—zuêº@áØ™»_ÏÔåJ«îëµj¾êz¡n;ÔñP‡/Ëu"Öƒ!‡´
ê4Yqo¤ìëÕ5W­|U¶éê8k…o¶îh•
'Rꎚ™SïM$$DEE‰ôÆ`©¥(:1>Ñø‘ŒÁ)j +v}U|ÉQ4
&ÒPK=Fµ¡øäø¨xêŽ/®J†zù¢¾n݉Ðv_É+R›8|ˆŽ¦9·¦ÄˆÇÔäÐéiä•Tâ.ΗSñnt¨¼è†Ü'¥å̬ZIA
2®ýA¸¯¯9„½õ,œN(€„Ät±Â‚´“ëP?D
Ké=ÁU0$í:Õ(Ó””š‹Ãï)"=Ä—ô2r–ë&ÜÓJ ) Ñ¨‡ze®øÒ‰Õå’zZE÷ìå²håN¿eBé>Ë%é1)Wûè1Ó-»8*Õƒ#w$°Ó<8ésQæ3Õ2Nä."nà:½4ZjUÛa xa…ê“Û£é4kOYö)_ÌÔu¾*Ût#ñЭ;.j÷°r’CIÿÓ½¼â¿C«îŒv.ß.t-Níÿ¡$Úß}~÷Òáw/NøÎ´awÖý´_Qbª„„Š×þí)hÂÕ$Yù«æüèãH#¨§pöÐ?ÃÞŸ…‹Ç_†Í}CôÐWbõÒ‰ÃÞ=^ï8†1í7­¾PúçõN ŽDðw&îvú¡à2ËzºÇ,{^Ê4µÖ¾¨C:0û~ü<¼Žœ¥@æË=)ÆÕá
‹ßþûd5,J¡•w§"_q¡w2ìÈp6­Šø‹Hk—Ë‚¥Û„Òð¾Ÿ%½PèõvÜ P(ƒèÏŸèÕzrgP¯ì..Ý…
ÔÝ¢Ì×^Þgpn]°#w‘HC\¬#ÁmnzIÁ\«í˜ÍëY¾ä²^¨ÛÊú'ŒðéÕjsÜdfy*mºÏºãƒ«~¾Pb¸¢¾h‡¯Ww"¦ÝwS»“Ö¨/~ã¿§Zû3f¦,´Ðȃœ·VþÓ)3DÏt1høŒŒPÊÂ^Z*†’½k¼§¡i2j¿üT…â©ùŸxá%ˆK[eØ¡’‹2
äòžxKFôŽ˜Ðþñ	©¢g¬¶¯¾S"$Ä¡îí«r
÷xú¯
&2e\”î2$7Jãm‚o~il(/V>6ôûÐÑøáÏ/@|Ê´r«Ú‘Œ3=WþöW.©ÌÐm<Ž2PÃgÆ=­– ÞÛÆ]~WyV†?Ç?e—òâAõCé^™ÏF¸
µÂ—ž_î,i	ÖŸ¯}û5¸Ñ’€£kɆƒéòU®È#ez¥Çê¶Cú#ß+ï•îé¹²íö|…OÝë?zªžþ{È)Z/½÷{Á•t{væÁÏo÷k7^üÙQØw¬W	fi/¹ÐKåo³uÇWò×H¾ø?Øí¾w¤¾õ»`õ=6A&`h„C
åod¯ŽÞ‘Ò–6èž&)lнèùÝ¥Ã1«|G~‘¾
cùÚg€v”†>¤©64’!õ;Ôq“þ5Ýx–¯ö®¸ªý	ç½’‹2
'å½’½#&¤J@i”öé™òž8‘°A&oI%
*i3J“Òž2.dO>·#C7e<é·¿rIe&!9Cð#ûd”~é¹_±á9\~\¸Q²RºŸãŸ²Kù'ë‡Ò½2ŸíÈ]$Þà™.%«9|u_z«n;¤?ò½ò^É‹ž+ÛiÏWø¸Ì™¦Í2N¾J.”åo½²OvÕuÇW²k$_Èž^ø‘ÜîSÚÝff¿ø6KyùöÏO÷ -ˆõF©b°™%@ùÑOÿÔ0f8͎沟|áfAùs7ØLÛ‘»x#ìþÄËæa¦	˜ÉÓžk8ຣ%Ìlý¶JH°ÊŸ0ç•åÁ›ábƮ嵙‡¡dʰl†9 è˜áfÆn@‘bdž;8V â|µ‚¢µ~ØZà°6©ì`L€	0&.,p„‹<‡Ë˜`LÀEXàpQfsR™`L€	„‹á"Ïá2&À˜p8\”ÙœT&À˜`á" –ÅŽ
wÁøÍºÂ;„KÇi[a¼Þ)ÜŽºÍ
¯ï‡w|( 4Œ¤—YêAœðXSnõÂ鶴˜ðÈ[×^Çǰ\GMŸlWžÑ>Ü|‹v`fã‹€w|Ø×+~±ŸùÐfhíàD²-_µHþœ÷õ+Ÿ¹FF¹¢K€Å‹2åOS×ÒùðÅOTšrãVËe+ƒ¿SâÇöl„ÆÖ^·"ž“îÒåö8¬Z	‹òçn”8'ü
r™S¨ŠÁÿ§DdR„Lx¦IEND®B`‚loki-ecmwf-0.3.6/example/gfx/intent_inout_map-crop.png0000664000175000017500000011333515167130205023211 0ustar  alastairalastair‰PNG


IHDR^|­µ&'iCCPkCGColorSpaceAdobeRGB1998(‘c``RH,(Èa``ÈÍ+)
rwRˆˆŒR`ÆÀÁÀà ÎÀÌÀ˜\\ààÃ0|»ÆÀ¢/ë‚̔ǸRR‹“ô ÎN.(*a``̲•ËK
@ì [$)Ì^bdo±Ó!ì`5ö°š g ûÍ—f3ìâK‡°@l¨½  蘒Ÿ”ªò½†¡¥¥…&‰~ JR+J@´s~AeQfzF‰‚#0¤R<ó’õtŒŒŒ@áQý9žŒbgb€›#ÁÀà¿”åB̤—aÿT„˜š!ƒ€>þ9É¥EePc™ŒñÛJ3gŽFQ8eXIfMM*‡i ^ |™µìÀ@IDATxì|\Õ±ÿG½÷^lIîEî½wã`j)†GZꃼ—¤A’ò^BB:Áü©„@cÜÀ½ÉÝ’e¹É²¬Þ{×~g}WW«]i%­´mŽ>ÒÞrê÷\íÎΙ3ãÑɉ$	! €@Mm#9_à=‘.8#)ã)4$À».}v#Þn4VªN 7¿œžþãòsðžJ÷@sc5ýâ±õ4ub²£uMú#ºÁ«9BÀÞü‚É+ ÁÞÝöŒ€µ;Y¥»îJÀÓ].ãB@! „ÀpÁk¸‰K{B@! „€ÛÁËm§^.„€B@7¼†›¸´'„€B@¸-¼ÜvêeàB@! „ÀpÁk¸‰K{B@! „€ÛÁËm§^.„Àpèho¥'þ5lÍÖVæSÁ¥ÃÖžÖPiþ)jiªÕN¯uUùT]vÅx.BÀ	ˆà厳.c.LàôÞ¿X%ldz•Žï|¶WÍUfˆ^énš–Ï;¿‹®ñ/’5í몲ú°¾º:;>­<<<éèG¿ ö¶«Ë÷•±²ø<}øò=ÔÙÙa1뱿&ä3MW³¶ÑÅ“Ã'xš¶/çBÀˆàå³ }BÀfªJ.PCmqŸõ%Ž^L©“?Õk¾3{7² ðN¯yz»iZþjÖV1~¥*bMû½ÕméÞ¶× hº‚Ã)ˆ.í³”½ß׃Óiê’¯„:IB@ôŸ€x®ï?3)!„€“¸|æºröCòöñ£+¾E¡Q)ÆžW—]RÚ¬Èø	”±í—•”Nyç¶“§—/M_þM*/Ȥ¢«GùܛʳhÑí?§¦ú
:³o#Õ±`™0‰¦,zPÕ‘yð
ŽHRš6Ô7mé×(/{g·òó×?I¥ù§iÖêï¨>ôÕ~HÄ¥¹ÊÎø;]9LžÞ¾4*ýÜVð˜¶P{{™v‡ªkߦ'T½™_¢öÖf:¶ãŠK™M“æ‘ÒæSaîa£À‡¹YQ]ÕuJ_x¿*éÔ{ÔÑѦ„©k9»¨£½’Æ,¡	s6PCM1e|™¼¼ý”kÞÍ?¢«ç¶©ûå…™”}ôo*OhTªâæ®ê»¬Ã¯2ûJ_ôEÄŽU×µ?
µ%=Xzûj·åU¸,ùÊâ²S+îM±û°”8kõ·iò‚ûÉóûÛ„§šrƒ½QYÁYÊ>ü÷° àOX®ŒK™Eá1c(6y0÷‘§§7íÿO3¹ÎÊ’ÊäåÊÖ–zd>d-Ó5š²ø!Ê¿°›®å|Ò£<Úëìl§À851}µLº ä@ÔÆÚ¹#=ÍÂO·•GµåW\’wŒÚZhÌô;• 8jêm,h­R÷ÑßúªîǃÃ)çØ?¸L£Ê“sì-
Š¢¦†J¿¡G“5å¹7Æ·•Z›ëhÂÜ{¹L•\;¡ÊAxK›|3A¨„€vñä&cŸ
/dÁîA
äöm~ŠÇÞi¼‡s,»e!ࢺ¿¹è eXB@¸P¾þ¡lÌþOt¤ÞÒŒ•*-Qê¤u¬
* €àòŒ`Á!Ž¢&Sc])aZ™Òü“ªNh|pmÎÚÇ)vÄUG}uÏò-M5\&\	FæúaÚ>ò\ÜGãgž¢'ÓÈ	«•Xœ—a®¸ºAÑÃÓ‹µKã($Â,Ú?(šêkŠº•‰JLWcƒpA¶h	£ÐĹ÷°¯FiÃ|ýƒyÌeªœ—·?Í^ûk¹³VÌÃXWÊĵäÏ[EQ6ùèò#´†1ÉÓh2kÝꪯSsC…±\o,™ä@¸(Yjtщ•a	w'€¥±Õž§óoÒö7¦¹ë¾OÉc—ZÄ¢Ù,yzùÓõ™¡éAÂÒòƧÌUÚ\ù&xY(HÓ0¡Œi2×~+k±||ƒŒY½xÉÔ`8ïÑ«q»±´·5û¦]G_!ÈÁØiòØejÛþú€ZŽŒŸD¤SPa¹Õ‹—:MÓé½ÏSÁåýª.,Oêiãñb^HzWo,Ufù#\˜€h¼\xrehBÀ	(w,`@“”2q
•]?ÕoÂ`‹„|X"„M£§ÝNiéëy)rt¯uv/ÃBPs¿vI†GQ¶bZÐ,ñAûæÆK”v•íµ`ï¥%OO¥Ó”CŸMÓÈ	k¨ˆm¿P~$ó©c-–0¡¹¯Þv-ju^9¤4ry	©>a‰ÂX>/»BÓçi¼=–ÆÂr œœ€h¼œ|¥ûB@˜¸¡}Â2Ù'o?ª„”Ö–Z|û/ºgTKf†e3h€4
^{ÃôýüPÙmÝöÕM4kÍw)cû¯èÔžçØæ©ž&λ—âFÎ&.ÝU7vû©ºI¶kåoýÊ»j鳪ô¢Z’4äé½ýôEÐ÷Hï=w«„`L;r&…DŽ ‹§6Ѧ?ßBqãÕ¹QÕ‡¤Ñ‹hß{OPò˜¥CøêÒK?±«7Ž°Ñ ,z/Vò’àt®¿¢“¦Ò–—6_ 2¦7dÅètãSg†óÔÉëèÄÎßRÛƒùñ²,–v‘Àß»º™mã¼X˜{ÌÀ÷shÏ̱7ónC“òW¸0þV¤S(»ðHehB@8<Óçòé'¿ßM^û¤tذ;ÏK}øã¸¥‘m«Ã
üº
•F‡ßþ°Í–Ó´¤?WËb¬ñÒvÜá-6Q¾~Á\ÆGÑçWK:AN_»'Qfæªÿ2h”¬h
@ÈÃ2#üµ„~´±@éãÔ£ÿÐöyû4P<ÿiZxëO•p¥•Õ^Ñ×NþÑ׋M	  ˆi×õãCYý9–O=<¼,˜“ž'ìÅÀM«GÏõ˜c‰ëIíùôä#ËhêÄ?;iWÊþ¥Æþ“üB@84|ÈköV8Æ—¦ÍÒw\i¶XèBÒ]¦ç0Òׄ.ÜCÝ0º×„.Óü<ôíéË™ñiv9±ETäEê­}܇p¥	/8GB?pÉ´¼¯ˆÊ_Æî+°I-s	í›Ö‹eLÔ­¿nZ¿þKŒÐ`¡Œéx ÀéëÑ3G̱4×O¹&\‰€^®4›2! šv®è­aë#–%×Üû°µ'
	!Ð7¼úf$9„€6# ×†Ù¬Ò^*îözéŠÜB€	ˆà%B@! †‰€^ÃZšB@XC.#LƒZ—æŸê—
´SQtNÙ“Á^’ŽC@/Ç™é‰B€ö½ûX ßÇvüZÅI´ÂíåzàK^â%	!à8ºö&;NŸ¤'B@!0ˆ=‰p>Û#IÇ" ‚—c͇ôF'"€ Ñgö¿@Ö?¥zðD!1”8jÚýGª(>OaQi*\|Šeyƒ
¯f¸ûË*7¼ÙŸÙ·‘ʳ”3ÖöŸe.]=JY‡_UâÓ=ÄñÇ’¹:«Ë®ÐuÖx±oØz6;}E°í"nדÝ>ŒJ¿…h¯PA­3¾¬¥Vr?—Ýý;œýfþ™ë‹\B`àd©qà줤nN µ¹Žfw…"ª.¿Bõ`K|Õå¹´è¶Ÿ+ïõÀ”“ñ¢‚Phd
ígôp š}ôoT’wœ¦.ùŠZb4Äbì	­Ó>¨âCÚü”*k®NøìBìöB?vÆ]JèB¦.ù*¥Mþùèi»µ¥žr³¶*ál‡ü¹pâm³ýëÙ¹"„À`ˆà5zRV!`†@Ph â=çe(-SCM5Õ—+¡kH˜æÜô;CíŒ`Ù0&yMžÿEŽ«xCýTp€êžu*§¤AQÊqj8kÅ
.îSñ£'«`Ö±É3}@òòö§ÙÎ'iÌb‚`‡ЦýSå6# ‚—ÍPJEB@¸#ÎŽŽÃŽOKoù)\9H[_½CUBë`©Ï‹+žµêÛ*Æ"®ûpX$åý½Gm†š?.¯Á¨UÈ uê«hmå°BÜ[K=¤iÕà^ç‘,õO+'¯B@؆€ØxÙ†£Ô"„€€…x„õÕ…*¶!–C“©¶"Oi§bGÌ 6~šj+¯){®Nví0jêmJÀªgÂ…ð²#¿’ò²wP{{‹Y’X.„V,?ç%¸!<᛫S_Axô®w'kµ–ªŽpW1fÚú,êØšºz’B@ô›€^ýF&„€¡qJ`ÚúêÉã…'ò
µÛ-h·Â¢G)›+ÓøàG´yã]*¦apX"-ÿìïh"ÛWí}÷{ôÞŸo¥°˜Ñ2RÕ¡gMXCm	½û§›9ö¡ÍáåAhÀ,Õ‰²(ƒ”¾è:Àödï=w+‡,tÝA%#|Ž®¨òàOou3Ƀ&àÁêêÎA×"! l@àô¹|úÉïw“W@²
j¾*Z›ë9v€±AEm-
¼ë°°ZŸZù:K@Æ×¸§-bÉOõA¨q_»ívëOã¾ij)A«¹ZB±Ì¨/«Õ«åÁ«i]ú{Ž|ÜÞ˜OO>²Œ¦Nt®gÇ‘™J߆†€h¼††«Ô*„€ðñ벡҆
É\Òì¹ô÷ Ò®›
]ȧ]óe­š¹¤•Õîy°VÌ4™ë£V¯>¯i]ú{r,„Àà	t}|]RƒB@! „@/DðêŽÜB@! „€-	ˆàeKšR—B@! z! ‚W/pä–B ?ÚZ›¨øªÁ9i]U¾Ú9ØŸòŽž»"á*;,%	!00b\?0nRJ7!ÐÑÞjëÃþ¯ x$³?,xš×Ü5è1ÔWÐÁ?Ew|}3]ÍÚFMì]~Öêï|gey]9RNœBÓWü§Ú©ØÜXÅÞâcõUô8F¯»#{d‚¦mo{íËÊ<±%wüŸÑðš—*…€K—KN«J[8°ùÇ”ÏA§Çͼ[ýæ_ÜC-ì‰ÞÚÁmÿ{ß瘔ïõpãPšRùï꫞3{7ÒÅ“ïô•mHî›¶½äÓ¿¢µÿñ}êË£êÒKPûà´+•
W& /Wž]›ƒ"PQtŽŠ®¡OÝÿ!þ"RÒ˜%êõâÉwéZÎ.%DáÚ„9ÔuÓ?ðŸ…pþA8lm©§Ü¬Uè¢Éî§‹'Þ¡#[Nãç|AyºÏ<ð’êÏþ÷ Æ2sõ·©²$‡2ÑŸeÖhÊ⇔vï‡%2m[s¼
MÝñÏrø¡2Î3Û8N9BÀ:"xYÇIr	!à†Z›ëXcÔå‘^`âÜ{TœFL¾þÁJÑß×Ï_ÿcš±üŽ>¢^û’Ò2A«o÷Ѭ	ƒwû”‰k•©¢(›|nÔÃ!„"¡‰¢&Saîa
`MSuùÿñË2USpØ:{Í÷XÓ5ÂãÆ©@	ióiÄ„UTDZ$ëJ©ŠIäƒðäFå…]eç¬}œ )ƒ0[5Ó¶µñä[Å“¼é¾—Tß´ëò*„€ud©Ñ:N’K7$Äñk+󨽭™—ýŒ`·µýõ‡xI0…—æ&qÌã-³Ý“:y%[NÛXðº~q7/ýumszïóãq?œ°Z-_š«K–ÐÕ—õâÀÝ*äÙ‘j;vÄL%$ZÈ"—…€è…€h¼z#·„€poñ©sÕRÝÉÝRvZˆcx>ãïTQœ£2h‰R&®Q»-‘‚û…Â+‡T<Æv¶‰‚m–/k›£¬ZfÔ	n¦m«eM­	v(㩽1Ÿž|dMèxÏŽ#ð‘>8Ñx9Î\HO„€p`šÐ£ï"Ô‘<<ºÞJ!t®As¤Õìn4M¬4¡÷ôKŒÄ‚“–¼Ù%„>AøÓ€¸§µ«ŽY@Ô'MèÂ5c0Ø×'}YMØÓî›¶­„Bí¦¼
!Ðo]t¿‹I! „€B@þÁ«¿Ä$¿B@! H@¯‚“bB@! „€è/¼úKLò! „€B`€º[`°)&„€°ŽÞ
È»ö$	þPÏM
H^!`'"xÙ	¼4+„@OÁAþáMí¥=oºù•Ö?b‡äEmäíÙìæ4zß+ЛðüHŽN@üx9úIÿ„€pku
­´ëpMH¥q©a”“[MÙ—khå¼x4Øß—$! œŠ€^N5]ÒY! ܉Àñ¬r*(i ›'‘w—Ink[mÛ_@	14sR”;!‘±
§' ‚—ÓO¡@W#P]ÛB;ÒŒ‰‘”–bqxWòkéĹ
Z9?ÂCº<ê[, 7„€°;¼ì>Ò! „@çK©¢º…Ö.L$//ëû®,ÝŽÚÛ;iûÁ%xÍŸÓížœ!àxDðr¼9‘	!à†Ê«šéã#E4j4%Çõ›@~q=:UF+æÆST¸_¿ËK! †‡€^ÃÃYZB@˜%€ ÕûO”RcS­â%COϾµ\f+â‹ÊßßדÍŒ5ɶ”_®!0üDð~æÒ¢B@(.o¤½Å´xVÅGØŒŠª÷X	-fáË–õÚ¬ƒR‘pc"x¹ñäËÐ…€°h¦ö+¦Nö»|nÜh¦ IÛÍBÚX6'nPš4ûP’V…€kÁË5çUF%„€ƒÐl±–³01ô?k;æ ¥[BÀi	ˆàå´S'BÀ™`÷!Œça…¥ÅáNû—Pcsß'Xµ[r¸û'í	w! ‚—»Ì´ŒS»€¿­ŒÌrZÃ."ìéoþÁ¶,d§«‘4ªÿ`v%
7  ‚—L²Qû€‡ù]ì5’Ý;ÌI¶O'Ì´š‘YFeÍ´jAB7øf²Ê%! lL@/•ê„€ “[Cgr*é¦E‰S1 ·( É£Ãi|Z˜LšÃD@¯a-Í!àš[ÚU¸ŸÄØ@š>!Òá}ê|å5(bþ~^ß_é pv"x9ûJÿ…€p™«èÂÕ¥å
ð÷v˜~õÕ8oýˆƒnBéc#úÊ.÷…€¼OŠ
! @ ØrÁ`}Ò˜p§…’u©Š.æÕÒjö à<‚£Ó—Ž»%¼ÜrÚeÐB@ØŠÀÉs”WT¯´\~¾Î¿T×ÂK¥Ð~%ÅòîÇ([a’z„€¸A@/y„€ P[ßÊqi§M	@
Ž]ä"/™žå¥Ó•ó(4ØÇ±;+½ND@/'š,骎AîŠËš”–ËÛÛÓ1:5½hcwØùéïPî0†`¨R¥6"x
jiHg'PUÓ¢´\X‚KM
vöáXÝÿ«u”q¶œVΧˆP?«ËIF! zÁ«'¹"„€èAàЩR‚ç÷ÕÝ2ä{ïà
Ált¿pFl>rAëˆàe'É%„€›(«l¢OŽÓ‚i1ÊàÜM1‡]PÒ@N–Ò²Ùqj	ÒxC„€°Š€^Va’LB@¸ÎÎNÚÇ¥[Z;8°t;>{’+B w"xõÎGî
!àâÚÚ;èãÃE$»õl2ÓØýYS×ÊŽWãÉ•}œÙ–Tâ–DðrËi—A!—¯ÕÒqù³fA……ø
€àµã`MÁ·]Ï«¿0I5nJ@/7x¶pgØ©ˆ ÖÑ‘~4{r´;£Ò±Ï*§¢²FtÛ×âX),©Ümˆàå6S-BÎ_©¦ÌKU´va"J¡~*ÛTØ¡ñi¡4qTøP7'õ‡' ‚—ÃO‘tP[hjn§¬åÈK`‘¶¨Rêè³*)÷z­šŸ@þÞý()Y…€kÁ˵æSF#„€øÐ¿Äö\7-J"?/39äÒp€ð‹ Û©‰A"üpiÃ!	ˆàåÓ"®A ­­®^¯°Û`ðAèt)¥$Ñè¡”6Rì¹ì6º†±Ü›Í¿Ð~a¹·°¤š[t9ìwB!Aþö뀴ìòDðrù)–
û()«¡ûþë5
öN´µuûC%oOiCT_WMþõ›ÃÞiÐ<lpØÎÚ¯¸hzéTXÖÄÁÇí«lll o=°”Ö,™h¾ÓrUØ€€,´Û¢T!„€eAAÁä0Òr†!ºcê·åüµ$Õ„€¯'­_–L—ók©¬ª™<ýÉÓÛ¾Žk}:‹2)#úE€cHB@!`£’Cćš}ÐK«v" ‚—ÀK³B@!` À+Á’„€ÛÁËm¦Z*„€B@Ø›€^öži_! „€p"x¹ÍTË@…€ãøæ†	ôæ3K9lO”±s3'Fªkß=ÖxM„Àö|ÿêÓ‹hãS()6ÐdNzýí—Khæ$qŠk„"M@/‡žéœpmN–§§ݵ6Å8PãÚÁ“¥Ækr àïïÃC}é‹·6»oþÅ«$!àÄ„3Ì’ôQ¸(ãYÊ£üè!4m|µ³ã­±)¡Ê¹æ™UìC‹fÄR\”?•T4Ñ¥TUkp´äM+æÆST¸?UV7Ó®#ETS×ꢤdXzÓY+:u\ΩÔ_6zÓ‚i1”¨ž‰Sç
‘Œä@Ø‘€^v„/M!@ôöG¹ôøƒSèö•#”à&omÍ¥è?zúÑ™JÃÑÁ´`w®Iß}&C}˜þü‘™ ±ð_òôò w¶ç	R'pèT)Íg¡êÞ[GÑc¿9Öc´xn~öŸ3X ïò	öÙu©ôÜ›çiwFqürA7¼†›¸´'„@7šÖ+}l„ºž}¹šÎ²¶ëá»Ç)¡kã[9ôÉÑ"Ú°>nY>‚Vr˜™}ÇJ”ÐuäL=ór&ä@­ì©^’ëÀœCО0*Œ–ÍŽ#xÀ×§»Ö¤(¡ëÝyJ€Ÿ2.\	ö÷° &‚—ž”Û‹€,ŠÛ‹¼´+„€‘ÀÛ¬áÒÒ[¬C="X½ŽâeÈû?=†R“ç	Ñ-WcS›2ʇ¾7k»
KU~ùãú^{ÿuvvÒç>•Fð€¯OXªFŒNhR±t}2»’Î^¬RÂZ/]Kö&Ðý‰µwo¤}! Ü’Àñsjɰ¬²Ii»ÆÔHIqÊVÇ‹…«s—«èJ~Òrüâ…³”{½Ž–²Öãÿ{­[œ¨òË×'p1¯VÙûa9qÝ’¤n†í_Ccµµs Î©½½C	j,«Iv' KvŸé€ € Öú岪&µdô»¿ž£ŠjƒA½žTö•jzüÙã4el8ýà+Siî”hÚº¯@ŸEŽ]˜Àß·\¡¹S£	!‡ô©œã>âZÛz•VrHV/$ÇQ¯HÖ³@&IØ›€^öži_³Že–ÓÄQáôÄCShûå. ?P÷Ü0þÒ£iס"u݃cÎÀ_’ûÀ.×­{¯Ó­+FttÆÙr32”¾ÿðTÚͶSx÷cL¤¿zn°ô(IØ›€^öži_EË@ú¥ -{®³‰Z1/žüÌ8•§‚ÝFlc!KIXмï†?'|ÿíßW„¤О
Øvié_Û¯Ò<ÖzÅòsÒØÔ®.oÞ¯ž›%³ciÃ-£xɱƒ°òåw/jÅäUØ•€?Ä]O±]»"!àjJÊjèk?|‡¼Óúäñndª•ÀRQT˜5óî5S?]°ñig[Í·Wo4WŸ§÷^|¸·,rÏN~ìïTÞE^Þ]. Ìu†ô¦»‘Ï–õÉ‹ÝD†ùReMK7{/}Óã–†"úÆ=SiÍ’‰¦·ä\ØŒ€h¼l†R*B`0ôÆÐúzð
[s	ö<’܇€9¡£7ºp
¼¥ç÷%	{]ö"/í
! „€nG@/·›r°B@!`/²Ôh/òÒ®pUl_ÓÜÒNbJêÓ=è!ÚÝäXÓAÏ¡TÐ7¼úf$9„€}ÇŠ©¬²Žý:©½þÂj°m‘°Þ
·mÛšÔÖ#é<ó*™ØÇ÷§Š~ç…Œ›1öDbô~Ïû;x·¬<'ý†)úE@v5ö—dB /%å´‡c).šC	1}e—ûBÀ®àоáäyµë4¸Uã"x¹ÕtË`…ÀЀSåÜ”µ^쯦’„€“ØËZhÀVÌg÷òì:É´9e7EðrÊi“NÇ"p½¸²“Êåsâ(:Âß±:'½V€ƒÞÑ?52!ÈÊR’Mô€^ýã%¹…€Ð€óÒ’¿Ÿ7-ž«»#‡BÀy	ÀÓ}um­Z@Þ^²ùßygÒ1{.‚—c΋ôJ8<Üëu„xŠ«æ'Px¨¯Ã÷W:(úC ¦®…v,¤i"iôˆî¸ûSä¦Dð2%"çB@ôJ ­­ƒv*¤H¶æNé5¯ÜÎNàxV9•5Òê‰ÆÝÎ>&é¿}	ˆàe_þÒºp*®ÖЙ•´†?„B‚|œªïÒY!0Põ”};gŸFF…
´)'¼äAB O-ìu;/»$ÆÐŒ‰Q}æ—BÀ	œÉ©$,±¯Y˜Èv^®8DÓ0Ák KBÀ™	d]ª¢œÜZ»(‘ýÅç²3Ï¥ô}ðššÛik¿R“‚i금ÁW(5¸¼ÜnÊeÀBÀ:M¼¼ÂZ®QÉÁ”>V>`¬£&¹Ü…@ö•jʾ\­´_Aò…Ä]æÝãÁË¥!àbN¯ ¼‚z¥åòó•%›^ŽÀáꎃ@³&˼°º|5"x¹üË…€õêê[ÕŽÅI£Ãh\ª[ONrº3Ë×jé$YYÍ®UBƒÅµŠ;?ÖŒ]/k(I!à2ΖâÖÁpØÇ[œFºÁ”ËmH ­½ƒvòÒ|Xˆ/ÍŸ&nVlˆÖåªÁËå¦T$úG Š=tï:\H3ØQdZ²8Šì=É-º¸VTOGN—©˜‘á~ÝoÊ™`"xÉc ܘÀÁ“¥TËË‹ð>ïå%ÝøQ¡ÛÆïâ˜~¾ž´dVœ
k–ª\€^®0‹2!ÐOåUøHÍŸMÉñ¸Ÿø$»°Š@Qi#í;QBKgÅR,àK  ‚—<BÀtvvÒÞc%„°?ËçÆ“§§h¹Ühúe¨v €ÿ¹=ÅÔÑA´lNœüÏÙa­I¼mF¤?B`ˆ Þܾã%´xf,ÅGË·ï!Â,Õ
³Ê*›è“£Å´€
ï“âÍæ‘‹îA@/÷˜g¥›€Á|8ï®B‚½É'¼¬èÍ;—𒇇‡h¹Üìqá:¼ôˆØ+çuÙUÂK ;a
´MÔvEöŒ!\©Z؃ÀáÓ¥ôë—3	KùÅõôÎŽ<š:>‚–ÎŽ¡Ë"m
…3biNz4½·+OÅ}Ä­¿¼•£l.uÙäÐ…	ˆÆË…'W†æ~àIû¿þ÷•V6«%ÅÙéQ´pz¬û' päLæÀÛÛöPhýþûs•æË	º.]Ñx
žŽFàý¯)¡ý:ÊQÇŒu´.J„€¸A`Ò¨0:|ªTÕðrã[å
7  ‚—L²Ñ=ÀxwÓÎ<ã`[yçâ;Û¯Ïå@Ç"°‰—!piië¾ëÊ<@;—W×$ K®9¯2*7$pñj
ùÅÞ²£#ü)"ÔWŒuÝð9!;ööNª¨n¦²ª&*gpþ¿Má\ƒÞö‹€^ýÂ%™…€B@!0pÞ/*%…Àðøçæãôö–“Ãß°´èô‚ƒ|éÅ_Ýëôã!àÜDðrîùs»Þ•ÕRc{ù„»ÝØeÀƒ#PZž3¸
¤´ÝœÉ¾NÍ-mvk_¶-)’8Ž¥ûŠî;rÛ>GRÛp` ž²/d8‘»B[â8ÖygñÇÏü›<}BœwÒs#ÆúJÖ<ßC	qaÆkîv ‚—»Í¸ŒW!àl8Ø‚W@¢³õZúk†@@g‹™«îuIÔî5ß2Z! „€BÀŽDð²#|iZ! „€p/"x¹×|Ëh…€B@;ÁËŽð¥i! „€BÀ½ˆàå^ó-£µÒüSÔÒTÛïšÚZ›¨øjF¿ËÙ¢À@ûl‹¶‡£Žšò«ÔPS<MIB@AÁkPø¤°£€póÉÛߢֿzc×r޽E—Ïl6ž÷÷ ¹±ª› ulǯ©²ø|«¡úê:øï§úU®½­…jKúUÆ\æöÙ\]¹–uèU:¾óÙµªLvÆßèê¹mVå•LB@{ÁËžô¥m›èìh£Òü“ÔÑÞx¶¦"êªòÜÖ™½éâÉw\~01–½ï~o0U8DÙÄÑ‹)uò§¢/Ò	! „€=	ˆ/{Ò—¶‡ÀÅ“ïÒµœ],˜µQÒ˜%4a΂–ìÔî?Rk±Â¢Òhîºïû•›¹•Š®%O/o*/Ì¢E·ÿ\Ý+¹vœ²¾Á×}iúòoRHÄ¥™:³o#ÕUæSdÂ$š²èAòö
4Ö¥?(/ÌäòSËc¡Q©ªxã¿|æºröCòöñ£ñ³7ÐÙ/ª<¾ÆÏþÅŽ˜¡ªÉÍúˆ…É딾ð~u~éÔ{ÔÁBgdü³õjm£Ì™ý/ЂõO©Kç3Þ¤€9~^>HNü“ÚYhMt¥¥¯×Š©Wkúuž1/‚cÈ?0Biq¢&ïÍXù(e³Y3´Žë*àzJ©ªä‚ÒpayÐ/ Œ5d™Æü¦Ðæ@SSQ”M>7ú!B´N(=Š 
óö
`­P:ßë
™Åçè´A•Å9Jã“0jÒ™ÖkÚ¶¹óÂÜÃJpƒ V[yûIåÝûoMŸ¡ùƒði‰'Eüblq#gQgg‡â€>¥LZK£§Þ¦~=<<©±¶”,õ«äÚ	ŠakÄø•4zÚ4bÜJsÃ’kB`ÈàK[ÿßWºpâ_ÔÙÑ>\ÍI;CH@¯!„+UÛÀ¸™w³–è?Ôoxì8ÕØ}mý!µœ×ÔPEÔiè_|ê\ZxËO©àÊAÚúê}JóÔWÏ! yzù¨7C,W"yû/?Ƨ̥‰sÿC]3÷çôÞçéÈGOS}M¡Ò¡3^Þ~´zÃóäííOÛßx˜òY´”wpä„Õ”w~/~LÉc—©òæê5­£³£Ãô’† ÍCßñ;ž5RñiÝ…Psu›ëso<‘KªÛÿú}ôÚ—	ó¤	”zYqíl·Ø¯¶–FÅZˆÄîÔH¸ß+6º|øò=Jˆ·4úÁnR1Wÿ°µ¦}Kýêíºé¦žÜ,ƒÙCoe†ë^}u¡ƒ€-‚× àIQç"PÇ»
k+óhÎÚÇY3´ÆøFe¶˜äi´ôÎ_ò›I›ÒøèG!;¡±”C┆- 8š50·+û¨ð˜Ñ–²Sá•C¼¬÷yÎî5
ÊETЦ¡e×O)Á®‰ƒÊb‰®³ó†¤x£Ö‘Ö¨oܰuÉù‘ÌÕ{#»zÐÞ8±á $︺ŽåT,B€Cÿ“Ç-cm_´¾¨ÙºÍõ¹7žhã€]Üœ5ߣ±,xõ–,õ+4*…ùœQvcK)kÀ$¹'`¶í›ºä«¤}2Ga°›TLË7󷢫GXÓº‚¬iß\Ÿúºfº©g$kw¯œÝÒW±a¹¿íõø}rà–†¥“܈×;ðäH×B€£é"ÝxQ‡,ÌàBpxÛM¥-/m`£î@¥!Âý‚Ë”A:4=X‹¸¡!Ã=$Ø}íÿà‡¬ÚM·}u¿Á{ßäñfs/o_𵿻ÊîèÔžç”;‹‰óîUC-øËånt,uò::±ó·”uðòã%C,1bÙó“·UËŒ°oZ|û/¸Ï‰ªîM¾…æÞô8DËÕAø@›Ùn-&yººn®^ÜÐúÃw,Ïm}õ‹äÇmqýèWBÚÂRå–—7¨%Lh”–}æY^êLQõ⹺ÍõÙÏbÞ €v°„	áy ,áeÙËÿT÷Ñ’“å~¡¿qãiów+^A¡	ª”*,ÜŠìaWˆÍ2øbºI#*!½Ç&Øfyƒ¿Læg2‚m*¿¬l(3¶ý’¢’Ò)ïÜvãÆ¯å›êËÕÿ¾œÀœÖ>Ža;‘D—¨Í.°'Å—·¢Ü£¼;ú_ü§FÙ%NšÿEõtdëÓʶ2,:M™\ÉÜBQñ“zlê‰O¯ê6\Ø+Ø®ô¢úB…÷|ù&&
S¹}ü?ŸÝÿÿ”ݦ›,Lb›Q|Ù4–es…†Úbeª€¥ûªÒ‹„Bø26z*¾HÞllö—í­ÍtlÇ3ÊÜ_³y³L³ôä÷ÁQé·ð{Ì
c~9èIÀƒ¿}vÿÝ3\Cà¯ì¦m‡*•p`©SÐa¢–4M•öo&t:y)K3–ocA§ƒí'´e/­¬öª–Yã…]ЦõëÏñï„%_Þ­‡7[Ó¤Ï['¼©«|\·Û¢agb¿1û†ó=ƒBv-͵j‰Î´>Üëäm¸o¶^&ðsÛ1-imaI,|ÙÆši2[·™>›òTsÀl°\‚˜«ïÙ¨ªÆ›8Úž¹ê¿xµµS1À
='œ[ê—aœÃð6¦eLSsõyzïŇM/˹¸ó¡ä2ÞlOaOøñ[Ð_ßL™^¢s,PAkŒtò“?Ðú‡þIy,ÁFjÞÍ?¢ÐȺ|ú}þ÷Ó¬Õߥ¼ìíjÉþæûÿNñüOáyÄcü?οùI~f7u+]¹nÞö?ÊžQk}Ùö×ûÕØþð”6.’¿$l{ýAš½æ;ê žûÄQ‹hÊâ‡hË‹Ÿ§97=¡!Ølb×ñ’;ÿ›ß¨÷¹´)ë	ö¥øß{÷Ÿ¢[þ—Ñ.cTcæÖ³Vý·³¿FIì¾e4o<9¼å§<Æïð®ë‘J;Ž/W×`žÀ¿7}ñÕ®²¼¡ÇÛ;@õwÕžS†bGÌT4ʰgÕ„²]o~CÕ?‘¿~¢Ï9kS¦G·ý­üÜŸø‹‘ÁÄC+§½¶7äÒr%Ä…i—ÜîµëÓÉí†.vUz¡c4ý0†áºáz×ãoÉíƒÊÈ`ø®%ÓúõçV`ôn)éóÂ̘XC‚­>A Ṅ{¦â‘Ùzu‚(êññ2WÒ®A{g)™­ÛLŸMyª9àŽâ[vÖá¿*›/¶ekgÛ8¸èÐîkíê9áúd®_]ã0¥ Õ$¯îF@Û¤qŸÚý'jáeAý&\‡ÆæÅyJëÔPSDÐb!ig ‘=ÍZe6•8™ð{5k›âjIðÒÚuç×®Ow¦ cB`XÀšì…¦¯7!uX:$¸s›4L	í–ÅÔf’6Xõmµd|Ú5<Ÿ–vâ¾4˜K(¯i‹½nÔ
0„&-y±Ÿ>hØUbMÓÊk÷ͽB›ÞÖf¾MmÌž^†/pÆ1p?Q7v(ïÝô˜Î4s­
­,ÎÁšghü.±VÑ&°,;9K©µµçØ,q³T‡»]ãzw›q¯°3|(†ÄŠÐeçyp§æ!Dé7©@kÓÉËï£à¾„7“Ä¥Î1k 12-§Ãý	å;–àtš/%ײw*íê÷óS;#±œÇÈZB›úM=Í
ª,œ
÷7A³»1Ø›A‹¥OÅyÇX‹V¤Ú‚“è0Þ„ÝšcÙÏÞìµßã~÷ܸâé飾µe;¦Z.[Ö-uY& K–ÙÈ''—2K±Æ&Ï`×}T_U@ûßÿòÜ>}ù7”/Ÿ¬C¯uåv:A
»F°P€„­à=e1â'îf!èå<ñЇ?S§!øøÊ+´V„
Ñ-`á	ž­OHûÌ™¿þIåìðâÉMê~Û‘Áøuòü/qé%ðÔpÀj$xâž¶ôJàƒ …z‘Ð/¸e€}FYÁ:Çþ±ô)'ãÊiâlö¨§‘ûÙK¦	dúÂY3O‡6?¥ò€|ùÌdÇŠ•%9”ÉKŸH–®£ÿÙ‡ß`¡ðµdrzï_º5ïâ°'Ó‚qc>à	Kª³¸ÉîWŒ,õ;ûèßTˆ£©K¾¢–µå`­NyB@8¼œa–¤"`ê„N5'†Øî'††P6]Õ›:ÄÍ!"vûhsSï0‚¡NOá+GKø&mεó;Õ%„ïH™°V9PÄÎ¤Š¢lòñfÁ©Ë‘ጨÝO¤OçÞ£â+BhóÕ•A¿à•aAFM¹UÅ.Ô—ëÍI¤>´npl:™—c몯s¸TUrA1êàp¶¼0S	yæ®kuiŽ'á±»Ž…*klXEšC´-™¥~#®äÞâŽðHðôí	ŸC’„€NF@/'›0éîÀ	ôêİ—j•ñªÚ²
ƒ§ŠÊ#«*Œ,Œç-áñióºÕ2râj»ï ñ‰OË¡8žç­ãO«ÐؽÄÆ2ݼÂ߸
;¯í¯?¤–ãšx—‘.»Ú)¥Îb9¡ƒôIï$Ë™z'‘ú|Úr‰×
Oúm-'èÆŸ2—ƒyÿ‡,1]«K«˰æ´Q=ƒŒ£ß«7<ÏáJüiûÐ(îQ¡Q4ç–ú~c<ÚÒÆmØ8¯µ.¯B@ç  6^Î1OÒËÐ;!„ÃK§ßSöSø@×;1ÔW¯9‰©¿Üíš±K¼L˜\¿¸‡Ž6ÞÃÚלDBP“Dð0MX…)Ÿë„ú€—m†MD’5Sæ®›Ögîõ¶4Õ¨ð%ˆ‘	íU{±‡üA[›´²ë§,ö;„—Kag‡ ßyÙ;ض®Å\SrM!àÐDðrèé‘Î
–€Þ	á­_yGÅÛòÒ%€@0ÑœêÛÑ;œ0÷^åÑxÿ†CÄ„´„€³[^Þ <°Ã?βÏ<Ûc'dꤵJË¥ÅhK¼ŽNìü-e±­–ÇtÄ2’Òà n-)
›‡Üo
}öñÔ	wJàË<øÝú4†Å«eGcš.È’“H­	¼¢]Ì¿û§›y+¸ǘ{LÙhÍb»0l48µç9ÖÖÕÓÄy÷òÒìÝd麾ÿz
¡Öv–B`ÚÊAˆýxÌAj9ÕC-µ~òö£j™>•ßþµMÝœsˉ<{ßý½÷ç[•wmÿ@Ä´½—ÆX^…€pâ@Õ9æIzyƒÀ@¨ê¢KNõõÎõ
Õ+Ú²b	Ó—í“ |˜KX.Ôkš ¹òà ØêïLÄn%}¨ÃÔ)"ŒÐ!¤A[¤ùʆý@_µÀß(kZ—©“HäÑ’–Ú(ØŒi²Yu@IDATuã>ñ•¦‹ƒëêûoîºVi½Ú¹ö
!Nµ„þc-5JԸ⾹~£mmÉÑ´M­NK¯â@ÕÇ¿~çƒÉ# Åñ;*=ì“@GÓuzî¸°>óºjÑx¹êÌʸŒ`À­O¦K‚ú{Ú±æhçz‡†¦[ºá5½/Ïéz¡õu³åbIßΕ¢“ã4ÁÊãë_V’´{(‡dZ—fe¸Ûý¯–WÓ¼éïB4ÄÚÜu­­¼é¹v]ï¬R»†q`s‚i2×o´­]·Ô†i=rîüfOIE¥5Î?^FÐÞáIÍm~ÔÖáEÞ^màÝÄï½pÒ[Þ^aäïßÓäÁI‡3 nw½‹¨¸B@!0´~ðŸë†¶;Öž{½ŽÎäTRXˆ/-˜C;Òâ™±´'£˜ü|½Ô1^%¹¼\g.e$B@!à°d~*»’ò
ë)%)ˆn]ÑålÝô¡›—&S]C«Ä|}˜¿Ÿ`N0½}vQ¯>I! „€ƒ'ÐÒÚA‡N•RUM͘IÓ'ö\b×·¢	`õmôñ‘"òñöT0Àô”œïX/ç›3鱋€Ñ9{|‡SVgNuål_GÁaI=ìËœy\Òw!`+Õµ-t.ŽÀrb8/+ö'xÓ§–$Q—æÁX‰ÖŠŽ“W/Ç™鉋ÀÎGì‰í6ÂzöÑuðßOÑ_ßÜíú@Oê«U¦ZŸ5å2¼DW2gÑŸ…¯zZrç/•.kÊJ!àêò‹êéĹ

	ò¡•sãÉw¶Z"€¹Æ£sä’QG#€°;ð?5ÔiÛë°³Õü¡n¦[ýˆi¹þÁ·hÝ—þªWgý{·ûr"Ü‘ŒåßßuJ*šè–åÉ´ÜB—ž£&€-œKŸð䎃ÔÔ,AãõŒùX4^Ž<;Ò7§"Zg÷ÿ?,ÛÇ?„&Í»½¿ÇÐÙ/ª€Ø¾ÆÏþÇ‹œav\V˜…íì÷+uÒM”–¾^ÅHDp膚bvΚʱ¿©¼Ë_>ó]9M“ÍXñ-Ê9þµ·6Ó±Ϩ¥ËIwQKˆyvÿ‹Wò2/&Òä…÷«˜ˆW³>R¸«J/±Gù=ívKåàÿÒ©MÊoØøYŸWá²¼¡B ùEp@ë/+ͼÛk	q+гµSynE ­í·N—QEU3MN·­ìn0?0y	r/A6°
0//Z2K– ‡‚µ-ë—-iJ]nM ©¾ByšŸwó(†½ÍßùvÁ²á'k"bð–o.ÕVäÑ¡¦³	s6Љÿ 8ÆN›|3Í_ÿ¤¾.r˜"8S=¾óYöTÿm€îg¬ž4fúʾ
¨Üž–°{jß{O¨<¨
pߦ'X êà@Ö×é̾(:i
¥° OõpZ
AìÈÖŸ«0GçÞÇç)'ãTÄÁg³GûPݳÿý*«hžï¡é‚Ðïô’„€;¨«o¥m
Ôï¤QaJàJKVš¶h†A¶ûÓØ„X°’‘€h¼qV¤ONI 4*…Ãú³–ê¬ê?4Mp®
M¼µG'¦[W!5AQ¬•º¢òÀ¡hyA&ùÀ²Õ¯°
±{*rt&°
š&$ØvEÄŽãóduŽ?
µÅTUr–ßý;åxtê’¯Ò&T_]¤ò¤pH£Ñ,¬!Úý'j¬-¥â¼Ç*
®Ç§Î¡]o~ƒíÇâÔ=8 mà¸MõåJ°,Í?E¹YÒâ;þ"ã' ˆ$!àò
KèXf9ú{ÓRÑ2iX]»Ù˜§'kÀØ'X÷Q’ãÙpœ¹ž89J{7=¦–µx‰Ö	apàÁÝË˰Ûiü¬ÏQìÈ™*ÎcÁåý4rÂj
ˆù¨x«7ªâ–]?¥/OOÖˆ•RHäHcAlÓÁèZÎÇ„%ϼs;8P$V	Æþ™„E¥Ñ•3›ÕNLZEW¡Cím„¥Lô¿ž5^Z(¤°èѪ¦õȹpíítäLË7Ò¤ÑÃc¿evæ0xÅÇuIö# ôíÇ^Zv1‰£ñòß;´ù…»)P§1Šˆ«„ŸM¾…æÞô8%[~cäЋ¾-'¤- „QhËËT|Ķ–FZö™g)uò::±ó·”uðÈ"Ô#–0?yûQ¥q‚MÖâÛ¡êKâöaÏ•ÄꆾiÒÚ6½Žo¸z¡÷•ÀÖ¥ðê&â¾·o ^º%3×P·‡ƒ]Ý!'BÀœÜʾRM1þJø{WL0¶¿‰mÀ €í=V¢†¨– Ù=…¤¡' ”‡ž±´ „€J ££“2xwbQi#M
¡ÛLV;h·mÒ-`k%*ç«X‚„ö~À°;RÒкCÇVjB@%Oï°ßªoh£9éÑ4—m¸Ü5!æ£`Ã7û"x
kiI! ìL œ=Ë>]JÞ^l¿Å«ƒ9Ž¢$nرb¶%5hÀ¢H’íMÛ±”š„€BÀA	\Îç^ª(2ÌÖ.L$oo	Übiª”ÆŒ Ü#˜%L¾.‚×€ÑIA! „€pd°Y:q®‚ò‹(-9Ø­ì·l1/æ0ኖp0|Eð=);ì°ƒ®³¥œ:<뇽miй	´´´8÷¤÷Vhiaû­S¥TS×J3'Eòo”Õe%cOzl/kÀàP~ÀDëÉÊš+"xYCIò8»×O§eóÇ8L¤#ÎCÀ“…vI®M ª¦…±À…´pF…Bp¹ö¨‡otÀÖÜX‚làÜEð8;)i±Ñ¡„_I]êÙ»vÆÙ2š75†ðÆ(I¸¼‚::u¾’ÂB|iåüòõû­¡|4¬™5‹ÀÚÚX†%HÙ¨`v¼¬Â$™„€c8q®œ®±ýÊòÙq×cF†PúØÇì¬ôJؘÀ©ótõz=L¢[–'£=ظ™a­Î™ô²~¾^´zA"iX+`KDëóy‘A}"’BÀñÔÖ·ÒLJ‹8`oIéÒž»\E®ÖÒjþÖ/NoÞ¤Gƒ'ÐÚÖ¡–+«[hÚ„JI|¥Rl§v.TËyÒ¥~uà €•æH0ËèDð²ÌFî‡$‘YFÅeMt{œ6·%†Åí/ ¤¸@1*vÈ”N
„å²ÃÓvö-µ€ýoE„ú
¤‡.äWñ'gNxÚáˆ0,A†Èd·éÁ«9ŽK ²¦™v*¢ÙéQV}Ë¿˜W£ü­œ—ÀFƲýÛqgVzÖ‚’:Æ!}°ƒóXÞrÕeÿ‰RZ1/Þ%†¨	`-­í*‘`†iÁË%o„«ÀN­êÚµáéi½H[{mcíW4ýuç(®þ|¸âø2/VÑ¥¼ZJŒ
PšÛþ<÷ÎÊKáUÙ×¼´y€¶÷x‰²[<3Îí¿Šà¥=ò*@Ye}r´X…6IŒ
p±ëëèÙrõMž»%	G$€/
GN—QYe3MN£G„8b7‡¬O
¼CùXV¹Ò
Y#v¬X0|¼ìøJÓBÀxÜÞÇß[Z;hÅÜx²Å·}Ä]Ûq¨‚9îÚ±–š–ëB`Ø	À%ʆç}ÞÔh¥¡öN8@ƒu¼in1±]”+'¼¯íc7M¬	sG
˜^®ütËØœ’@aiƒ²óXÊ."b#ým>ØÌ8YJCU¿Í;,º,âòF¥‰
`ÿsXàïÞŽjêZ(óbµÚ<ಓ®Xw,ÖmÞŠà¥{äPØ“4R»yYÑ“}?µ4j)"/¶ƒ†PL’„ÀpȾRM9Wj(.ÚŸæ¤GÛD£;\}ÊvàylæO‹Êf®n%€g
Û¸a¤«GÁËáAé;È+¬§£gʆÝ«´¢‰vóöuhcCæŽs&cî|±8ÂÏxIyOåß°þUà¹+ªšéâµZ·Ý£À±9"¸bÁËgUÆä4`LG¨Av¶»BؼéÙÊžÌi&@::䛨~‹—¶ñ:gJ4ÅEy›ÎÚ6ÓäòF˜Ù“£u6é7üá=©±É s5L/›<&R‰è?KüÍöĹ
Z»~¶ìÿÍ®¿~Âú?b)áN D抈›ˆÍør!©w%¬Î/ªÇÇ70AÛÇŽXXhdžƒpÑ€‰àÕûÿÜ6'ÍÒ.Þ]åïo°ð/á+Ù‰£9Ïø6"º8îÍbñè?šË;½½$`µµ\TÖHø>!ÒÚ"n‘ÏÕ0¼ÜⱕA:
Îf]ªâp?I­€àµý@MÑ-¤£p”~8ØogÿS%4šµOîXtðÞ@ûýò»©5ÞŽ/è‹à5¼Ï´æ¦°[g'k¹F$ÑÔqNCáĹr*,mTA·}]8T‹ÓLˆ;z9¿–ã%–Ò=·ŒR=Ásóº†Vš59J6hp~°Ëø»Ï#l²A
ñ¡?ü`žK‡G *U*üÖ±¶˜—±ÃC}	>О~áýøkÓÈŸÝ“8rÁË‘gGúæÎäT¾Ñ®[œäðoæ€Ã›ö6Ö~K
¥I£E“aŽ‘;\«âUO<{œêYÈzú[3é4?×pGû-‰Á7ø'dû×sª¢/Ý1šn^š<øJ]¼%€1·º†6*bÿ‡Û*÷8ßÜ0Á¡Gîø:9‡Æ'–	@þþÇ×”êûŽU#RèÂèÙ(ýÇrÒæO®©Ýi–G-w\‘>àžy9“ÊÙÝASK½õÑU7ô&þ2!B—mf¾»b(2ÌW±µM­®]lP—s\Kh½•i»ÇÙuØp쨣—m&Ž:3Ò/§&`°wiPZ.ìêr…”>–í½F†ª Û)IAlÿ%À®0¯ÖŒá…^ œÜcÖœÜjvºk<•@X°;W¤VÞ|ãvJ6²ÍªØºï:1ìªî¥w.ª÷ª‘lÚáˆI–qV¤ONK!?v±_®)lÇåÊ~ñÁ{îr5­šŸ@Á>N;_Òñ¾	4s<=,+6³=´]8ÇïâYCÒªï¹nŽv6¬Çìíßgó{ì
I÷e *ÜFÄ‹àÕ?’’[8xå†ï¢µ‹Ýâ;ˆ¶í/ øèeXídÓ%ÝB@Ø…€h¼ì‚]u%ÕÍÊû<¼r;ªj{(y_ánp»’µ_®âàp(yIÝB@¸7¼Ü{þeôƒ$€HØÒ¼š…Øh¸kÂÉöƒJðr·¿î:ç2n! F@¯q“RnN¡=°{fÑìD
ts]ÃG¸“C&1ac!I! ºÁ«;9½€£C\íìZÂ…‡líêÁn'°ÁÀÏדs|5aÔ‘\BÀ	ˆà寓/Cï‚’:À^º—ÍŽ£˜HÿþvÃ܈9ïÒ¾`€/ip¿L-¼›P’ÀôÉÉ2ð窵µ»,8uÍ= A—®–R~a•®&9´D`δÁ˹.n€çã#E\Kx½$ë	hBøØY6'έíଧf>çm÷ÿ…¼|%rè´4UÓÿ>~M™d–W«jèß|…üÅp5Ö—Ñ;f‡É¾VÐëžå™çwÐ'‡óÉÇW¾v'Óý¬¥±Šþô?w“8PíÎE΄@7Wê(ãl9ïØ‹§ˆP±YêÇŠ,3.cÏÒðxþÎŽ<š?5š’Ô·ŽñkOOOò	L°k¥q/OÛhþüýÉ[˜¦µ½nÀÓËäåI>a®Ã
zy´ªaº†Kmw˜1ã°@ˆø¨B€è»Ö¦ˆÐ5Hú0´ÿsDàí÷» %	! Ü‘€h¼ÜqÖe̽¸xµ†N¯TñÒBƒÅ+{¯°úy•lùÝy4kR$¥%‡ô³É.„€pn¢ñrîù“ÞÛ—·ìɧچV¥å¡Ë†puUÁÉ*´_X~ܺ÷:Á¾$! „€»—»Ì´Œ³Wç.WQö庉ÃýÈ¿E¯°ltsvz4Õ±ûÁ'×(}L8Kû¡•j„€p`¢ñràÉ‘®
=Ʀ6ÚÌü°9ºsõHº†y·`ûÓ«S¨‘ÜnÞotÛ-‡œ! \‹€|µw­ù”ÑôƒÀéó”{½^µö÷óêGIÉjkÓÆGÒ¸”Púhÿu="„ÒÇFغ	©O!àDðrˆip¿N—ÖÐ#?~›ÚØGÖp'ø–jf{.o/Oòöö¤M[üè¥gþc¸»!í™ð÷¦ÛWޤ¬KUôþ®k´zAýò¹èÌùB“œÃŠ}m5Í™–:üK‹B@¸¼\j:g0­míÔÖéMž)vét`PW³•9]'rdw“F‡ÓÖzmc·ç/—ù§°ãUûî.mk,¢†Æ»³‘!àüDðrþ9têH?§ž¾!뼯¯ݲ|ýýV5‘‡Ä{2ÒR±ÃM@Œë‡›¸´'„€Õ¦I’c¨)Ï¥¢ÜÃ6éä…ÿ¢ÎÛx¥·I‡ ’ŽöVkSÁåT[yÍÚì.›Ï–Ï¥5úÃ]ÞÕ¬!*y„€N ²ø<}øò=ÔÙiÙ/Z{[5Ô–x$æÊŸÞû<ÕTä
¨ÎfŽ]×ÒTk,››µ•Š®5ž;ÚAÖ¡WéøÎg{í–é˜zÍlæf}ua7á3ïü.ºÆ¿HÖ´#ûÈfj¶Ï%kžKôÌtÜýém_ÏåŽ7æ/GúSeŸyMÛìw¼úÄ+„€ŽO 8<™¦.ù*/ËZ~[/Í?I{ßýÞ€cZ¾¹¡Š¥#4bÜŠÕyfïFºxòcÙ‘ãWÒ•³[ŒçŽv8z1¥NþT¯Ý2S¯™ÍÜÜöú¬±Ê7Þ¹ÊÂèæ‚dMû˜‹ü»©µ¥ÁX‡=¬y.Ñ?Óq÷§Ï}=—“æ‰"bÇõ§Ê>óš¶ÙîbãÕ'^É „€p|­ÍutõÜ6J³„®f}DM
•TUz‰5	4zÚí•Ng¼H
5ÅJø?û4EiG
¯&ÿ š¼àË3†2¶ý’¢’Ò)ïÜvòôò¥éË¿ÉWòMõå=Š‚£ ¢Ü£,Hý‹šk(vÄtš4ÿ‹äåíGG¶>Mh/,:*вéJ把Ÿ¤´[ž^ÞT^˜E‹nÿ9ŧΧ̃¯8,ìê²KJC?Á,£ò‚Ìc*f&Nü“}¶Rꤛ(-}½šŒ38"‰
. Ô7mé×èø®ßR{k3ÛñÅ¥ÌffŸ§ÒüÓ4kõw“¾Ú‰AAa	Keù§(aÔ»³Ô?—xö¬÷„¹÷˜}.­}®MŸË‚Kû(04–ÚۚͶïéåCuetvÿ‹T]~™‚ÃiòÂû	<Í=»ãg}¾ÇÿBìˆVs·üÕÈîÓ%pgßÜ0Þ|f)ÍžeÄ0sb¤ºöðÝc×ä@LH¥WŸ^DŸZ@I±F sÒ£èo¿\B39&¤;¤Ö–z*¹vB
µ®ê:Ù÷‚¬RøÃ>cû¯ÈÛ/ˆFŽ_E¾þ¡4qÞ}JÀÊÉø‡²Ï𽿻™Bûßÿ!/UvRYÁYÊ>üM˜syûøÓé½Q•iùê²Ë§Ú¬ååÆýïÿ@ig¦/ÿÁæ%ëÐkê^ÙõÓ,°Ô¨ãf˹þ¸”Yª±É3X@»w®z³Ð¯>ñÁ鈩Ž5Q5åWc2ÃÈtLõUtèß)¡sœ
tâã?°ày޵Qõ”›õ¡²Åš²ø!¥¡º–ó	™~'ºÞ4jêmÌq·u•ç£?Ð
Œûj_cOu,p;BÒ?—ÖŽÛÒsiís­.ÁÏ3´³–ÚÇ3¿ï½'ÈÃÓS	¹>þ!´oÓjÙÞܳ‹/¦ÿhÇZî"x–$‡#pàd	¿{¨˜‰Zçîâø~¸vðd©vI^…Áû=|€…‡úÒom$âÃ>Úà§
¯î˜R&­¥ÑüŽ_,?¶ðOhT*yûPtb:`!J8‡zq^†ÒL5Ô‘&ôÌXù¨Òº¤NZGu,@xyûö(ßÜXMþ7´]´ åJ™¸–µk“iÜÌÏRq/öZÁ1äÁVq*?æÈÛ'€||ƒ¨ž5#ÎL™Ž©7E±åвüƒ"YèÌTCóö
¤9kgf3gh&¡môðôRËb!ÉJXõW˜9¦íky `.1Y3nü¶yKR6yU¥
8¬háÚsãI?Vý10˜·¥çÒÀ¨‹›aÏçZÿ\‚A×\™oš®Y«¿MÇwü†ÞÿËtfÿ4wÝ÷Õ¯¥g×ô¡?Ü=Xµ;üÁòãi^Ø‘@~a%=ú“÷É+0µ×^À þñ‡¦¨ZD/üó½ð“äÇžÏ÷/¡-{òér~]¯m´TŸ§M/>Ük¹i_þöëTÓÏK5–µ0Ø€ñ½Òé÷¯Ÿ£¢²Fúù£3¨¢º…þñáúú&Я_ɤçËè™ïΦØHºÿ‡û©­Ýð¶÷ïNUK“ýøU÷²ÝÚPHßúÒLZ6œ}@p«w<¸‘|C{×b@“ãlµ¬ÇoíXªBÒ®ãKZ-͵¼¤ŽS•”ëÎïÃøHúü¦çúòÈ÷ÁÆOÓ‚[~¢l•ä->Áïú…9}ýXžá*?vSbésæÊoé‹ö8îhʧ'YNS&$õ¸gí…ªšºÿ;o’wp—]`_eõ\õc@9ý¹~L¸‡±c9Ë——q! éó«eF 
lñ°d¦˜ðó?sÕu›W}y}}ØYzô£_Ð-ÿ“Û2ϪÁ>þ´Ö^ 7~ò’sÓ¯þ²ƒöj0.S›–×÷UÜÛ¸Q‡és©çû¦uiÏ5®ëŸK}>ý±iû‡°	D[nGH–ž]”×Ú´†{Gc=óý›ÉúY1´/…À°8~®B-–U6)¡Ø)).@ÙêxñRâ¹ËUt…¬–ÖúÅg)÷z-e­Çÿþ÷,Z·8Qå—?®Oàb^­²÷Ãrâº%Ý?”aû×ÐØfº@£½½ƒ?Ì:ù×5Ø@èBRZªn6^]V%ÆôBòû°¤	]8×êÁ1’þ\_×GO½®_ÜkÈÈá>ÂTèÂMhÊ4}}¸†ü˜‡‚ËûiÌ´;u9Úž«~è§þ\“ÖŒKšÐešLõB–!t!™ñiv9±_[Óþõ‹{hôô;ºÕ§
Ûñžþ¸·q£»¦Ï¥~ü¸oZ—ö\ãºþ¹ÔçÓ›¶ù1ºÐŽ¥gWÿ¿Ðî]ÿ¨]’p@mmÔÊ¿Z*«jRKF¿ûë9¥ÙЮk¯0ÀüÙã4el8ýà+Siî”hÚº¯@»-¯.Nàï[®ÐÜ©Ñ4е¢úT^Õ¬®Å°­Wie3°%Çñ{ÍfLÒÀ¤/zÀ ‰Xqc)|èÝúð;F-ñ†›ÀÞiýCoYMš1½guAËh«çÒZ,ýá.‚—µT%ŸÃ8–YNG…Ó¼¹ý@r€¥Ç=7¤¿tçhÚu¨èÿ·wÀqWžÈ9çÀ1€™`¢(’
VÖz(g[>§²Ï'{kî®Î>ÙU^׆ò–ƒ\’“¼>e‹¤HŠ
Ì9 ÅDÎ9cpï5ØÀ‡ÁD`Â73ÿf_êøëîoÞ÷úu·º/Z}1À‡2Ëu÷*zdË´q…>YÚDEÓ“é»_\B2,½˜g?fñУ´™57yžú¡
ÜDîðuÇïÄ”‚ëŽ/Y¸“¯àjgAY2í|¿Š—‘ˆ£-kré±¹iæe#ö°&CI2ù©;ë9Éð½Y”\P¨ºmÍU_Ú{‹Ö°Ö+›ÛIOïÈ2
¯Rµ›M+³iÇóyÈÑB2ò7¯Ü1\Pð^>€Œ$¦Fà?>>NðÃh1¢þå«”‘C}l×e\§Kf3ŠÏûÓk{M-mf§ÊšÔìV±ïÓ®›…­¯={\
'ÊP¢8yþ‹?@Ïýõ
¥§DSK;<ß1²ñ¿  à}¼¼Ï)L‘€½GùA[[NìyàB‡€Qè2–Z]Æ{2¬h¯Ýýá@¼A³½Aq‚€€€€
¼l@Á-ß„A³oH#³ÀP£Yj"„ò!krÉò½½Uá÷’b)¿W‚ÈmÃUd‘µ|ädìÛ~EóæÚ2+V\o×ȉÿK¾úüŸ?"MÚ2à™½W‡hLWa1'‹ÆêµÇ¦O0‡Õí‚W0ײÉÊ&³ÎœªçMp-ôÉÇæÑ–U٦ȡìõgN?xæ!êéñ̬;%´p[•-ˆ¢£Âi9ïž äd%»…Çýn\=‡š[»=o F‘A©)ãWÆw·ÑQ‘´´8ßGøðvQQ¼CÄä>p–Üi›`é¨FFdSbBaË G”ðÌcd+ÙÂgSI¶šÒﱈx‘@m·Û3h·^DŒ¨A ä@ð
¹*÷meØæÝãµÉÃ6"t·ÌðmNLŽ€hje‘U™!¹yU§@C:9’   x¡x@emå͉·¬ÎUëjy-!D>  ¶‰ïž¨£µK3ÕVC>HI€!^AX©þ.’¬¾ÿX-¯ I뗙ÎËßL~ð8ÌC²·ã½kòØ&Ú¯à©Y”|C‚—o8‡L*7*;èì¥fÚº6R’¢C¦Ü(hhÞ>RC%3hfAbh¥˜^S‡ÀšÀÏTÜÇ?DYé1´ra¦¾#5¥ÔÈ»'l[—§6eê¢p !ÁË#C;’ÊÛ¨ìzm_Ÿ§6¨m(}¨èì ½üѱ°(•æÎôï’¡Æå@$Á+kÍ$yîíb-W5MÏK %óÒM’+düCàìåfª¬ë¦í¬ýЉŽðO&*€€é	@ð2}™3ƒ¥W[èÆíN¥åŠ‹Å:¼æ¬%äÊ×zziÏáj*šž¬4`¾N队/óב©rØÍ³¹ö²–ë®ÉT<'ÕTyCf@À,Ê®µÒµŠÚÆÃïñø01Kµ  `
¼LQ
‘‰3—š¨ª¾‡îã¡”h¥F¥!—~#Ð×?¤l¿
³ãio; B‚ÚSíôöÑZ|W*±¦@ÀuWoµSéÕV5ó1)!Êõ€ð	 ” xeµz®PÇÏ7Ps[¿úÑmà@Ü' ÃËÌÇŒÔZ½Ë­¸O!@ x@ð
žºôhIZÚûhÿÑZZ¹(ƒfäcHÂEd!KàfU'ºØ¤NMÆÃ!ÛPð&Á+¤«ßvá®§nž%«ÏcC`ÛŒp&K`hh˜·Ôª¡„8ÞRk9¶Ôš,G„@%Á+PkÎùnhî¥÷NÖñþŠY”ÏÁp Þ#PÅk~9×@÷¬Ê¡Ì´Xï%„˜ALE‚—©ªÃ?™V×°…h3ÿ@ËåŸz@ª¡GÀbá¾w¢ŽûÑÝ+s(,›n‡^+@‰C¯P«q«òÖ4tÓ¡3
ê¥ŸŽ¯n+<¸Ÿ¨oê¡÷OÕÓ†åY”—m³O #ð^~ïd+ëºè0Yy`&É—öþcµNKrü‘¤	 `EàÀ©:êë·Ð½kr•æùµý´|~:MÇ+R¸À%€½^·îÜÎùo_¹N¥×Z(?+Žêšzi¿ÜÓSbÜŽ@¼C`5·õÑ+û*¨ 'žþ¼ë&½ÜBÿû+K½“ bð9,ÌäsäþIðø…F:¥…5]DÚYNo¡Ë?UTAÀ!ùzbÛtÕOyäEÞ~è(áÃ^ÁQK1À‹7þþµë£~Zúèµw*F¯q `.»VQMCÏh¦~ÿúuêç-ˆà@Ÿl¼¿–@¾–ßçe"2ÓbÔ´õL^=;73ŽfOKr@|OàZE;Uó¾¨M­}ÔØÒ«ŽWdÃÓ÷UAÀã xy)"Û0Ôh›î‚€€€€Ç	„ü¬Æ7ß¾@Uµm‹ýC œ ÜñÄ*ŠÃ>xþ©Ϧºs)UÖ´z6RÄæ7Ò??öØJJLÀlj¿Uö;¼þò÷³ÔÔGáa~¯d`ê†úêé±û—@ðš:JSÄð×7ÏRC{,¯i…þiŠ
™b&,ýõôÈöż¦ÈÁ›@È^R}±q©Ø5‰Ü+CíH°5…˜øTŠ@ÿŠj
ÃèBPT$
1%°ñš>×	@ðr|‚€€€À”@ðš>×	@ðr|‚€€€À”@ðš>×	@ðr|‚€€€À”@ðš>×	@ðr|‚€€€À”@ðš>×	@ðr•Ï|Z†èê™—|–^GK%U_?ì³ô2ôÏ@®=äüO‚—“:(;ú;:ýö¿9ôÕ×ÓJý½ý8zØÕVCÖ¡Q/ì§Ûü_œ+étãĘfXX8xëÇ44ØïFÞñ:U–ÞÉb5+WúÇTÛ”±¯ôÏÉ¿ëÌÚŽ/ð%^NhçÏÙH3>èÐ×…Ïѵ³/;ôãèáž?O¢uÒîVÙnš6ï^uéJú:œ;Gcš‰©ù”Àÿ«¯t'
¯ø*K¯d
‘š–€+ýcªmÊØWúçäßu¦mHÈø6Év»­ñºÒf¥çΧ“{~B‹¨âÒ^ÞT;š–Ýó5jª¾Hµ·Nðu$5ՔцǞ¥º›'x¨ðo4ÄC†3‹ï§Y‹¢îö:ºxä·”˜V †õ$¾¥w™Nïÿwè£Sû~J93VÒ¼•£†ÊóT²í[*gÎÒOJ›¦´e—Oþ‰jËQxd4Í^ô0n[¨¼t'ç¡ŸŠ–>®â:øêwT¼¼0.Íⵟ¦¼Yk©ææ±QOè<'¤æqžQö´åìo=]<üŽþ‰þ‰^¾ —Êòrno*W¾«Kéò±?²`ñ
±tþÀ/YX*a¡¢ˆ²—SñÚOQWk5ÝõC >®3ïüŒšk/Ñ@Ý,ÛÅš­Û´xãÓTyõ=VÞ¥¢eO(¡mö’GYèÙÊiÝ¢áá!ŠOÊQi:K_<‰ÐuëÒZ²éK4‹µsÇßúµÔ]á´*¨ƒãÓ®¾â
tOHSž'$窼k¿rÔy–¡–…ë>G×μLÇw?KóV}œÂÃ#X{Ayïína!ò«J”¯ö¦›*lyé›ÔÓQOËïù:¥©ìب¯§M
Ý–l{FÅ>Ö­Y†…EÐÁ×¾CâGѨØ$áqxØbÌ&ÎC˜€³þaݦÂÃ#éÐëßSí}·Á–ú+t‘Í	t[GÿDÿáî„¢ûˆÀدžôd–ßû
¥™šYüu²%Øø4ŠOΡŒ¼…Jk—Ášr%dÅ&¤+­˜”[´=«îûg¥9íVW[µÚÂXˆI˞˚œBþòn§˜¸T%ŒÙbe¾ø©¾vPiÊ2òÒôùÛ”XWqÒVpuOEcšr36!“ºÚk'„‘<¯ÜþO¬éZC©9sYXú¬ÒŽM›¿•:Y ·`õS*ß­U›H=꾄-Ùþm’|ÍYòs(Uktl²ÒÆÄ¥ŒÓ^Y³ìÖú«´tóW™Ï]J°ìl­dnó©ÄŸ'`Ý?¬ÛTOgƒjSÒ6*Ïr_KQýý3ä;ø„/71‹!º¸ðˆ¨qñ:šÁÖEREÊÿy%¥\ZÄIذ°0ua'¼D‡=g+ýÖbEE'Œ‰à!¼cý0—µCCƒ½£yˆOŒy–ri
Uxxq"<¼8@{_|Z
öò° 
…–Yó¿ˆÈÚ¶ãWK{ÿøEÖü½?Àêl°¿›ÓPš<ŠàaTq¢„[lõ£¿Á^u§úgîŒÕüáðIuÏØ^Ñ?Ñ?íç àI¼<@S„°nR“!0Ñ&ÉpZá]›iÎÒǨpîfÖŠe:LE„ùæ!Æ,ž]اlN2Ö.¦ó0QÞ¸0¸GŒmJ†ðå#EÚºôO±¿LÍšã(8H¡j@F–蟚
Ž à×;ã¥4T#Z*ÑVé/jã×±¦zãûÊnë‘ÿö
åÍ^G;³C
Aö÷ÐæËr–ÿ:ÑœÝÑ~ÌÙ l™
‹î¦Õ~Ÿ‡ë’©µáš’ñã8ýE>O‡_ÿ>½ö‹G”à#ÆôÙÓWPRú4ºvîUzõçSZÎ<5¬"ùgLs͇þ'µ5\§ôÜ£Ù9Ÿçñåçx8ÿb\ŸY°„v¾°ƒ¢bâ•FKG"Ú±²£¿¥»D	)y<ìø-5ùî_¿¡ò2À­ýX{WG#ËG¿ô*Ûv=C§÷ý+e[9Ñ®~à»v‡aÇE„‹Ð àfÿTmЇ¿Oîý:÷þ/H>¬ùåL_‰þ‰þ}¥ô;0Ö²‡üžŸgàÓßüuYòÔС­Ä•!7#’}™µ'³µ3^«!Ö숈8YK†Ê¢Y뤇þÕP A“uÀ"£ãÔ0¥Ìž”/Ë[¿9¢Ar!}IS~Dd˜Q†µ“ê•|DÅ$LÈ¿NSü¾ñ«'iý#? ¬Âe:¨:ól±rY"ÆÊÃ×:-Ñò‰À(Àr¯­ñ½ó—¯Óc_~]i¹Dû¦ÄÓßölñ©×D¥«5K)ÈíÛX:.ëãP×
úå>LYIÖp€>ûÌ‹Ô>˜ËÂ<´mÃM¶J›’õ½¢cGû¾±­£z©v—ÓÏø$åd%Û¨MÜÐ 0ñW/4Êír)•f‹….qF¡ËúZf9j¡KžÉÙˆ`1¦å2†AÎ(tDóŒ=-Ä-R-ß ñ¸š¾øáJÇ!×âDè“ûâŒé˵N³‘—¯#dk¡Ë:ŒÄ­…HõÌ à‰`%Ï&¦~GÓ&!Fœø‘IÆòëgr´f)ñ7£_œ‡6Wû‡­6%“bäG;cÿ@ÿDÿÔíGð4h¼œh¼<
ÜÕøäKÞž`âjîøótz¢1e&&c7æN¾­ýBãeM$°¯i¼üU:O÷gåðtz~ëŸÐx9«j<Ðx™´’})t	O§'_]&­Jd+	xº¿8CäéôÐ?Çsð^Þc‹˜A@@@`^ãpà@@@¼G‚—÷Ø"fG‚×8¸ï€àå=¶ˆ@@@Æ[isÜíкèínæY}#ku…VɃ¯´–þ¾à+Tˆ—¨·«™×‡Cÿ†f€þµˆ2L•@È^Ÿ|r%ÕÖ·O•£i÷tX¨©ÍBíÝJŒ§ÔÄ0ÊN
Þ±°°i”cÚú@ÆÜ#ðÔÜ?ëÚÜ`¾k›‡¨«g˜,¼‡ÈìüŠ[t9ÀŠâ4»Ò?“cúƒf!¿€j0V®…ßà§Êš¨¦¾‡Š¦'QqQ*í>XEl, ›Utör3Í*L¤¥óÒƒ±ø(¡¡azëPÍž–Dóg¥PwÏ í¥&EKrß‘JMŽ¢•3ÇÝÇ€@p€àõØÜÚGGÏ7¨!Š
üµœ˜0¶ÿœ.žµà¥ïCÓ$p߸t£•Ê+;•:Üΰ╛mtõVÝ¿1Ÿ"#0Ê75ƒT@À7 xù†³WR)¯ì W[)=9šÖ.Í¢ÈHû/h{‚—ÎØ¨–ϰЀi.8‚€'	ì?*Ú¬hZQœá4ÚöÎ5¹eu.e¦Á.Ê)0x!Á+@*Jgsxx˜Î\j¦ÊÚnšÉvZKæ¦éGÎ/øVu§ŠF~-_àüÇA‡Ã@À>î^¶ßz¿Š6•dSN†ëö[Òßß:TMùlóµ6™öã	^RYýýl¿u®Úø+xk¤
sÜʹ«‚—Ž´‚°Ó,àAÓDpÉ™ÓeÍôÐæBŠŽ²¯•v{éÕª¬ë¦ûÖçóÒÁ;ëÑ<`!ÁËä5ÙÚÑOGÙ`žgšÓúeY”beˆëjöݼt¼Z›ž—àÒðˆ‡#€Ñá³õd±m\‘=e-í}$†÷ÛYø²6ÈŸräˆ@Àg xùµ{	UÔtÑ9^ö!91ŠÖ-Ëžô—²Nu²‚—/ù9ÍKT@ÓDpû-<´X©–l‘™Ãžr²TŒ,91gZ"-˜ê©h€€	@ðò!lW’:ÿA3¯µÕÅC‰ñlc•Î+ê{fXaª‚—ÎûíÚ.:u˜æ#X¨oî¥÷OÔÒƒwRBœwÖ¨>s©‰šZûiëÚ\½#¬Ëkï€àå®nÅ*_ÇÇØ~«¹­Ÿ
hÓÈ“_È:#ž¼t|Z›Æ¶f%a„¯¹àÚdqâ¼¶­Ëóº@$Þ{"àñÂȶ–	íš@éAÀ¼ xù±n:ºÔ‚§ƒCµDzŠ÷¶ºñ´à¥±U²ì$kÀD[Qì9
ŽG2¸‡gÊP¼ìá+7ÈmoòæBN³hz²¯’E: S Ák
ð&´º¾[ÙKÅó0Äz¶ßŠñþÞ‰Þ¼4-€Éi	¯Qä©!R?Ž `Víýjɇ­kóÈ›OŽÊüB#õð’›Wå:ò†g & Áˇ•Pv­•®UtP^vœN|9-ÜÛ‚—ÆXYǰÒ&*ÌaŒ‡ !€i28#+7ÛIV™pS!EDxÆs²œäƒN¶ûÐÝëÛ²Éæ
á@Æ@ðcá•3F”¯Q±ûðçp€¯/
˜&c°x—í«Yk½r‘yöT”õþþÎC¢užÁ»PÀ˜//ÕIWÏ >SO}l¿µ$Óï[~øZðÒX«xÑÇ¥TÀ°•Ѐi,80ÙŒþÍ÷*y™—,ÊÏÛˆÞLE:tºžµÍ¼öïÝ
 `.¼<\uM=,h4QltmXže•¿¿/W†ADóLÁ1	ˆ&÷øùFµ
}÷q3»Ñóyè1Úäy53Gä
h¦º¦^Ú¾.“[‚¦VQ3IÁKŒKÏ]n¡Šš.šžŸ@Ëæ¿i°^ºAÓ$pœ,™Èqø6•6ò“™Ûï¯
Šå3ŒåÂ9˜‰@H	^ý¼oâÑs
$v
Ëæ¥±Ð:›È›à¥;Q=kÀdˆ83=†÷ÄÌ‚LƒÁÑ!“©½s€î]“çÐ_(>ÔÆÊ¢Ðsg&‡"”¼J $¯öÎ~µà)+ºhÛp¤&G{ª#VÁK³†¦IàèˆÀÐЈ=SÑô$µÛ„#¿¡þì$ï,ÑÑ5H[Öä†:
”G„ŠG·L£>8ËõHBħÀd·õV˜ÌDÛ3_C8]ÖDÒÆ7±±·¬ãöþÉ:5´ˆe|[WR»xÖãÝlK—Í{©Š	ÈààpHî{é[òH-P	˜ba±ÇøÏÿºL?y¦DÙ܈!ôñ¬ÆW?‚x‘N®y%ðº;l, —÷U¼Ptx=¸©pr‘y¨Ì´Xz„…Ò&ž+Sçµ6Äÿ¯Ÿ¡¯í˜O‹îJr
Uzï4µÞlm<ÓjDÐÚÎ+uÇÇF(Â3±,"€ÝÏÚÂóWZTaÛx_¿ÿxñYXç/ï­Í„ÿâÏWF¯qâ¿{íºú€Ö©¿´ç–>Å@À@À¯CÏ¿t•.—·fçfu]»ÕNE3°1ë(œûØ£%T›êv®Ü¨£7ö–òªûX;Ë1¼0zdÛ"š7'DZ·)>Ýýn•~P=ÅX‚'øæ5E´jÙLSèØ™r:püº)óf¦LÉûé“O®¢ìL,ad¦z	ö¼Dî~ï2Ýnˆ ð¬hÞ×ÓNÓÒ¦,x½¾ïïÏ“ìmÇ¥ò
÷7Ó½ëçNJðª¨j¦ýÇ*)2Æ}¡Í¥Ì‰§ÞVZVGÑ1‰“/d„îo¢m)†àum¦"ª5£b’(2
P
x¤nÂyOĨ8ÓÖ5%¦QÜ.£ÁÒ1CK¯ãç|*èxƘQY,C¦ÏxTT<Å ï8¬'u8|އ à
°ñòUÄ	   6@ð²·@@@@À xyƒ*â xÙ€‚[    à
¼¼Aq‚€€€€
¼l@Á-ð^Þ Š8A@@@À^6 à€€€xƒ/oPEœ    `ƒ€i¯ö¦›T{ó˜,»«úÆaêh¹í~À QuíuµÕ¸TJ0Ád «g^r‰™'<»ûÝywt´TRõõÃî'‚np§^ÜŽÜFôPpË\¼öýñ‹,w˜aùžä6CƒýÔÝQ?.þó~EíÍãî¹zaŸtøËÇÿèjpŸû+;ú;:ýö¿9L׺L=ÛxØ×ÓJý½cÛcô÷¶Ó±]ÏRXxµÔ}@»~ó
[l„¹eF†Þn—Rrëv]ñÁ~ºÍÿŹRoÊ£¬ëÙŒÜ]-Ž+íʺ¼®Æ­ýÙ
¯ß®¤N'Þú1I<¡ä¼ÝwÕ‹pv%}wëÃ:Í@î;î–þ‹€K‚WñÚÏPZö\‡%ÛóâçY«Téн‡
•géÀ+ÿ4ú¸¯»•jo§is·ŒÞsçÄ:>‰§òê{4ÐßíN4>ó›?g#Í\ø Ãô¬Ëäг‡y(/iÉG ¸å÷~C±Yü—µšŽ/ÓÐ`¯b"u!êõ˜¸Öd]-šuø¸Ä,ŠOcá3‡5Œ•¿~f—0¢Q
xçĺŽz:Ô33t¥]ŠV#.!ƒ5§åêã!6!…õn®„!?&.U}dX3“kkîî¶]ùʷ¹Kþ\uÖíÊVßwôî°ækÝ/¬ßÖù²N_·ëXî]òn°.“'®]iûŽúŽ5W[õj|§[çÙVúîö[mAÒ	–¾cÍ×MÀmÁK„¥°°0Uêþ¡°eP?8ÐCá,,DðP™üP•l}F	(œ4íä‡F4ÖN42C½êö ÛeÉ‚hÄI§Âq>„+wþªøFòm¼oÆsmk¥øØ˜°0Rb&qŠqîŒÕ´`õ'G‹â,¼x”aG©'[Î^™™¡Ëí’Û–´Kù?¯ä£”;kBàJxù¸°ÇL"±æî°íRp¶][íIß³×®ôs9:zwXó5†ÓçÆw‡¾§öÒ—ýNÓ~CéèJÛWõb§ïØãjdè¨^l¥ï°ïé{ßÈçÁMÀmÁËŽðp“o`AhXiÀ†‡I†ÅŽ+gæ*þ¡²T=ëíj!‡‘ðqIY£³ئK„­ÛWÞQá+.ícíM:³åQLlŠš‘'³Yn²ý˜v’–1>¹ßó&åë'P±Lq‰Ù<,• Ê#|ÅN)5kŽÃ¢Ix™9ªUÑ‚YÏ$u?D†ÖíR¾¼ïÚ¬ÚeáÜÍNÛ„1|<·K±s4ÎuÄÌaÛe-ekråÃEìdÄÞKœ±ž¥/ˆDî*ã.ü±.¯hÏ=õîp!yåEúA ¿\-§»þŒm_êžc]¯Æwº+ùpØwBè½ï
+ø	<.	^ò58òµÉGþ7êdÜíWÁœ
Êëø®ÿ«ŒéÛ›oÑߟûúû¯ÿ‘NìþñHåw,üH\ajQ„«Wþ0UñìÃôÜ$6"$ˆ¦«„í˜NïûWzý—Ó…C¿¦Õ|Wilæ²Æ¢ôðôú¯'‹e€Ó‰[fëã³Xنꦊw4ïf:1pc=¢AÑ_âÆ2Õ–¡¶Ÿ¹/ì 7˜³Lí…â›<#ïµ_<ʬ™Å|jãI#ÎP¯†¼È3]Gfd8VVCþU¦m·Ë¼YëÔ䊿١–ÏØóûÏñð³,câZø„”e‹ØÚpMR‘
’?wN9érçŽVØQÛ>+O4iVm¾üâ®;á‰}ÁŒÜU!]úcàj`¥)nË+ql½;ÆêzŒ¯ÄcìÖï©ù§œôå™ôyç„’ãi`$ì¼Óíö;\׋Ô!§«ú‹íôõ yï‡RƒCYÇûòwÿßpm{êèPÞ¸§w.D%CSâŒçj˜q´‘Òˆ­–¨•Å)ãUþjbzqJÛÂ×2d#Î:®þ¾54)÷ßxîIZ÷ðÿQ¶dâW¾þGllRärÔ‰¶Khö⫽yB­Õóðÿ6úã8ᤛm™>õØ,úÈÃ%†»îŸ>ñôs•4Ïå€F.Æ2HÆká­É3a"ësEÇ$*M‰µëk5T(Â,ۆɺlÇyý¢G¾ø’b¢Ó1æÅÞU†Æè,½Uô½¯l¢å‹¦o»t¾ïÀ%úÏÏStÂÈ$ë@:Ïrßxî¬]J›‘¡ŒhÖ:É€;áev|ͯØúÍqíÙ˜þ„øì´]©?ɇôcxc=»Â½¯«–¾²c1Ý¿¹X•Å[¾ýì«t­&†µ­ñ.'¡Ëe¯]IDÆòꈭß:ýÜxm/÷ïíÏ^ú"ؾñ«'iý#? ¬Âe:z§Çžyýèæzzǧ~ýáá—8@;6ò¬òt›Ék.òÐxînß±ÇUâu¥^œ¥ïí÷¾¥ç6ýðl¥wåIVà@À'\Òxi¡Krd<JåË31nÕB—\ËZ]r­4ãl¼ÆV³¸ÄLœ¤1gÉc$+«k'?bDníD³¥í¿¬ó¦ã«ºö>ÍYöø¸¼ZÇãÏk#c$OÆk##y&LÄh^íŒþåžñZ8‰Ð%.gÆ*Ø’F˵?c^ÄŸ¾oF†:oÆ|ʹ³v)m&&>UñÿâŒq9
_´üI^Jå
cde?!>;mWêO÷cxc=›‘»*¼‹t¹Œ¬&ð1ô}­õ»CÇ£Ÿ¯¼ä¾ñÝ¡ýÙK¿‘—‘aww„.‡@>j.Rã¹£¶/~­ûŽ=®âוzŽÒæ÷¾”.4	ŒI>&+ÿ¢
ŸÑ(x _¢Üy¡ÝÿéߺÌG؉­ËCOÿe¤—ÏÀÝ}Àî¼;²§¯ íŸøµû‰ „ÛÜ©·#·}ÇÜ2SK#ž–<)j̃™p‡‹;~=˜ESFåK¾LË”°'™)w¸¹ãw’ÙA°;|ÉÚ—i¡‚AÀ¦¼Ü)ü‚€€€€Ù	@ð2{
!   AC‚WÐT%
   `v¼Ì^CÈ€€€@Ѐà4U‰‚€€€˜€ZN¢¯»‰î,zjö{3}‰Þbæm^ê<W GbèšRúz;iÈ–Ž ö{¦Ý:JC?“wÅ`‡¾Ùã@·ë°S—¿¿·1•=àì°tÛ{„û à5‘ŸùÇÕTÛ€©&¼h^¾>ôñkŸ¹›zzñÂÓ§å§éS·Žó‹réËŸXåV˜Põ\<×û+oüÑTYÛªˆ'”{þó
^[ÖÏ¥ü܉NO(nP^68¡ø–Àÿ¸†â§ãIEND®B`‚loki-ecmwf-0.3.6/example/05_argument_intent_linter.ipynb0000664000175000017500000040166715167130205023541 0ustar  alastairalastair{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "29b239f0",
   "metadata": {},
   "source": [
    "# Argument Intent Linter"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e6ef992",
   "metadata": {},
   "source": [
    "The previous notebooks have all focused on using Loki to perform code transformations. Loki can also be used as a Fortran linter. In this notebook, we will use Loki to check whether the declared intent of a subroutine argument is consistent with how that variable is used.\n",
    "\n",
    "For brevity, only the core functionality of a subroutine dummy argument intent-linter is developed here.\n",
    "\n",
    "Let us start by first examining the sample subroutine that we will use to illustrate this notebook:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "65da8ec9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "MODULE kernel_mod\n",
      "  USE parkind1, ONLY: jpim, jprb\n",
      "  IMPLICIT NONE\n",
      "  CONTAINS\n",
      "  SUBROUTINE some_kernel (n, vout, var_out, var_in, var_inout, b, l, h, y)\n",
      "    \n",
      "    INTEGER(KIND=jpim), INTENT(IN) :: n, l, b\n",
      "    INTEGER(KIND=jpim), INTENT(IN) :: h\n",
      "    REAL(KIND=jprb), INTENT(IN) :: var_in(n)\n",
      "    REAL(KIND=jprb), INTENT(INOUT) :: var_out(n)\n",
      "    REAL(KIND=jprb), INTENT(INOUT) :: var_inout(n)\n",
      "    REAL(KIND=jprb), INTENT(INOUT) :: vout(n)\n",
      "    REAL(KIND=jprb), INTENT(INOUT) :: y(:)\n",
      "    \n",
      "  END SUBROUTINE some_kernel\n",
      "END MODULE kernel_mod\n",
      "\n",
      "SUBROUTINE intent_test (m, n, var_in, var_out, var_inout, tendency_loc)\n",
      "  USE parkind1, ONLY: jpim, jprb\n",
      "  USE kernel_mod, ONLY: some_kernel\n",
      "  USE yoecldp, ONLY: nclv\n",
      "  IMPLICIT NONE\n",
      "  \n",
      "  INTEGER(KIND=jpim), INTENT(IN) :: m, n\n",
      "  INTEGER(KIND=jpim) :: i, j, k, h, l\n",
      "  REAL(KIND=jprb), INTENT(IN) :: var_in(n, n, n)\n",
      "  REAL(KIND=jprb), TARGET, INTENT(OUT) :: var_out(n, n, n)\n",
      "  REAL(KIND=jprb), INTENT(INOUT) :: var_inout(n, n, n)\n",
      "  REAL(KIND=jprb), ALLOCATABLE :: x(:), y(:)\n",
      "  REAL(KIND=jprb), POINTER :: vout(n)\n",
      "  TYPE(state_type), INTENT(OUT) :: tendency_loc\n",
      "  \n",
      "  ALLOCATE (x(n))\n",
      "  ASSOCIATE (mtmp=>m)\n",
      "  ALLOCATE (y(mtmp))\n",
      "  END ASSOCIATE\n",
      "  \n",
      "  ASSOCIATE (mtmp=>n)\n",
      "  DO k=1,mtmp\n",
      "    DO j=1,mtmp\n",
      "      DO i=1,mtmp\n",
      "        var_out(i, j, k) = 2._jprb\n",
      "      END DO\n",
      "      \n",
      "      ASSOCIATE (mbuf=>mtmp)\n",
      "      var_out(m:mbuf, j, k) = var_in(m:mbuf, j, k) + var_inout(m:mbuf, j, k) + var_out(m:mbuf, j, k)\n",
      "      END ASSOCIATE\n",
      "      \n",
      "      vout => var_out(:, j, k)\n",
      "      \n",
      "      ASSOCIATE (vin=>mtmp)\n",
      "      CALL some_kernel(vin, vout, vout, var_in(:, j, k), var_inout(:, j, k), 1, h=vin, l=5, y=y)\n",
      "      END ASSOCIATE\n",
      "      \n",
      "      NULLIFY (vout)\n",
      "      \n",
      "      ASSOCIATE (vout=>tendency_loc%cld(:, j, k))\n",
      "      \n",
      "      ASSOCIATE (vin=>var_in(:, j, k))\n",
      "      CALL some_kernel(mtmp, vout, var_out(:, j, k), vin, var_inout(:, j, k), 1, h=mtmp, l=5, y=y)\n",
      "      END ASSOCIATE\n",
      "      \n",
      "      END ASSOCIATE\n",
      "      \n",
      "      DO i=1,mtmp\n",
      "        var_inout(i, j, k) = var_out(i, j, k)\n",
      "      END DO\n",
      "    END DO\n",
      "  END DO\n",
      "  END ASSOCIATE\n",
      "  \n",
      "  DEALLOCATE (x)\n",
      "  DEALLOCATE (y)\n",
      "  \n",
      "END SUBROUTINE intent_test\n"
     ]
    }
   ],
   "source": [
    "from loki import Sourcefile\n",
    "\n",
    "source = Sourcefile.from_file('src/intent_test.F90')\n",
    "print(source.to_fortran())\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b365ec21",
   "metadata": {},
   "source": [
    "## Retrieving variable intent"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71b72f2e",
   "metadata": {},
   "source": [
    "We can access all the variables declared in subroutine `intent_test` using the `variables` property of the [_Subroutine_](https://sites.ecmwf.int/docs/loki/main/loki.subroutine.html#loki.subroutine.Subroutine) object:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0566471b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vars:: m, n, i, j, k, h, l, var_in(n, n, n), var_out(n, n, n), var_inout(n, n, n), x(:), y(:), vout(n), tendency_loc\n"
     ]
    }
   ],
   "source": [
    "routine = source['intent_test']\n",
    "print('vars::', ', '.join([str(v) for v in routine.variables]))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6244a0f",
   "metadata": {},
   "source": [
    "In the Loki IR, variables are stored as symbols with base `class` [_MetaSymbol_](https://sites.ecmwf.int/docs/loki/main/loki.expression.symbols.html#loki.expression.symbols.MetaSymbol) and the `intent` of a variable is stored in the `property` [`MetaSymbol.type`](https://sites.ecmwf.int/docs/loki/main/loki.expression.symbols.html#loki.expression.symbols.MetaSymbol.type). To retrieve all variables with declared intent, we only need to look through the subroutine `arguments`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3b251eb0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "in:: m, n, var_in(n, n, n) out:: var_out(n, n, n), tendency_loc inout:: var_inout(n, n, n)\n"
     ]
    }
   ],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "intent_vars = defaultdict(list)\n",
    "for var in routine.arguments:\n",
    "    intent_vars[var.type.intent].append(var)\n",
    "\n",
    "in_vars = intent_vars['in']\n",
    "out_vars = intent_vars['out']\n",
    "inout_vars = intent_vars['inout']\n",
    "\n",
    "print('in::', ', '.join([str(v) for v in in_vars]), 'out::', ', '.join([str(v) for v in out_vars]), 'inout::', ','.join([str(v) for v in inout_vars]))\n",
    "assert all([len(in_vars) == 3, len(out_vars) == 2, len(inout_vars) == 1])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d1732e23",
   "metadata": {},
   "source": [
    "## Separating variables from dimensions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "82041e9c",
   "metadata": {},
   "source": [
    "In Loki, the most general way of retrieving the variables used in an expression or node is the [_FindVariables_](https://sites.ecmwf.int/docs/loki/main/loki.expression.expr_visitors.html#loki.expression.expr_visitors.FindVariables) visitor. In the IR nodes in the body of a subroutine, the `FindVariables` visitor will return variables that appear in their own right, as well as any variables used for array indexing. An example of this is seen when `FindVariables` is applied to an [_Allocation_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Allocation):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "56c5c4e7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x(n), n\n"
     ]
    }
   ],
   "source": [
    "from loki import FindNodes, FindVariables, Allocation\n",
    "\n",
    "alloc = FindNodes(Allocation).visit(routine.body)[0]\n",
    "alloc_vars = FindVariables().visit(alloc.variables)\n",
    "print(', '.join([str(v) for v in alloc_vars]))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2ffad04",
   "metadata": {},
   "source": [
    "Utilities to distingiush between variables and their dimensions can be constructed by wrapping small functions around `FindVariables`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "2a1de40b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from loki import Array, flatten\n",
    "\n",
    "def findvarsnotdims(o, return_vars=True):\n",
    "    \"\"\"Return list of variables excluding any array dimensions.\"\"\"\n",
    "\n",
    "    dims = flatten([FindVariables().visit(var.dimensions) for var in FindVariables().visit(o) if isinstance(var, Array)])\n",
    "\n",
    "#   remove duplicates from dims\n",
    "    dims = list(set(dims))\n",
    "\n",
    "    if return_vars:\n",
    "        return [var for var in FindVariables().visit(o) if not var in dims]\n",
    "\n",
    "    return [var.name for var in FindVariables().visit(o) if not var in dims]\n",
    "\n",
    "def finddimsnotvars(o, return_vars=True):\n",
    "    \"\"\"Return list of all array dimensions.\"\"\"\n",
    "\n",
    "    dims = flatten([FindVariables().visit(var.dimensions) for var in FindVariables().visit(o) if isinstance(var, Array)])\n",
    "\n",
    "#   remove duplicates from dims\n",
    "    dims = list(set(dims))\n",
    "\n",
    "    if return_vars:\n",
    "        return dims\n",
    "\n",
    "    return [var.name for var in dims]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "808792c4",
   "metadata": {},
   "source": [
    "A quick test reveals:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "09018be6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vars:['x']\n",
      "dims:['n']\n"
     ]
    }
   ],
   "source": [
    "print(f'vars:{findvarsnotdims(alloc.variables, return_vars=False)}')\n",
    "print(f'dims:{finddimsnotvars(alloc.variables, return_vars=False)}')\n",
    "\n",
    "assert len(findvarsnotdims(alloc.variables)) == 1\n",
    "assert len(finddimsnotvars(alloc.variables)) == 1\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c242a047",
   "metadata": {},
   "source": [
    "## Resolving associations"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bba91785",
   "metadata": {},
   "source": [
    "You may have noticed that `intent_test` contains several nested associations. The simplest way of dealing with these is to resolve all the associations before we begin linting the program:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "97e3e472",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "ALLOCATE (x(n))\n",
      "ALLOCATE (y(m))\n",
      "\n",
      "DO k=1,n\n",
      "  DO j=1,n\n",
      "    DO i=1,n\n",
      "      var_out(i, j, k) = 2._jprb\n",
      "    END DO\n",
      "    \n",
      "    var_out(m:n, j, k) = var_in(m:n, j, k) + var_inout(m:n, j, k) + var_out(m:n, j, k)\n",
      "    \n",
      "    vout => var_out(:, j, k)\n",
      "    \n",
      "    CALL some_kernel(n, vout, vout, var_in(:, j, k), var_inout(:, j, k), 1, h=n, l=5, y=y)\n",
      "    \n",
      "    NULLIFY (vout)\n",
      "    \n",
      "    \n",
      "    CALL some_kernel(n, tendency_loc%cld(:, j, k), var_out(:, j, k), var_in(:, j, k), var_inout(:, j, k), 1, h=n, l=5, y=y)\n",
      "    \n",
      "    \n",
      "    DO i=1,n\n",
      "      var_inout(i, j, k) = var_out(i, j, k)\n",
      "    END DO\n",
      "  END DO\n",
      "END DO\n",
      "\n",
      "DEALLOCATE (x)\n",
      "DEALLOCATE (y)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from loki import fgen, SubstituteExpressions, Associate, Transformer\n",
    "\n",
    "assoc_map = {}\n",
    "for assoc in FindNodes(Associate).visit(routine.body):\n",
    "    vmap = {}\n",
    "    for rexpr, lexpr in assoc.associations:\n",
    "        vmap.update({var: rexpr for var in FindVariables().visit(assoc.body) if lexpr == var})\n",
    "    assoc_map[assoc] = SubstituteExpressions(vmap).visit(assoc.body)\n",
    "routine.body = Transformer(assoc_map).visit(routine.body)\n",
    "print(fgen(routine.body))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4797ecc6",
   "metadata": {},
   "source": [
    "In Loki, an [_Associate_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Associate) statement is a `ScopedNode`. A `ScopedNode` is a mix-in that attaches to an [_InternalNode_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.InternalNode). It declares a new scope that sits within the `Subroutine` scope, but also defines a few of its own symbols. This means that the new variables declared in an `Associate` statement are only in scope in the body of that particular node. Therefore to resolve the associations we can simply apply the [_SubstituteExpressions_](https://sites.ecmwf.int/docs/loki/main/loki.expression.expr_visitors.html#loki.expression.expr_visitors.SubstituteExpressions) visitor to the `Associate`'s body, as shown in the previous code-cell.\n",
    "\n",
    "Resolving pointer associations that use the `=>` operator is a little more involved. Firstly, pointers and targets must be declared in the `Subroutine` specification, unlike an `Associate` statement which declares new symbols. It would thus be wrong to think of a pointer association as having its own localised scope. Secondly, an associated pointer can be disassociated either by exiting the encompassing scope (i.e. exiting the `Subroutine`), by using the [_Nullify_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Nullify) intrinsic or by assigning to `NULL()`.\n",
    "\n",
    "Therefore before we can apply the `SubstituteExpressions` visitor to resolve pointer associations, we must first determine the *range* of nodes over which the pointer is valid. The best way to do so is to develop a bespoke visitor: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d8b139a5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Comment:: \n",
      "Call:: some_kernel\n",
      "Comment:: \n"
     ]
    }
   ],
   "source": [
    "from loki import Assignment\n",
    "\n",
    "class FindPointerRange(FindNodes):\n",
    "    \"\"\"Visitor to find range of nodes over which pointer associations apply.\"\"\"\n",
    "\n",
    "\n",
    "    def __init__(self, match, greedy=False):\n",
    "\n",
    "        super().__init__(match, mode='type', greedy=greedy)\n",
    "        self.rule = lambda match, o: o == match\n",
    "        self.stat = False\n",
    "\n",
    "    def visit_Assignment(self, o, **kwargs):\n",
    "        \"\"\"\n",
    "        Check for pointer assignment (=>). Also check if pointer is disassociated,\n",
    "        else add the node to the returned list.\n",
    "        \"\"\"\n",
    "\n",
    "        ret = kwargs.pop('ret', self.default_retval())\n",
    "        if self.rule(self.match, o):\n",
    "            assert not self.stat # we should only visit the pointer assignment node once\n",
    "            self.stat = True\n",
    "        elif self.match.lhs in findvarsnotdims(o.lhs) and 'null' in [v.name.lower for v in findvarsnotdims(o.rhs)]:\n",
    "            assert self.stat\n",
    "            self.stat = False\n",
    "            ret.append(o)\n",
    "        elif self.stat:\n",
    "            ret.append(o)\n",
    "        return ret or self.default_retval()\n",
    "\n",
    "    def visit_Nullify(self, o, **kwargs):\n",
    "        \"\"\"\n",
    "        Check if pointer is disassociated, else add the node to the returned list.\n",
    "        \"\"\"\n",
    "\n",
    "        ret = kwargs.pop('ret', self.default_retval())\n",
    "        if self.match.lhs in findvarsnotdims(o.variables):\n",
    "            assert self.stat\n",
    "            self.stat = False\n",
    "            ret.append(o)\n",
    "        elif self.stat:\n",
    "            ret.append(o)\n",
    "        return ret or self.default_retval()\n",
    "\n",
    "    def visit_Node(self, o, **kwargs):\n",
    "        \"\"\"\n",
    "        Add the node to the returned list if stat is True and visit\n",
    "        all children.\n",
    "        \"\"\"\n",
    "\n",
    "        ret = kwargs.pop('ret', self.default_retval())\n",
    "        if self.stat:\n",
    "            ret.append(o)\n",
    "            if self.greedy:\n",
    "                return ret\n",
    "        for i in o.children:\n",
    "            ret = self.visit(i, ret=ret, **kwargs)\n",
    "        return ret or self.default_retval()\n",
    "\n",
    "for assign in [a for a in FindNodes(Assignment).visit(routine.body) if a.ptr]:\n",
    "    nodes = FindPointerRange(assign).visit(routine.body)\n",
    "    for node in nodes[:-1]:\n",
    "        print(node)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3fceef8d",
   "metadata": {},
   "source": [
    "As the above output shows, our new visitor correctly identifies all three nodes over which the pointer association applies. We can now finally proceed to resolve the pointer association:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "aef2e705",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "ALLOCATE (x(n))\n",
      "ALLOCATE (y(m))\n",
      "\n",
      "DO k=1,n\n",
      "  DO j=1,n\n",
      "    DO i=1,n\n",
      "      var_out(i, j, k) = 2._jprb\n",
      "    END DO\n",
      "    \n",
      "    var_out(m:n, j, k) = var_in(m:n, j, k) + var_inout(m:n, j, k) + var_out(m:n, j, k)\n",
      "    \n",
      "    \n",
      "    CALL some_kernel(n, var_out(:, j, k), var_out(:, j, k), var_in(:, j, k), var_inout(:, j, k), 1, h=n, l=5, y=y)\n",
      "    \n",
      "    \n",
      "    \n",
      "    CALL some_kernel(n, tendency_loc%cld(:, j, k), var_out(:, j, k), var_in(:, j, k), var_inout(:, j, k), 1, h=n, l=5, y=y)\n",
      "    \n",
      "    \n",
      "    DO i=1,n\n",
      "      var_inout(i, j, k) = var_out(i, j, k)\n",
      "    END DO\n",
      "  END DO\n",
      "END DO\n",
      "\n",
      "DEALLOCATE (x)\n",
      "DEALLOCATE (y)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "pointer_map = {}\n",
    "for assign in [a for a in FindNodes(Assignment).visit(routine.body) if a.ptr]:\n",
    "    nodes = FindPointerRange(assign).visit(routine.body)\n",
    "    pointer_map[assign] = None\n",
    "    for node in nodes[:-1]:\n",
    "        vmap = {var: assign.rhs for var in FindVariables().visit(node) if assign.lhs == var}\n",
    "        pointer_map[node] = SubstituteExpressions(vmap).visit(node)\n",
    "    pointer_map[nodes[-1]] = None\n",
    "routine.body = Transformer(pointer_map).visit(routine.body)\n",
    "print(fgen(routine.body))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "feacf206",
   "metadata": {},
   "source": [
    "## Modifying variable values"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "852fd85d",
   "metadata": {},
   "source": [
    "Putting aside function or subroutine calls for the moment (and also ignoring I/O), there are only two mechanisms for modifying the value of a variable. The obvious one is an [_Assignment_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Assignment) statement, where the `rhs` value is assigned to the `lhs` value.\n",
    "\n",
    "Values can also be assigned to a variable by using it as the induction variable of a loop. Although an extremely unusual practice, Fortran compilers do allow dummy arguments of kind `intent(out)` or `intent(inout)`  to be used as the induction variables of a loop. This is however (in my humble opinion) bad coding practice; for ease of readability, local variables rather than dummy arguments should be used as loop induction variables. Therefore in our linter rules, we will forbid the use of variables with declared `intent` as loop induction variables.\n",
    "\n",
    "To enable us to check all our linter rules in just one pass of the subroutine's IR, we will create a new visitor: `IntentLinterVisitor`. The next section creates the `IntentLinterVisitor` visitor and defines linter rules for checking `intent` consistency within the `Subroutine` body. We will later examine how this can be extended to subroutine calls."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4db06cc",
   "metadata": {},
   "source": [
    "## Checking `intent` in `Subroutine` body"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fec94f1c",
   "metadata": {},
   "source": [
    "Let us first define an initialization method for an instance of `IntentLinterVisitor`, and a method to check for rule violations:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "b9a024c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from loki import Visitor\n",
    "\n",
    "class IntentLinterVisitor(Visitor):\n",
    "    \"\"\"Visitor to check for dummy argument intent violations.\"\"\"\n",
    "\n",
    "    def __init__(self, in_vars, out_vars, inout_vars):  # pylint: disable=redefined-outer-name\n",
    "        \"\"\"Initialise an instance of the intent linter visitor.\"\"\"\n",
    "\n",
    "        super().__init__()\n",
    "        self.in_vars = in_vars\n",
    "        self.out_vars = out_vars\n",
    "        self.inout_vars = inout_vars\n",
    "        self.var_check = {var: True for var in (in_vars + out_vars + inout_vars)}\n",
    "\n",
    "        self.vars_read = set(in_vars + inout_vars)\n",
    "        self.vars_written = set()\n",
    "        self.alloc_vars = set() # set of variables that are allocated\n",
    "\n",
    "    def rule_check(self):\n",
    "        \"\"\"Check rule-status for all variables with declared intent.\"\"\"\n",
    "\n",
    "        for v, s in self.var_check.items():\n",
    "            assert s, f'intent({v.type.intent}) rule broken for {v.name}'\n",
    "        print('All rules satisfied')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c679257",
   "metadata": {},
   "source": [
    "You may have noticed that in defining our new visitor, we also introduced an attribute called `alloc_vars`. This is intended to accumulate the variables allocated in a subroutine. The reason for including this attribute will be clear later on.\n",
    "\n",
    "We can now proceed to define the rules for our linter. The rule to check whether variables with declared intent are used as loop induction variables can be implemented as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "f4a62f1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def visit_Loop(self, o, **kwargs):\n",
    "    \"\"\"\n",
    "    Check if loop induction variable has declared intent, update vars_read/vars_written\n",
    "    if variables with declared intent are used in loop bounds and visit any nodes in loop body.\n",
    "    \"\"\"\n",
    "\n",
    "    if o.variable.type.intent:\n",
    "        self.var_check[o.variable] = False\n",
    "        print(f'intent({o.variable.type.intent}) {o.variable.name} used as loop induction variable.')\n",
    "    for v in [v for v in FindVariables().visit(o.bounds) if v.type.intent]:\n",
    "        if v not in self.vars_read | self.vars_written:\n",
    "            print(f'undefined intent({v.type.intent}) variable {v.name} used for loop bounds.')\n",
    "            self.var_check[v] = False\n",
    "        self.vars_read.add(v)\n",
    "        self.vars_written.discard(v)\n",
    "    self.visit(o.body, **kwargs)\n",
    "\n",
    "IntentLinterVisitor.visit_Loop = visit_Loop\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42f4e4d1",
   "metadata": {},
   "source": [
    "For `intent(in)` variables, we don't want their value to be reassigned in the `Subroutine`. Therefore the rule for checking `intent(in)` variables is the simplest: variables of kind `intent(in)` should not appear in the `lhs` of an `Assignment`. The rule for `intent(out)` variables is that upon entry to a subroutine, a value must be assigned to them before the variable can be used: `intent(out)` variables must be written to before they can be read. For `Assignment` statements, the above rules can be implemented as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b28b2e39",
   "metadata": {},
   "outputs": [],
   "source": [
    "def visit_Assignment(self, o):\n",
    "    \"\"\"Check intent rules for assignment statements.\"\"\"\n",
    "\n",
    "    if o.lhs.type.intent == 'in':\n",
    "        print(f'value of intent(in) var {o.lhs.name} modified')\n",
    "        self.var_check[o.lhs] = False\n",
    "\n",
    "    self.vars_written.add(o.lhs)\n",
    "    self.vars_read.discard(o.lhs)\n",
    "\n",
    "    for v in FindVariables().visit(o.rhs):\n",
    "        if v.type.intent == 'out' and v not in self.vars_read | self.vars_written:\n",
    "            print('intent(out) var read from before being written to.')\n",
    "            self.var_check[v] = False\n",
    "        elif v.type.intent:\n",
    "            self.vars_read.add(v)\n",
    "            self.vars_written.discard(v)\n",
    "\n",
    "IntentLinterVisitor.visit_Assignment = visit_Assignment\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ccb50ad2",
   "metadata": {},
   "source": [
    "In principle we could also build similar checks for `intent(inout)` variables. However, the way in which some Fortran compilers treat `allocatable` variables prevents us from doing so.\n",
    "\n",
    "If an `allocatable` array is passed to a subroutine as a dummy argument of kind `intent(out)`, some Fortran compilers will deallocate that array upon exiting the subroutine. This is why in the IFS, data arrays are sometimes declared as `intent(inout)` even if their true intent is `intent(out)`. An example can be seen in the 'cloudsc-dwarf' in `src/cloudsc_driver_mod.F90`: the `REAL` array `PCOVPTOT` is declared `intent(inout)` even though it's value entering the subroutine is never used.\n",
    "\n",
    "An allocatable array passed as a dummy argument to a subroutine could thus belong to two possible categories:\n",
    "1. A variable that is truly of type `intent(inout)`\n",
    "2. A variable that is strictly of type `intent(out)`, but has been declared `intent(inout)` to avoid deallocation\n",
    "\n",
    "It would be very difficult to discern between the two options from a static analysis of the source code. As such, we will not impose any rules related to `Assignment` expressions for `intent(inout)` variables.\n",
    "\n",
    "We can however impose the rule that the dummy argument corresponding to an `allocatable` array must be of kind `intent(inout)` or `intent(in)`. To enable a `visit_CallStatement` method to perform this check, we first need a `visit_Allocation` method that updates the `alloc_vars` set:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "6b6b5179",
   "metadata": {},
   "outputs": [],
   "source": [
    "def visit_Allocation(self, o):\n",
    "    \"\"\"\n",
    "    Update set of allocated variables and read/written sets for variables used to define\n",
    "    allocation size.\n",
    "    \"\"\"\n",
    "\n",
    "    self.alloc_vars.update(o.variables)\n",
    "    for v in [v for v in finddimsnotvars(o.variables) if v.type.intent]:\n",
    "        if v not in self.vars_read | self.vars_written:\n",
    "            print(f'undefined intent({v.type.intent}) variable {v.name} used to set allocation size.')\n",
    "            self.var_check[v] = False\n",
    "        self.vars_read.add(v)\n",
    "        self.vars_written.discard(v)\n",
    "\n",
    "IntentLinterVisitor.visit_Allocation = visit_Allocation\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "feaa4c96",
   "metadata": {},
   "source": [
    "We are now ready to implement a `visit_CallSatement` method."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27f1275c",
   "metadata": {},
   "source": [
    "## Building `intent` map between function caller and callee"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df378590",
   "metadata": {},
   "source": [
    "We have already defined one linter rule for `CallStatement`s; dummy arguments corresponding to `allocatable` variables must of type `intent(in)` or `intent(inout)`. Another very important check we have to perform is to ensure that the declared `intent` of a dummy argument is consistent with the `intent` of the argument in the calling (parent) subroutine.\n",
    "\n",
    "For example, `var_in` is a variable of kind `intent(in)` in `Subroutine` `intent_test`. Therefore `var_in` must not be modified within `intent_test` or any subroutines called by `intent_test`. Hence in `some_kernel`, `var_in` must also be of kind `intent(in)`.\n",
    "\n",
    "The mapping for `intent(out)` variables is a little more complicated and depends on whether before the [_CallStatement_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.CallStatement), the variable in question has ever been written to, and if so, whether the value last assigned to it has been read at least once. The procedure for building the mapping is best illustrated using the flowchart below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "d795ea6a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAhwAAAF8CAYAAACE4mK7AAABJ2lDQ1BrQ0dDb2xvclNwYWNlQWRvYmVSR0IxOTk4AAAokWNgYFJILCjIYRJgYMjNKykKcndSiIiMUmB/xsDBwMMgzsDIIJ2YXFzgGBDgwwAEMBoVfLsGVAcEl3VBZmHK4wVcKanFyUD6DxBnJxcUlTAwMGYA2crlJQUgdg+QLZKUDWYvALGLgA4EsreA2OkQ9gmwGgj7DlhNSJAzkP0ByOZLArOZQHbxpUPYAiA21F4QEHRMyU9KVQD5XsPQ0tJCk0Q/EAQlqRUlINo5v6CyKDM9o0TBERhSqQqeecl6OgpGBkZGDAygcIeo/hwIDk9GsTMIMQRAiM2RYGDwX8rAwPIHIWbSy8CwQIeBgX8qQkzNkIFBQJ+BYd+c5NKiMqgxjEzGDAyE+AD210pB9M4YjwAAADhlWElmTU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAAqACAAQAAAABAAACHKADAAQAAAABAAABfAAAAAAQzxnzAABAAElEQVR4AeydB5wcV5H/a3POWbta5ZxztIIl2cbZGBNszuYMNhmO4w5MOri7P+YOOOAS3NmHDQZjMDgnWcmWlXPOebXanHPef/3eqle9o5nd1YaZ2Znf00c7Pd390vdNd1fXq1cV0K5JmEiABEjACwhUVdfLkVN5XtASNmEoEpg2YZjExkQMxab7RZuD/aKX7CQJkMCQIHAxt1Se/K8NEhYRNyTay0Z6D4HG+kr50TfvkOmTsrynUWxJFwIUOLrg4BcSIAFPEwiPiJagiAxPN4P1DzECAdI6xFrsf80N9L8us8ckQAIkQAIkQALuJkCBw93EWR8JkAAJkAAJ+CEBChx+OOjsMgmQAAmQAAm4mwAFDncTZ30kQAIkQAIk4IcEKHD44aCzyyRAAiRAAiTgbgIUONxNnPWRAAmQAAmQgB8SoMDhh4POLpMACZAACZCAuwlQ4HA3cdZHAiRAAiRAAn5IgAKHHw46u0wCJEACJEAC7iZAgcPdxFkfCZAACZAACfghAQocfjjo7DIJkAAJkAAJuJsABQ53E2d9JEACJEACJOCHBChw+OGgs8skQAIkQAIk4G4CFDjcTZz1kQAJ+CyBttZmOXPgpQHpX9757VJdfnlAyrqRQopzD0lTQ/V1WVztv+7EQdiB9pQVnBiEklmkOwlQ4HAnbdZFAiTQLwItzQ2y+S9fk+bG2h7L2fD841JwcXe359VW5kt7W9/Dmjvmzzm1SS7r/76k1pYmqasu6sxaVXpRTu5+vvP7QGyUF56Sd559SNrb21wWt2/DvwnOc0yu9jueNxjfS/OOyt71PxmMolmmGwlQ4HAjbFZFAiTQPwLtbS1SdPmAQJPQU5q88FOSkDq+29PW/f7TqkXI7fac7g465r90fK0Mn3Bzd1lcHivOPShbXvlG5/Hh41dK7pnN0txU17mvvxvR8Vky/abPSUAAb/39Zcn8N04g+MazMAcJkAAJeJ4AtB2HNv+XlOnbeFzSKJl/27e7NCrv3FaJjE2V1pZGObbjNxKdkCl557ZLYvpEmbHs87J/0y+ktblR9m34qaSNmCuTFz4i+ed36JTIX6RVBZqRk2+VUVPvkEvH35WGunKpKD4ntZV5MmbGPTJi0i3mjduef8Lcj0tx7mGZs/rvTDvqa0rk6LZfS2XpeYmOGyZTFj8qMQnD5cLRt7X8Jhk7415z3tZXvyUT539Sjm7/tdRVFRqhY8LcT0jq8FkSGZMqJTrFkTF6UWff9q77sQyfuErSsueYfTvf+idt072qWdloph1CwmNk8oKHJSVrhml7bVWBtuugJKRNNHVeOrFOMsfeJKX5x+Tknj+YOmOTRsrMFV+SsIh4U2bBpT1yfNdvJTgkQqYueUwFt3Gd9WMDmpgjW5+SGhXWEjMmy7Qln5Hg0MjOc5z1EVxCwqKuGzNXZVWVXtI6/leaGmsE7WMa+gQo5g79MWQPSMAvCeDBWanTDkvu/qFkjFp4HYMSVcM31lWohqBWLh5/x9hDTFv6mNEaXD79voydeZ8EBgXL6Ol3q1ZilVSX5cjOd/5Z8LCfOO9BOfDef5oHeE3FFX3wPS3JmdNkhAohUO1D6+CYHw/I9vZWFRLS9LNdtr72LQkIDDQCCIQACBaYyqguz5FqPddKRTn7JCQ0QrK1DaHhsTJJhYX4lLHmcGRsutSokGNPwXru+cNvmF0VRWcEth5BIWESHZ8pC27/nqRkTpf9G39mjqPtx3Y8K1njlsswFVrAAhoiJBwbNeV2WXjH943Qcfbgq2Y//kDwmrr4MyqwpcvON39g+tN5UDe2vf4didJjs1d/XcqLTsuxnb+1H3bax5bmOnE2Zs7KAr8db31fIqJTZNL8v5Liq23uUgm/DDkCFDiG3JCxwSRAAiCAB15lyXm5cuYDSR+5oFsoePued8sTRmsAbQY0FXioBwQGmWmXmIQsyb+4SyKiklSIuWCEk/CoRCnNO2bKHTH5Fhmjggn+Yzqivrr4uvxNDVVGQwAhpq66UCAMzFj+RaMdwDRGTUWu1lvgtJ2BQSHmLR7CRPKwqSp4xJjzIqKTVRjomid74hoVCLZLiwo9l0+/Z7QViWkTjNAEHkjQrlgJwhQ0IMkqiNgTtDTh2t+ygpMSEh7dJQ+0HdCQTFGtT03lFRXcyjqz1tcUm76BKTQnYRFxRlvSeUI3G45j5qqshtoygQ0LBMT0kfPMNFA3xfLQECFAgWOIDBSbSQIk0JVA+sj5svjOf5K8Cztk7W8flsb6yq4n2L5BSAgICDB7gvTh7sxQtKW5XgIDgyUoKNT8nzDnY5I+qkOQCVTBxEoQDqDJcEwQXlAGEoSBgIAgnZIIN9+DgkPNZ0e+AKPpMDt6+INpI5GOdlunYkooQqdaoNmAwDFi0hojGL373CP64D+udXe0wTof2hNn6fCW/5Xd7z4ptVX5ahPToqe0d55m2XgE6ZQKEjQOVupok5jpFrBKHzHfaCGs4x2fzvvoOGZ1KrghYerGXpbFMegqP6s9HWXz71AlQIFjqI4c200Cfk4AUyB4C192349VgGgxWokbRRIYGKJv9sXmgQqNB4QWTD/ATiNr/HJV6Sd3W6Q9f2RMirEXwRLOKLXZgJABgQAp58QGCY9MVK1MhtEIlKv2A0IP7ENgz4EEQaahttw8/K0HfL3aSjhrA4SMI9v+TwWcBmPLUZizV5LSJxvblPjUjukYU2g3f/Iv7NTpo4+rsPBJ88C3n4qpDwghuTr1BLsOaHushCmjkNAo0y5wgp1LfMoY67D5hNbDWR8dxwz2Nc7KAksIIZjagYHwpRPru5TPL0OTAI1Gh+a4sdUk4KcErr7t6wfe8GH0iDfjuOTR161IgUaj481YP+1aAqzQuKrtyByzxNhaZI1dJvM/9F1jnPn2sw+qcJBgNAXLP/Lzq+de0zJ0lNXx3TE/bDAqis+aqZs5at+wf8PP5KDagkD7AaNWTLdkq8Hn2UOvyqu/vFMNOScYAQRaDBhmQkjB/vm3PiHDxi69Oq3w+HVjnT1xtTGEHTf7I6bsYdqPMwdeljeffkDtLtKunW/6ea3tqMdiMXLKbXJg4y/k+I7fSJj2F21HAjcYcr7y37erxidIp6K+aThaPNHGOWv+3tiyHPrgV2aJ8qQFn5Txsx/orNdVH/POb+syZkkZU1yWNW3p47Lz7X9WgSRSknSaCW1nGtoEAlSSvqYrG9p9YetJgASGOIHDJ3LlH/9jswRFZLnsCd688eBGwtRFm2oKLJsHeyb7efZtM53SKYyIcXIF2wlMpyDBHwbKDdW3dDxkjc8KvU1CaECyl4Xv0GhY+bGCBJqK2au+hkNGc9Jh2xFnvlt/cNtFHVi1YS8PbWtqrDZahYKLe2TPuz+SOx//i3ngW3mtT+RDm9BGpDbV8mA6BSywjf44tt2cZ+enU0CY+kGbVeViyrPag3bDTsPiYu03lekf9KGxvkJCw6I78lsHrn666qOzMXNVFsZCa1JBLKwLJ4eqzNfW+lz5/leWy/RJrn87zvJxn/sIcErFfaxZEwmQwAAQsIQNFIUHojNhA8fs59m3Ox7S1259yG89VJEPb/BhkfGdD3Jj/3FV2HAsF9/t+cfO+rAuvd2G3SZBGMD0gmPCfggbSI5ts5amXjn7gYyZea9TYcPKZwkb5rsKGBYLqz+Obbfy4RPJ2E5of1GOJVBZ7YHGwyoH51r7sY2EPNAEGWGlY1eXv6766GzMXJWFsYCwgeRYf5fK+GVIEOCUypAYJjaSBEhgKBCAHcgdj704IE2FloTGkgOCkoV4CYFrYr6XNIjNIAESIIGhTGCghISBKmcos2TbfYsABQ7fGk/2hgRIgARIgAS8kgAFDq8cFjaKBEigvwSwZLTw0t4bKgZGilgu6snIpPZ2w1lYZcmFG+qDt58MHxvoU7O6LGfyLwK04fCv8WZvScBvCMCb6I63fiD3fuHNXvcZq0Kq9SEPnxaJ6ZN6na+nE+FLAkt44VYcK0ewDHfsrPuN4aVjXnu7Lx1fp3Fcyox7dOQ9ufv3xldI8rBpMnPll83KDawUQcyV7hKi2uIcyzC0u3MH+pi9bng13frqN43jMvjkmKZxWrC0l8k/CFDD4R/jzF6SAAn0gkDJlSMya+VXO5e19iJLr07Z/uY/mBgu8FWB/7m6AqWpG8+ojoVCYNn22rc1ZswiI3wE6uoNLFN1jDDrmM/67hjV1trvjk973TGJ2fKhv35e1jz0tCy8/ftyWAPAOfP66o52sQ73E6CGw/3MWSMJkICbCECbgEix+er+HKHqZyz/glkK6ixCKZxYNdSXy+EPfqneM+/UGB7znUZ7hXfQLhFYNQhcd5FTMT1TcGG3fOjR5038F3Qd0VqRzh58Rb2RbjLCA/YhaJyz1NxYa9ymwwEWXJvjf01F3nURZrGM1DECLALP2aPauoqKi0i1YBUVn2GW9iJabcaoxXJMo9jCxfislV8xAeKcsYMvkt5G5LX6B+GjTb2swo9KkG3ZsXWcn75HgBoO3xtT9ogESOAqATiZguvymSu+LCV5R+TErt+ZI84ilCL6a3BwuIyadqekauh3V9FeHSOwOivLPgBw8R2rD1cELnNMCHs/Y9kXjUtyPLARsMxZgl+QrHErZPNLfyvHNTIrBBC4PHeMMOssAqxjVFtXUXGtqLqYApmy6FE5q55Ld6/9oUyY9wn1xxGkgsczpmnO+mvlrS6/bAKu5Z7ZrILU9RF5rb6dP/KG7Fn3Lyp4LTN+T6z9/PRtAhQ4fHt82TsS8GsCcDI1Z/XfmQiso6fdJZgycRWhFO7RA9R5VkLKOGXW3m20VysCa1RcRo+RU2EcaQVBcxyMSfMfUk+lVSZUfKhDxFbHcxfe8Q8ya8VX5KJqWN597lPGw2ls0kjj5dSKMOssAuyNRMUFr7lrvqGajQUSnzZeBY+/1u2FMlzdsdeoIOKKHdqKvD1F5MV5mB7at+HfBIIQ+sTkPwQ4peI/Y82ekoDfETCeNtUjJhI8VrbrP3u0UxxHtNNIB+0DNCOuo72KxvfoiKLam7IQyK26PMcEdrO8ZqI9ePCu//1jGpZ+hE6RTIaM021CWxH/JGv8ClmnAseVs5t1iqOrG29EgEW8EsRagY2Hs0LtUXFRIaLipmbPNu2x84KX0YDAjndSBKmD6/Pu+mvP6yoiL+qDUIeyM8cs7eLJFMeYfJsANRy+Pb7sHQn4NQFMPUC1j1gdcBUenzxGV2v0HO20u2ivdqC9KQu2IJiSOLj5v40dRrMKM6f2viBlhaeNIAKtAKK/mrgn9sJt25gGwXJd9KNVl/u2qrCCWC+OEWZdRYC1R7XtS1Rcqym96a91rvVprxv7YCQKL6o0FrUI+c8nNRz+M9bsKQn4GQGNY6K2D8d2PCN71j4pkXHpZnqlu2inRhmif4LVSNJVtFdVfShHS2vSc+RURDtdcveTsnfDT+TV/74DQUiMQeqoaXdLcuZ0efuZBzWuSmRnzBCUbUV0teqCoLF/0y/MyhYEZhs+fqUMV5sOaCvQHyvCrKsIsI5RbTNGL5LrouLa60UPtZ3QWiBhG21xxS4te+61NpsMmg95NNnrXnD793RZb7mJNItIuRB+mPyHAKPF+s9Ys6ck4PUEehMt9kY6gYczHppYReEYRM1ZhFJnEVEdo706i8DqrCxn7TRTNTqdgAe3lRp1eSwCpbW3t3ZOMVjtcKyrsa7C2GzYp2bsEWZRJoQQxwiw2G+PaovvjlFxsc+q12wbdl2j0VrB3Jz1157XaC9sAouzuu0MUF9/E6PF9pfg4OenhmPwGbMGEiABDxGwHpCOwgaaY0UotTfNMSIpznHMa976O17eO7M6K6vzoG0DhpWOySo/IODa7dhqh2Nd0Ng4JjjzsiLM4hgiwHYmDT1vJSuSrPUdD3zHh75VL86x2Fnn27876689r6ODMWd1W+Xy038I0IbDf8aaPSUBEiABEiABjxGgwOEx9KyYBEiABEiABPyHAAUO/xlr9pQESIAESIAEPEaAAofH0LNiEiABEiABEvAfAhQ4/Ges2VMSIAESIAES8BgBChweQ8+KSYAESIAESMB/CFDg8J+xZk9JgARIgARIwGMEKHB4DD0rJgESIAESIAH/IUCBw3/Gmj0lARIgARIgAY8RoMDhMfSsmARIgARIgAT8hwAFDv8Za/aUBEiABEiABDxG4Jrzfo81gRWTAAmQwDUCbW1tGoGs6doObpFALwiY300vzuMpniNAgcNz7FkzCZCAA4HoqHBJTQiW1rZihyP82twWJm2ikWalRYIDGwnEgUBQZLDg98PkvQQYnt57x4YtIwESIAGpqWuWTbsKZOKoWBk/Mk5OX6yUk+er5OYF6fqADSEhEhgyBChwDJmhYkNJgAT8jcD+46WSV1Qnty7NlJDgayZ3zS1tsm5bnmSkRMjsyUn+hoX9HaIEKHAM0YFjs0mABHyXQGV1k2zcmS+zJiXKqKwYlx29kFstB06Uyc0LMyQ+JtTleTxAAt5AgAKHN4wC20ACJEACVwnsOlwsZZVNcsviYRIUFNAjl9bWdlm/I88IHAtnpPR4Pk8gAU8RoMDhKfKslwRIgARsBEorGuW93QWycHqyZKVH2Y70bjO3sFZ2HiqRlfPTJSk+rHeZeBYJuJEABQ43wmZVJEACJOBIoL29XbYdKJb6hhZZpVMjgYE9azUcy7C+t7W1GwPT8NBAWTI7VQIC+l6WVSY/SWCgCFDgGCiSLIcESIAEbpBAYWm9bNlbKEvnpEl6csQN5nZ9uil3X5EsVaFjIMt1XSOPkEDPBChw9MyIZ5AACZDAgBKAJuKDfYXSrj7OVsxPGxRNBDQnm1WYQR3L56X1S3MyoJ1nYX5LgAKH3w49O04CJOAJApatxQoVApITBt9RVX9tQzzBiHX6JgEKHL45ruwVCZCAlxHAahIYhcK+AlMo7k7b9hdJfWOLGpVm9Gr1i7vbx/p8nwAFDt8fY/aQBEjAwwTgL2PvsVJZo0tdPekvA/491u/IV2dhiTK6G/8eHsbF6n2UAAUOHx1YdosESMDzBOARdJM68ErUZarzpiZ7vkFXW7D3WImUlDXKqkUZXTyYek0D2RCfJECBwyeHlZ0iARLwNIHTF6vkyOlyuXXJMK+MeYIYLeu258mUMfEyYVScp3Gxfj8gQIHDDwaZXSQBEnAfgcamVuOWfFhqpMycmOi+ivtY06FTZZJbUGd8gISHBfWxFGYjgZ4JUODomRHPIAESIIFeETh2tkLOXKoyWo2I8OBe5fGGk+B07F0NBjc2O0amjkvwhiaxDT5IgAKHDw4qu0QCJOBeAnX6wIatBgwxJ4+Nd2/lA1jb8XMVcjanWlarx9PIiKEjMA0gAhY1iAQocAwiXBZNAiTg+wQOarTWnIJao9UICx36UxJNOiUEbUdmWqSuZkny/QFkD91GgAKH21CzIhIgAV8iUF3brHFL8mWyGl2OGxHrS10zfTmrU0NHdYro5gUZEhsd4nP9Y4fcT4ACh/uZs0YSIIEhTgDLSgtLGoxWIzg4cIj3xnXzW3RZL1aypCSGe9WyXtct5hFvJkCBw5tHh20jARLwKgIVVU1Gq4GphpGZ0V7VtsFszKW8Gtl7tFRuXpguCbFhg1kVy/ZhAhQ4fHhw2TUSIIGBI7DzULHAU+fqRcP80jU4As5tUMPYaDUmXTwrdeDAsiS/IUCBw2+Gmh0lARLoC4GS8gZ5f0+hLJqRYgwp+1KGL+XJK6qT7QeLZfncNDPV4kt9Y18GlwAFjsHly9JJgASGKAGEd9+qAc+amts04Fk6w7vbxhFsEIguKDBAlqngERAQYDvKTRJwToACh3Mu3EsCJODHBAqK62XrgSJZNidVUpMi/JhE910vKmuQD/YWyuKZKQLPqkwk0B0BChzd0eExEiABvyIAO4XNOn0SqAtPls9L96u+96ezW/YVUhPUH4B+kpcCh58MNLtJAiTQPYHL6rxr1+ESuVmnTxDdlenGCJRXNaq31QJdPpsk2cP8ZwXPjVHy77MpcPj3+LP3JOD3BFpa2+S9XQUSxdUXA/JbwGqeqppmdRiWLr7so2RAYPlZIRQ4/GzA2V0SIIFrBM5frpb96pp8zaIMiYsJvXaAW/0iAIFjw448mT4hQQPC+Z4X1n7B8ePMFDj8ePDZdRLwVwJYeYJga8mJYTJ3SrK/Yhj0fu8/XioFJfUmGFyoD8SZGXRgPl4BBQ4fH2B2jwRIoCuBUxcq5ZhGRb1l8TCJjmSMkK50Bv5bXX2LcY8+YVSsTBo9dCPpDjwZ/yuRAof/jTl7TAJ+SaChsVU2qlZjeHqkqvoT/ZKBJzt99Ey5XLxSI6sWZkhEeLAnm8K6PUSAAoeHwLNaEiAB9xHAw+6c2mvcuiRTwsOGfgh595Eb2Jog9CEY3MhhURT6BhbtkCiNAseQGCY2kgSGJoGWlla5dKXMY43HA27n4WIZkRElY4bHyqhs2mt4bDBsFWNa66T+h7YD01r5RZVSV99kO8Nzm6nJMRITFe65BvhwzRQ4fHhw2TUS8DSBopIqefhrz0l0dJzbm4LQ6urHS0I0fDw8b9fWVMo7v/uS29vBCp0TgOHuetV2pCWHyzPPb5D8kgYNiudZ7VN9fZ38zaeXyZqbJjlvNPf2iwAn0vqFj5lJgAR6IhAVFS2BEdk9nTbgxx0XuYY2nRrwOlhg3wmEhgTKHcuz5HxutZRUNEpg+DAJDPasw7WQ9oK+d4g5eySgDnyZSIAESIAESMAzBEZnxdAHimfQu71WChxuR84KSYAESIAE7AQYa9ZOw3e3KXD47tiyZyRAAiRAAiTgNQQocHjNULAhJEACJEACJOC7BChw+O7Ysmck4PUEvvTgRPnjT5epe/GkzrbOnpRo9j3+wLjOfdwggYnqqfS3Ty6Rp36wSDJTIzuBIDrtH358k8yeTGdunVC8dIMCh5cODJtFAv5AYPvBIgkMDJD7bxnR2V1sY9+Og8Wd+7hBAvDXAQ+l8bGh8sg9YzqBYNkzotLik8m7CXBZrHePD1tHAj5NYP/xMuMBdMzwGJmhkUVb1XHGuBGxxinUkTMVEhsdIktmpUpaUrgUlTXI9gPFUlHd4SAqJipYVs5Pl6T4cCmvbJRNuwtMWHSfBsbOGQIzVQs2fXyCHD5d7pRIdGSwLJqRIplpkeY3cehUh6dZpydzp9sIUOBwG2pWRAIk4IzAn9+9KE98Zprcc/NwI3DgnBfXXpTkhDB58quzzRttmwoi0Hrctypb/v6ne81D5IdfmS3pKRFGAInX0PKBQQHy8vocZ1Vwnw8R2HmoWBaqMPHJu0bLN3+277qe4Xfzz1+epYLoNZ8eH71tpPzqj6dk897C687nDvcRoMDhPtasiQRIwAkBS8sxdVyCOXryfKUcVe3G4w+MN8LGUy+elvf3FMiDd4ySO1cMl5vVHfbWfUVG2Nh9pER++uwxyVbX5c3qWZTJ9wlgzCFgThwdJ8vnpgk8ltrT/WtGGGHjlQ05RnCdNj7eCLQPqYBCgcNOyv3bnPRyP3PWSAIk4EDgz6rRsNKLqvFAGjM82nyO1umWRz88VkZmdnzPSO7QatQ3tBhjUxieBqt2I7+43pzPP75P4LnXz0l7e7t87EOjBB5L7QlTcoihA80ZpugOniyXo2crjJASp1N0TJ4j0HWkPNcO1kwCJODHBPafKDNTIyXlDUa7ARQwEkTKTIswc/FBKlScOF8hF3JrzFvtj54+asKdL9O33H/52zly29Jh5nz+8X0CZ3OqjT0Ppk1uuymzS4dh21NX3yItrRpI52pqbW0zAorKKEweJMApFQ/CZ9UkQALXCCDYmn1apKSiwajG//13J6Ss8vpIoog2+sTP98u0cfHync9Ol/nTkmXt1rxrBXLLpwm88PYFmT89WeAa3Z5KNS4L9qWoLUdxucZo0dfqrLQoadOZl1oVRJg8R4ACh+fYs2YSIIFuCOw7ViqTRsfLtx6bZqKKYtnjKH2QfHDV8O9T942RTTsLOqPBwrCUyX8IYNXS2i1X5K6Vw7t0eu/RUhmbHSvffny6bFbbn2m6miUlMdz8bjDFwuQ5AhQ4PMeeNZMACdgIQN1tV3m//cEVXQ4bISsXpMtnPjLenFmmy1/XaUhzqMwx5fLwVX8MePj84a0LttK46WsErN8GbDes9NL6S7JAtRyp+jupb2g1u9/cnGt+NzfNTZUH7xytUyttgpUtz75y1srGTw8RCNDBuzZ6HmoEqyUBEvBNAkUlVfL5774sQZGjeuwgDD9xN3J8C4VKPCkuTBp1NUJVTXOXcjCH36pz9ZZvji4HHb40Vp6S1379uMNefvUGAo9/8wUprU+SoB7C08NA1HFVCtqP3wimTOwpSJdRJ8aFSnlVUxd7Dvs5jttNdQXyxYemy5qbJjke4vcBIEANxwBAZBEkQAL9J2A38rOXhgcJ5uKdJczXM/kPAWfCBnrvKGxgHwRXV78bHGdyPwGuUnE/c9ZIAiRAAiRAAn5HgAKH3w05O0wCJEACJEAC7ifAKRX3M2eNJOA3BCp0/ryxqVUiImgq5jeD3o+OetykkD/Tfoxez1kpcPTMiGeQAAn0gcDWfYVSUl4jkWHt0lp7pg8lDGyWuJhrsTUGtmSW1l8Cw4fFS+mxS+Jg99nfYrvND9kCNiEB+ml5K1W7ZV39xN9Jt+D6cZCrVPoBj1lJgASuJ1BUWi8faKyTJbNSJCMl8voTuIcEvIgAllTDtwt/r4M/KBQ4Bp8xayABvyAAx1vGKZe+JSKoVkAA3h2ZSGBoENiiGjloPFbOTzeRiYdGq4dWKylwDK3xYmtJwCsJXCmskx3qXGnFvDQNKx/ulW1ko0igJwJwLPfergKZp27yEYGYaWAJUOAYWJ4sjQT8igCcbr23O1/Cw4Jl6exUv+o7O+u7BOCZtLK6SVYtytBIxFzMOVAjTYFjoEiyHBLwMwIXr9QI4p2sWpgh8bGhftZ7dtfXCVTVNMmGHfkyY2KijBneNUCcr/d9sPpHgWOwyLJcEvBRAojqumFnviSqkDF/eoqP9pLdIoEOAvuPl0pBSb2sXjSsczUL2fSNAAWOvnFjLhLwSwJnLlXJkTPlskZvvjFRIX7JgJ32PwIIa79egwZOHBUnE0fH+R+AAeoxBY4BAsliSMCXCTSp8671ql4elhohsyYl+XJX2TcScEngyOlywVTimsXD1G4pyOV5POCcAAUO51y4lwRI4CqB4+cq5PTFKrllyTCJDKevQP4w/JtAQ2OrrFNtx8jMaJk+PsG/Ydxg7ylw3CAwnk4C/kKgvkHVyKrVGJ0VLVPH8cbqL+POfvaOwMkLlXLyfKXRdkRFUBDvDTUKHL2hxHNIwM8IHDpVJjl5tUarERZK1bGfDT+720sCcBS2YUeepCVFyJwpnGrsCRsFjp4I8TgJ+BGBmtpmswJl8pg4GT+SxnF+NPTsaj8InL9cLQdVSF+tS8Rjo7lE3BVKChyuyHA/CfgZgb1HSwRxJWAQFxJMZ0d+Nvzsbj8JtLS2yUadgoyLCZWFM7hc3BlOChzOqHAfCfgRgQr1qLhpV77MUgdHo7Lo4MiPhp5dHQQClwtqZffhEhOTJTGekWftiClw2GlwmwT8jMCOg8VSrdMo8BYahNjcTCRAAv0mgECGmzQmS1hooNw0J63f5flKARQ4fGUk2Q8SuAECpRUapGp3gSycnixZ6QxSdQPoeCoJ9JpAQXG9bD1QJMvmpEqqGpb6e6LA4e+/APbfrwi0t7fLln1FAvfkKxiG26/Gnp31DAFccx/sLZS2NpHlGk05MNB/NYkUODzzG2StJOB2AogHsXV/kYnqmp7Mty23DwAr9GsCJeUN8v6eQlmkBqWZaZF+yYICh18OOzvt6wRgCBqv1vJImE9+X6dPgnXlyU2q2g0I8N83LF8fd/bP+wls1ykWxGa5ecE1uyksR49U52G+rv3g2jfv/32yhSRwQwR2HS6Wf3v2mECVm1tYKy9vyJHpExJk2dw0Chs3RJInk8DAE1g8K1XmTU2W1zblmLgsqOF/XjxtbKoGvjbvKpEaDu8aD7aGBPpFAJ4Pv/Yvu6W4vNFMncydmiSLZ6b2q0xmJgESGBwCu4+UyGENCLduW57EavTl//j2fKPpGJzaPF8qNRyeHwO2gAQGjMDr7102wgYK3KOOvMZmxw5Y2SyIBEhgYAlM1lD3uw4Vm0KrdFrlxXcvDmwFXlYaBQ4vGxA2hwT6SgBGaa9uzOnM3qwrUV5ef6nzOzdIgAS8i8CrOq0CQcNKa7deMdOg1ndf++SUiq+NKPvjtwTOXqoyrsmT1LthckK4JMSG+rwRmt8ONjvuMwRaW9ulrLJRSioapFSnQuP1uvXV6MwUOHzmZ8uOkAAJkAAJkID3EuCUiveODVtGAiRAAiRAAj5DINhnesKO+AWB6toGqalt9Iu+spMDSyBQ/Y+kpdCIdmCpeq60xqYWKauo9VwDWPMNEQgPCxEKHDeEjCd7msDvXtot77x3UkLDGIXR02Mx1Oqvq62Wd373paHWbLbXBYGjJ6/I9376lkREMhaQC0Res7uttVWGpUVR4PCaEWFDekUAXjMDwlIlKDKhV+fzJBKwCIQ2n7I2+ekjBCKiEiQoYpiP9MaHu9HSqPGbyoQ2HD48xuwaCZAACZAACXgLAQoc3jISbAcJkAAJkAAJ+DABChw+PLjsGgmQAAmQAAl4CwEKHN4yEmwHCZAACZAACfgwAQocPjy47BoJkAAJkAAJeAsBChzeMhJsBwmQAAmQAAn4MAEKHD48uOza4BAozj0kTQ3VN1x4S3ODFF7ae8P5BiJDX9s8EHW7o4yq0ktSV1XojqpYBwmQQB8JUODoIzhm804CeKi//+e/kebGax4IT+97Uc4febPPDW6sr+giYOzb8G9SXnjjPh1qK/Nkx1s/uKF2tLY0SV110Q3lcXZyX9vsrKy+7Du+87eyf+PP+5K1V3lO7v2DXDqxrlfn8iQSIAHPEKDA4RnurHWQCLS3tUhx7kFpa70W8rmqLEdqKnL7XOORLU/J2YMv9zl/fzKiL1te+UZ/ivCKvMPGLJWRUz7kFW1hI0iABDxDgK7NPcOdtXqIwNmDr8jl05tUIGmRzLE3ycR5Dwq0Ioc2/5eUqdYiLmmUzL/t252tu3hsrRRc2iOBQcFSmn9cltzzQ3Os6PJ+Obnned0fKjNXfEliEoYbTcSRrU9JTXmuJGZMlmlLPiPBoZGdZdk3SvOPaf4/mGmA2KSRpoywiHjVxLwhF46+I8EhYTJh7oNydPuvzTkQOibM/YSkDp9lirl4/F0Voq7I1MWPmu/nDr0mbSpsJaZPdFquVTfyHNn2tCy64wdm16m9f5SImBTJnrCqx/b3ps2zVv6NRMamXcezsuSc0RKhfQUXd8upvS9Iq3ofDAmLluRh02XEpDVybMdvJDohU/LObTf9mLHs88o3RPLP75AzB/4irSpEjpx8q4yaeodp+yVlAM1VeFSSQAsVE59ldZOfJOAxAlWlF/VaKpT0kQv63Ya889vNvQX3F19I1HD4wiiyD9cROLbzN3J4y/+Y/2UFJzqPN9SVy4xlX9T/nzcPONwcoIqv1M8ld/9QMkYt7DwXG2kj5kh8ylhJzZolkxc+LIGBHTJ67unNKqw8pIJBuKkD5257/TsSFZsus1d/XcqLTssxnUZwlfDgHzXldll4x/eNQHH24Kv60Kw00w5zNP+URY9KRHSSEQRCw2Nl0oKHTTus8qLjhsnpfX9SYane7MK0UYQ+eJ2Va+XBZ3NjjWqADnXuqiy9ILUVeeZ7T+13VrZjmwMCA53yhBBWpXVB0Nvx5vdl/JyPydhZ9xublsyxS6W5qVYuHn9Hqssvy7Slj0numc0qGL4v1aqd2vnOPxthC8Lhgff+UzCelSXnZe/6n8jYmfepEDZTSvKOdPaJGyQAApj2fOfZh6S9vc0lkP5OWTrLf3jL/wq0qn1JjuXh/nRy9/N9KWrA89RW5kt7W2u/yqXA0S98zOytBEJCo8T6HxgY1NnMSfMf0jftKvNgDg2PlvqaEiMk4AF25cwH172VRESnSLjGbcFbe1LGlM5yZt38VRVG5uob921aVp6WUywVRWeMRgPTIGERcaoROdZ5vuPGiEm3mDfzsoKTEnK1HRBeIFzgbR7545JHC7QfwaERqgWYqsdiOotJ0u9oF97+ywtPmzf8jNGLVFNwfbmdmbrZ6E37nZXt2Ga8iUHocsUTAhL+o29p2XPMwwAaCiRog+bd8oTR4oAtbF7yL+4yghQEIwgj4VGJUpp3TIouH5AUFTSGT7hZxsy4V4aPv7mb3vGQPxKIVo3X9Js+JwEBrh9z/Z2ydMzfWFehGtHd+ntc2SfkjuWhHAjfzU11fSpvIDOt+/2n9Rrs+9Q02sIplYEcEZblNQTGz35AwiLjTXtqr65egF3H+t8/pg/xEaqynyzS3tHc9JHzZfGd/yQnVc2P/7f81TPmgd9dZ6ybGFT+kPoxLYMUHBJhbnDpI+arkJLusgi8BeWd3ybZE1ebt340Jig4TFY/+L863fBHWf/842ZqBw90ZylAQ60jb86pTUblmjVuucnvrFzH/O1t17/x9ab9zsp21uasccuu42m1Aedj6mj97z4tATpNhXGCIAWBB0zRL6SgTq71RqsUpFNXSBNUM5KaPVugYQJrK0GzwkQCdgLQ5kF7ialTTL9Bu1lRfM4IsmNm3KMvEFOvm7JMzpxmNAr5F3apcJugmsa/NprF3uZvqC01wnREdLJpSsHFPWr/9ZK+EFQZTdzkhY+Y63T32ieN1i4ueZRq7E7KhWNv62/749e1B1OokTGpUqJaSbxQWKmjPWWqST2r/clXDegnjTbw8un3zMvJdNXgtjTVy9Ft/2c0giF6jU1WLWlK1oyrLDSvvqhg6gcvEhDaK4rPCqaEobUcM/0enbq83arOaBNbmxtl34afmhct9MNV3zozOdngVeoECnf5JoEafWOuLs8xb9GwGbBUrVDb40Jcdt+PVXhoMW/SdgIQKrBSxDrffszajoxJMxoV3GhwM4OdQXzKGOvwdZ/5F3bqDefjMmn+JzsfnGaprT5woT1B+0quHDI2DA215UYoaW+/KiFdLS174hq96HeZG0i2no/krNyrp5sPPOyh4cFNCoa0RTn7zf7etN9Z2c7a3B1P1Il+wO5l3ppvyDgVOLpLmM7CDRACFbhmjV+uU03JRmgsuXLE2IWgL8Wq8WAiATsBTNNBE4ZkbJe2Pi0QKEaoHRCm44LDoq6bsjy990/mmpq75u8lNnGETpN+1/xee5sfmr0ovRcg4TrANCW0cDNXfFFfMLbL8Z3PmWMlVw6b6xBfGlUQKs07an7XsKVynELFiwvuXfZk2qPCRJoK3+mqDdz22rdNGdOWPi45JzcYzWdDbZlEx2fKgtu/JymZ03W69memiM68I+fJuFkPmGlKCB+HNv/S2FMtvP0fjJBjrw9Tl7BjGz39bu3Pqm77Zs/nuE0NhyMRfh/iBDrekOXqBzrT8dYcYC6+ZL3w3n7mQTVWjDRvGjiOGwEMOPEWDVV/Qup47O5MsOvY9sZ3jWrz7s+9asqzNBzWW3lQcKjM0ZsUbmSHPviVWZaLtw68wV9LAdqsjoaNnHKbHNj4Czm+4zeqiUkwNxlM77z/568a7QpUqEvv+ZG2eZi2M1Re/eWdMv/WJ/SBu6KzOGhq0F7csFKyZpr9zsrFATBAWzE1hBvg2t8+ImE6fROl5etRU0dP7XdWtrM2O+NZqIa3qAdTIhDqcA40TkU6/TRrxZclIW1iJxu0V6AG1zZnjFpk3uzefvZBM4WEt7blH/m52Z+QNkHefFo1WTr9FBWbYco3efmHBJwQGDFZ3+T1gYl0aPN/S5NOf9inLLEfv0sI34U5e839oa6qQKC1QOpNfgjH4Ve1GygL9kXQICCNn/1ROXfoVWOjZHY4/MF17tgenAIBG+1wTCgXLzb1NaVyfNdzMv9D3zU2ZbjHQQiHZgdG2aX5R01WXKtWMnnVhgzp7KGXjQ0UpkJzz7wv8aljr5tahuAfoFPTuDfGJGQZLeyN9M2qlwKHRYKfPkEgRN9a7v/KBiONWx2aveprZhMP3JUf/Q/zxoy3iPb2VqOuh70AbkRtOjVit5Ow8kOVee8X3tJZjzZz/ppP/rqzfKj38R1puAoDeBPHiolQvdChGbEnqE/v+uwrZhemBlBnQEBQx3laNi7ouz77sjSp+hXTQZZQc8en/yhNjdVmKsJeHrZXffyXOhnTrud2CDLOysV59jYv1Dce+CmBbYg99dR+Z2U7a3NC2vjreEIFq8DVMHStEfwW3fmPpmoIaFj9gyWzFhscmL3yq0bgQL/wtjlLv7eoEBaqwoXV15vu+9er/YjUHF21P6Zw/iEBGwG7LZeZCtXr3zHBvihQH/x4+QiK0JeIVV83Ai3O601+GJW3Xp1exe8VdmRWCtKVZ7jHmKS/6+40plYefGK6016OdcxqT2BQh42adb9AG1A2bJ22vPpNI5RYLzqOefHdTFfqtYn75LnDrxvDdWiCFqgA4yp12zdXmXQ/p1S6gcNDQ5MAVH/2hAvRuhixH2/EeGhZK06wDwaLzoQNHEOCLYW1xNWxfPt3lAtjTkdho6MUveBsbYMNAt5qkAcPbiS0CVoAe3txDFMhzhKO2fuBc5yWa6sX50Aws7h0qauH9jst20mbHXmaurStmLpq1CkdrB5Yp3YcxZcPqlr3fjSpCxv0y94ucOoQwmyqq85+dGhv7OebAvmHBHoggOvUPmWJN/l2XUmFqQNM4aXptIOraxlFO+bHEnPLUV986jid0tlvNCSw87p8cqMkDeswPA8LjzOraLAqBUvcreRYHvbX63SuZRNindebT2hpktRWDSvyoLWwp8KcfVKrWhO0FQJ/nE7/YmUNrsW5t3yjcyrKnicwMMTYWmFKtLu+2fM4bne9Mzse5XcSIAESGEACWDlw+6MvmBsXbq4QzphIYHAI4L3+qoCqgjSm9KzUsT9ApwjGdZmyhJHo9je+J28+db95CcDy8xUf/XejbetN/sT0SXJm/1+MhmHYmCVy5ewWM4WL3zqEmTlaPhKWhe9+90dmKqTDt05H2xzbM0yXjGNpLGwzuqQu/enIa2n+tOGmvaj/zIGXzbQjplK7pgAV+B81tmGwG0EbNr7wBTWsLTOaRCxBd0yZWt7W174lWWOX6fTNd1z2zTGf/XuASivURdqJcNurCfzXbzbLup3lfFB59Sh5Z+MaK0/Ja792uHF7Z1PZql4Q2Hf4kjz5q+069QE7JOcJfl+gVTTTF/qoszSJ1n7kgvbBccrSLEPV86EJNOfAl0cv8qPcN576sGDK0HLSBwd3qMPSkJoC9Q+0G2gXtKeu2oOVIHtUMLnz8b900fj11B8IHRBA4AwQdk/Q3mIb2tBj258xgsVMtZ3CFE+IzTkhlvViqhWryZwlGInjuKVVddU3x7w4LzmyjMtiHcHwOwmQAAmQgG8QsKYwzXTbNQXHddN3jlOW9ocwSPQ2P+rDklJoNiyBw9XDG9OEVrLaaeqyTaFeOfuBjJl5bxdhw5xjtBhW7uunI60jEAysqWJLSLCOoV3XPBR17LVcCVjnOH5aZVn7XfXNOu74ySkVRyL8TgI2AjDYwpI1GJZ6U3LVLlf7XbUdXkex0sXxRuLqfG/ajyWPcMZmv3F7U/vYFv8kMHXJpzs0KgPQfRhyDrRt0nhdjg9tjScSjUY9QZ11uo2AY6TX3lSMJWVQgSL1JcJrb+ro6Zyeoqu6aper/a7q83QUWVft6s3+ra980zgu6s25PIcE3ElgoISEgSrH3ndob6ypIvt+d2xT4HAHZdbhMQJ9ifQ6EC58+9thRlftL0HmJwES8DYCnFLxthFhewaMgLNIr0U5B5y6GrYqhV8IuwtfOM+BgdaJ3b/XKKbb1CX6JLPMDFbnrqKYoqw6dae+V90AL7n7/xkDrKrSS8Zt8YS5H3MazXXvuh9LTOJw46oc8R/gbAgGWoiu6izCLepAuxBhNf/CDuOQZ8byL2B3Z4KRGAI/Obpp7jzh6gai4R7f9VuznHbqkseM5T6WyzmLfOtqP9qflDlVck6s1/nxaxF0rbo6vBs6j1Jrj5CLaLNwaOaMLQzP0CYs40vLnittTvwoWPXxkwRIwPsIUMPhfWPCFg0QAcdIr4iK6srVsFWlowtf7IeTm44opo/3GMXUKgfr8bGcLU+DqyEhpgPifeDB6xglFsdL1E4ES9imLPyU8ZppRVfFMWcRbrEf7UIMElibI1rqiV2/w+7O5MpNc+cJVzfwcJ+6+DMm9svON3+ggky7y8i3riLKov0ndz1/XQRdqy5XUWqdRZt1FSEW3mDhin36TZ81UynWtJdVBz9JgAS8mwAFDu8eH7auHwQcI73aXQ3D2BCuhjtcbl+rxNGFL45gOVtvo5haJWHuFevbL5/aaHYhqNKIibd0G8111sqvCNbOw525PTmLcIvjaNec1X9ngjWNnnaXxl7pGqK9OzfN9vIR1wQOuaaoN9CayisaxOmM08i3PUWUdYyga6/D1bazaLMuI8SqsIGQ9nDjPu/WbxlNiqtyuZ8ESMD7CFDg8L4xYYsGiUBf3fFCeLCc6jiLYgrXwHD7nT5qQZeWZ09abaYzsJoCb/iISouIq7vffVK9/OWbtfc6MdKZxx791NppRbjFtEODrpG3nW6s1zvbpUvc4OLcnuxumrHsz+6m2X6eZZgWdDX6akvTtci36Bsi306a/1ddIuLa91tlWeVgusmZ9sFZlFosq0OE3ODgcBMhN/fMByZ8PZbwoQ47W/THWq6IfttWOVpN4CcJkIAXE6ANhxcPDpvWfwJ4+MHuAPYOcMd77vBrxj4CD2C7q2F7TZYL35jEbPvuLtvQhJw7+KqJnYK161gN47gmHecgWuOedf+qcVZWmrX/VsTVkZNvkx06fdFTsiLcrvrEr8yD+PS+P3VmQTyUy6ffN23Aev345DGdx7CB+i03zXhAw5UxeDgmTPdAa5CrZYELYqEgdgPcKcOGBU6JTHwYjT/jbL9jec6+o1wrSi1i2GBqJEa9jtqjzcLmBBFyU7PnOGUbo9E7YUeD4HOIiNna2uSsKu7zKwIqZntoiadfYe5nZ61XIQoc/QTJ7N5NwB7pFYHR0kcucOpq2N4LuwvfiRo+vsu7tHG44zqKKQwe7WmkRqiEVsOKGOks4irON2/sKNtKV10XQ2BxFuEWbpYh6Bzb8YzsWfukRMalm+kVCCFWe126abbq0E/UC4Hslf++Xb0HBunU0TeN50NXkWNd7be3364RsqpyFaXWWbRZCIYImOcYIXaSjsWWV74hr/3yLhP7ITwyET2wquCnnxEIDw9Rx1X10lZ31rd6HqhBFdvqfapPuLNlj8sSujb3qWH1/c70xbU5nGGpjt/YPIBQb9zx2l34Orod1qe0mc7oKKvJGG/ao5g6jgKmReyaBUwNOEaJtdeB/I6ui2FcaY9wi3OgEcDDHW1FQDorOZbl6KbZOg+f1rnQPsAmxO6NEG+OziLfOttvlWOV7fjd2u8YpRbtRz8cI+TifLh+xjSYnS3qtqZWXNVh1eX4SdfmjkT43RsJrN16RW5bmumNTet3m6jh6DdCFuDtBGCYaE+OUx/2Y9a23fOmo9th6xx8wstlT54u7cIG8nSx1dDw9Ej2OvDd2EPYXt4tgSIg4NolawkH1jHkQ3Isy7J76Dja9a91LoQZxwSthbPgas72W+VYZTh+t/Y7cziEfiBCrmNyxhZ1W/1xVYdjOfxOAiTgHQRsOlzvaBBbQQIkQAIkQAIk4HsEKHD43piyRyRAAiRAAiTgdQQocHjdkLBBJEACJEACJOB7BChw+N6YskdDhACMWQsv7R0irXXdzPqaUqkqu3TVr4jr83iEBEjAvwlQ4PDv8Wfv3UAAqy2w9NQx3WhkV8f8jt/tUW4djw3W92Pbn5GNL3xOtr/+XXnr1x9VL6U+tkRxsMCxXBLwQwIUOPxw0Nll9xIozj1o/EcMdq2eiHI7evrdcsdnXpTbPvU7E1Dt5J4XBrubLJ8ESGCIEri2xm6IdoDNJgFvIQB/GEe3/Z+UFZyQkPAYmbzgYfXWmWKixCJ6LJxWTZj7CUkdPstpk51FSC3NP+Y0uqxjhNXT+1/sEuV2ssZFsRKcax3d9mupLD0v0XHDZMriRyUmYbhcOv6uCQxXUXxOoG0ZM+OeTgdl8GB67tCr6g+kVd22f9y4bXcWeRbeSK2EMssKT1pf+UkCJEACXQhQw9EFB7+QQN8JNNSWGVfmC27/nqRkTpf9G39m3IMjiBv8XExSAQTuxp0lVxFSnUWXdRZh1VmUW9QDR1lbX/uWiVSLQG8QhLa++i3jWMyEjN/6tHoynSYjJt8qe9f/ROAkDALI7rU/NO7YJ81/2EyTdBd5FtNF0GycOfAX43bcWf+4jwRIgAQocPA3QAIDRABuzYercFFZct6UCM0CnFfFJo1UL54RJqqr3aGYvVpXEVLhEj08Kkm1JidVWIjWcPQlxvU4BBg84OH0C5oFCDIB6po8IXW8fs/qLLquutBEfp2x/It6bJyGdv+c1FTkqkajwJwzQl2vj9FpEfyHs7H66mIpzNmr5Y0zGo/0kfME2pLuIs8W5x6Si8ffkaX3/quJkNtZOTdIgARIwEaAAocNBjdJoD8ESvOOybvPPSKl+cfVJfeNxUIwkV2dREh1Fl3WWYRVV+2Ga3C4Ube8rVpeUTFVgoT4KVaCR1Tsh/vxqNg0a7f5NO2DV1WN4OoYeRZlQuhJTJ/YJQ+/kAAJkICdAAUOOw1uk0A/CEAzkJQ+WWYs+7xGpr02dYIHeUNtuVk2iikOZwkaCkyVZI1bbjQLWeOXm+kYK7osApdZLtHtEVZHTFpjIqyiTCvKrb2OKLXZgEBw+fR7ptqcExvUXXmiChQZzpph9sUljRJoLRBHBQILQsbbI8/C1iNNNR+Wy/Y4jVKL6LdMJEACJNAdARqNdkeHx0jgBggMG7NEpzleljeffkAQHdVKmMrAQ//VX94p8299QrLGr7h6CHFdOwKmZIxa5DRCqrPoss4irKJAe5Rb2JEgQbMxZ/XXZf+Gn8nB9/7TTLvMv+3bHfFWNC6JPdpqR1sCJHPcTZJzaqO89fRHzfnjZj8g3UWeLbiwU86qgSlC2TORAAmQgCsCjBbrigz3eyWBvkSLdWdHEPkU0ymw1cC2FWCtva1VmhoR1TW+S3McI546i5DqNLqsiwir9ii39oqg9UBEWHugN8eItI5tgXZD52M6NSsoz1nkWZTdjr6qJsebE6PFevPosG0WAUaLtUjwkwRIoFsCEDAsw1BL2EAGGHQ6ChvY7xjxFJoQ/LcnayrF7LOiy7qIsGrVbc+PbURZtQsbHft0RtUWkdaxLQhX75isSK32/Sg7wMuFDXt7uU0CJOAZArTh8Ax31koCJEACJEACfkWAAodfDTc7SwIkQAIkQAKeIUCBwzPcWSsJkAAJkAAJ+BUBChx+NdzsLAmQAAmQAAl4hgAFDs9wZ60kQAIkQAIk4FcE6IfDr4Z76Hc2JjpMpKlAmlquD/c+9HvX1x4ESmBwlLS11GgBzh2L9bVkX8oXwrudLw0n+zIECdAPxxAcNDaZBCwCB06UyuWCOlkxN0027S6QsdkxMnVcgnWYnyRAAkOMgC/74eCUyhD7MbK5JAAC1bXN8vqmyxITGSJ3rxwusTGhcu+qbI11EiCvv3dZ6upbCIoESIAEvIoAlYxeNRxsDAn0TGDvsRIpLGmQ25dlSnBw13eGSaPjZUxWjLy7LU8y0yJl9uSkngvkGSRAAiTgBgJd71ZuqJBVkAAJ9I1AeVWjvLTukqQkhMsdy7OuEzasUkNDg+QuaD2iQ+TVjTlSDM6XmwAAQABJREFUVdNsHeInCZAACXiMADUcHkPPikmg9wR2HiqWyuomuW91tsZnsfkj76aIsdmxMjIzWtaptiNZhZT505K7OZuHSIAESGBwCVDDMbh8WToJ9ItASXmD/EW1GtkZUXLr0sxeCxtWpcFBgTr1kiXpSeFGO1JW2Wgd4icJkAAJuJUANRxuxc3KSKB3BBCBdev+ImlqbpMP34BWw1Xp2cOiJSs9SjbszJfoiGBZPCvV1ancTwIkQAKDQoAajkHBykJJoO8E8ovr5KX1OTJhVJysWphxw1oNVzVjKuaWxcPMNAu0JkVlDa5O5X4SIAESGHAC1HAMOFIWSAJ9I9DW1i6b9xSqgCHykVtG9K2QXuQalhop96/JlvfUb0eQCiHL1IcHQswzkQAJkMBgEqDAMZh0WTYJ9JJATn6t7DlSIisXpEtinHpTHeQEAePmBRlSrFoOaFMWz0wRCCJMJEACJDBYBDilMlhkWS4J9IJAS2ubrN+eJ7kFtXK/ajXcIWzYm5WSGG60KecuV8tGte+AloWJBEiABAaDADUcg0GVZZJALwjgIX/gRJnaVWSoz4zQXuQYvFNumpMm8PPxyoYcmTs1SUaokSkTCZAACQwkAWo4BpImyyKBXhDAypO1W64Yvxqw1fC0sGE1OSE2zGhZ8ovrje+OlpY26xA/SYAESKDfBBi8rd8IWQAJ9J7AyQuVcvxchdy6JFOidHmqtyZ4J8VUz4wJCTJ2RKy3NpPtIgGfIYApzdMXq+RKYZ0JSwA77ofvGSOR4d57n7hR+NRw3Cgxnk8CfSDQ0Ngqb23OvepXY4RXCxvoHtyiw6akuq5Z3v5A293U2odeMwsJkEBvCSD2EVaOnb5UZT4RoNGXhA1w8B3RqbejyvNIwM0EjpwuF9hr3KaeQsPDgtxce/+qmzUpSSaMbJG3dQpo/MhYmTwmvn8FMjcJkIBTAhPV7860cfFy5EyFOf7ArSOdnjeUd1LDMZRHj233agK1GiIeoeJDNKIrQscPNWHDghupUz9oP1awvPn+ZalvaLEO8ZMESGAACTxw20hT2jw13EYcJF9L1HD42oiyP15BYP/xUskrqjNajdAQ35Drp45Tew4NCIdgcCMyo9S+I9ErWLMRJOArBKDlmD4+QXxRu4ExotGor/xS2Q+vIFBV0ySbdhXINL1pjBke4xVtGoxGnL5YKSfOVxrX69GRIYNRBcskAb8kgKjQcTGeXSY/WOApcAwWWZbrdwR2q6dQRHe9ZckwQZRWX0/NumwW2o705AiZMyXJ17vL/pEACfSTAAWOfgJkdhJAyPf3VKsxb1qyCSPvb0Qu5HY4MLtZA83F++ibmb+NKftLAoNBgALHYFBlmX5DYPuBIqlR49DVAxjVdSjCa21tl/U78ozAsXBGylDsAtvshwR+8j8bZOvucxKgQQyZXBNoa2uV537xiMTH9i/eEo1GXTPmERJwSQCh3T/YWyhLZqVIRkr/LkKXlQyhA0FBAcZAFjFh/rLukqycny5J8YMfhG4IIWJTvZBAWWW9BEZkSUiY760IGUjcLbXnpX0AwixR4BjIUWFZPk+gXa86CBqtukQUId4Z1r3rkGelR8mHNeosDGfDQgNl6exUMuqKiN9IwG8JUODw26Fnx2+UAJa5bj9YLMvnpgmirDI5JxCo6unVizKkoKReXlqfY4QOGJYykQAJ+DcB3zel9+/xZe8HgAAcXiHOAbyFItgahY3eQYWQAS3QKY0fA6NacGQiARLwXwLUcPjv2LPnvSBwKa9G9h4tlZsXpguiqTLdGAFMOS2fly6lFY3y8oYcWTg9WTDtwkQCJOB/BKjh8L8xZ497QQCh2eFjAqHaEcSMwkYvoHVzCgxIoR3Kya81UWixqoWJBEjAvwhQw+Ff483e9oLAWY3WeOhUuaxZPMxETe1FFp7SSwKLZ6VKhXpSfGVjjsyZnCijsnzXG2svkfA0EvAbAtRw+M1Qs6M9EUAIdoRiR0h2aDUQop1p4AnAORi0HZhmWatRaOGxlIkESMD3CVDD4ftjzB72gsCJ8xVy8nyV3KpuyREdlWnwCcydmiw1Kty9oRFop46Nl/Ej4wa/UtZAAiTgMQLUcHgMPSv2BgIItY6Q67ApuG91NoUNNw8KAr99ePUIqW9slTc350qDfjKRAAn4JgEKHL45ruxVLwgcPlWmBoz56jNimCD0OpPnCCDU/aoF6fLutity9Ey55xrCmknAywi0tTbLmQMvDUirqstzJe/c9gEpqy+FUHfcF2rM028ChcVV8pV/+LO0eMA3A7yFNqq9BiK6BgcHyqtvh8kzP/2rfveJBfSPQER4sNxzc7YcP1chr2+6bJyH/fhX78qRU/n9K3gAcuvqXvnm51fLvBkjB6A0FjFUCRzf+VtpqC2T2au+5rILjfUV6l03SELD+2YQXVuZL5Ex6qE3MMjUkXNqk1zW/+Nm3e+yzu4O2MsLCAiUPe/+SO58/CUJCg7tLtugHKPAMShYWWhPBJpbWqWlPVgCI0f0dOqgHI+0uYIoKz89KHWw0L4RmDwmXsYOj5F12/Pk1PkSkfAREhjoWQPelvoCqatv6luHmMtnCAwbs1SgceguHdnylETGpsnkhY90d5rLY+t+/2lZ9fFfSWxSx73x0vG1MnzCzS7P7+mAvbzo+GESpf/zzm3tV5k91enqOAUOV2S43y0EGIvELZiHXCWhoUFy54rh8sLLiOIZwHgsQ24EfbPBlSXnpKmhWhLTJ8redT+WpMypknNivQQGhcrMFV+S0rxjUnBpj34PltL847Lknh8ajciRrU9JjU5nJGZMlmlLPmPKOLbjNxKdkGmmOFDejGWfl/2bfiGtzY2yb8NPJW3EXJkw9+NSnHtY5qz+OwO0XaO2ntz7ghRc2CWBqqEYPfVOFRxWSk3FFTmy7WlZdMcPzHmn9v5RImJSpChnf5fyIARljFoo+Rd3eUTgoA2Hb14X7BUJ+ASB0BDeonxiIH2kExAaqkovmN6U5B2Vk7uel4nzHpLgkHA5vOV/VEiYI/EpYyU1a5ZqOB5WzVywbHv9OxIVmy6zV39dyotOyzGdlmluqpWLx9+R6vLLMm3pY5J7ZrNcPv2+jJ15nxFWRk+/WwWCVVrXJY3S2qpTLGmmTggbl06sk+k3fU5GTfmQ7H73SSkvPC3NjTUqmBzqpFypbaytyLuuPJyAtuCYJxKvZk9QZ50kQAIkQAJDnsCsm79qNBEjJ9+mWoY8iYhOkfDIBDOlkpQxRepriqWi6IwEh0aqQHBQwiLiVPNxzPQb++bd8oSkDp9lyqitzDPCCmw3ElLHS0xClmpCqjRPvBFCkCnv7Faj9UgaNkWyJ642gk1hzl6XHCH82MvDieFRyVJbVeAyz2Ae4JTKYNJl2SRAAiRAAj5LAEaYSIFBIYLpDsfU0txgdgWHROi0YKCkj5ivwki62Yfv1pRykIv8EBZamus7i21urpOQ0GsGaEEhYZ31trf1zoFea0tDZ72dBbtpgxoON4FmNSRAAiRAAr5PAMJHXXWRToW0makQCAgR0ckyZsY9MmrqHarFGNMtBBhIQzOC1XSRaofR2tJobD6QKT55rOSc3GiOoY6iywcEmhRoQaANwYqUmopcY7thVWIvD/uQD+3xRKLA4QnqrLNHAl96cKL88afLZO6UpM5zZ09KNPsef2Bc5z5ukMDEUbHy2yeXyFM/WCSZqZGdQOZNTZI//Pgmma0xW5hIYEAIYH20GjEjQTthaTjs2goYZV7UlSWv/epunc4IlDlr/l72rv+JvP3Mg/LGU/errcZ7pgzNbcoxf6ApMWWLZI5ZIltf+5bsfuf/SVRcpi6vjZWK4rPmtKlLPq3Gp5e17Lvk3ec+JWPU1iM1e7aZwsFKlrW/fUTef/FvNF+GqQOZ7OXhe2XxOTV6nYRNtydOqbgdOSvsDYHtB4tk2dw0E9Nk77FSkwXxTQIDA2THweLeFMFz/IQAvJXCh0dEuMgj94yRJ58+Ynoeoj5W4GcFn0wkMBAEzFJX1TwgrfnkrzttK/DQx3ekjNGL5N4vvCU612GMRoePXyFZ45YL/HOEhkWb6Recd9dnX8GHSbNXfrVT4Jh7yzdkuq5YCQ7FNEyACgxLJVcNSmHrER2fKbc8/KwaidYKplNglGqlhbd/z+xHPnuyl9fW1iJXdEns4rv+yX6K27avtdZtVbIiEuiZwP7jZXLucrWMUX8MMyYkSKs6CBs3IlZOXqiUI2cqTGC1JRp5NC0pXIrKGmT7ATXO0iikSDFRwbJyfrokxYdLeWWjbNpdIFU13a+d77lFPGMoEJipWrDp4xPk8Gnn3kqjI4Nl0YwUyUyLNL8JRAXG74yJBHpDwGg0riomsPTVnuzfsWrFniA4wJjUnuznW06+rON2p2FjZ31Ytr7yRBdnYyFh1+w4rDz4dLXfKg/LZGHYmpI1057NbdtdibmtWlZEAj0T+PO7F+WJz0xT75PDjcCBHC+uvSjJCWHy5FdnS3xsqLSpIAKtx32rsuXvf7rXPER++JXZkp4SYQQQRCYNDAqQl9fn9FwhzxjSBHYeKpaFKkx88q7R8s2f7buuL/jd/POXZ6kgGtZ57KO3jZRf/fGUbN5b2LmPGyTgTQSw0uSOx14ckCZ1aGKeHpCy+lIIBY6+UGMetxCwtBxWnJOT5ys1zkaFPP7AeCNsPPXiaXl/T4E8eMco4yTq5oUZsnVfkRE2dh8pkZ8+e0yyM6IY/twto+X5SjDmEDAnjo6T5Tod19Tc1Wr//jUjjLDxyoYcI7hOGx9vBNqHVEChwOH58WMLXBOwbEVcn9H7IwNZVu9r7TiTk5s3Soznu5XAn1WjYaUXVeOBNGZ4tPkcrdMtj354rIzM7Piekdyh1UAEWBibwvA0WLUb+cXXlpWZjPzjswSee/2cseD/2IdGiaPTMEzJIRotNGeYojt4slyOnq0wQkpctGddp/vsgLBjJGAjQIHDBoOb3kdg/4kyMzVSUt5gtBtoIYwEkTLTIsxcfJAKFSfOV8iF3BrzVvujp4/KxSs1xuj0X/52jty2dJg5n398n8DZnGpjz4Npk9tuyuzSYdj21NW3SEtrh9EfDra2thkB5aodYJfz+YUESGBgCXBKZWB5srRBINDS0tZlWqSkosGoxv/9dyekrPL6gFowLH3i5/tl2rh4+c5np8v8acmydqtnXPkOAg4W2QOBF96+IPOnJ8vorK7ROksrGs2+FLXlKC5vVNsfkay0KLUDEqlVQYSJBEhgcAlQ4Bhcvix9EAjs02Wyk0bHy7cemybrNaIolj2O0ofLB1cN/z513xjZtLPA7MfSdhiWMvkPAaxaWrvlity1cniXTu89Wipjs2Pl249Pl81q+zNNV7OkJIab3w2mWJhIgAQGlwAFjsHly9IHgADU3XaV99sfXNHlsBGyckG6fOYj400NZbr8FeHMoTLHlMvD6o8BCQ+fP7x1wWzzj28SsH4b8MxopZfWX5IFquVI1d9JfUOHy+k3N+ea381Nc1PlwTtH69RKm2Bly7OvdDhVsvLykwRIYHAIUOAYHK4sdQAJfPVHu7sIHJiDf/ovZ+TXL5+RpLgwadTVCHY/G49/f4eZcmnV8yzfHAPYHBblZQT2HS+VT35zS5dVKXUqZHzph7vNtIkVYgKrVn71p1Py1J9PS2JcqJRXNXWx5/CybrE5JOBzBChw+NyQ+l6H7EZ+9t7hQYK5eGcJ8/VM/kPAcQms1XNL2LC+4xPTJ65+N/bzuO37BEKCA6SlLlekKcj3O9uPHtbU1PYj97WsFDiuseAWCZAACZCAHxH4zpdvMyuV/KjLfe5qRHj/l45T4OgzfmbsL4GWlnbhe0V/KTI/CZBAXwmEhfIR2Fd2fclH2n2hxjz9IgCfGlim2tBQIyEhnnc53tzCJZH9GtBBzAwfK+01V6QNa1jdlLCqqVntf0J19dPVAJ4a+ntgVMpu6gKrIQGvJBCglt3XTLu9solslK8QwE9ti7oeh1+NJbNTJDfPeYAtd/cXsVhGZSe7u1rW1wsCeYUVUl/v/sB7bfpbhat0eCudpQHhEEo8LSVWoqOuxWHpRfN5CgmQgI0ABQ4bDG4OHoGCknrZur9IbpqDCK9dwycPXq0smQT6R6BA3eJvPcDfbf8oMjcJdBCgwMFfwqASgHr6fQ0PH6zqaQgbCNPMRAJDiQA0c3AqhxUvy+elmejEQ6n9bCsJeAsBChzeMhI+2I7cglrZebhEVs5P7xIS3Ae7yi75AQHYHr2/p1AWzkg2LtH9oMvsIgkMKAEKHAOKk4WBADw4btpVoB4/g2XxzFRCIQGfIrBdp1gQe+XmBRkCo1YmEiCB3hGgwNE7TjyrlwTO51bLQY3wumphhsTFhPYyF08jgaFFAB5sN+7IlzlTkmRkZvTQajxbSwIeIkCBw0Pgfa3aZl15skFvwCmJYTJ3Cld8+Nr4sj/OCew5WiIl6u129aIMEyzQ+VncSwIkAAIUOPg76DeBUxoO/vi5SlmzOMMETut3gSyABIYQgZraZlmvwvaUsfEyfmTsEGo5m0oC7iVAgcO9vH2qtobGVtVq5El2RpRMn5DoU31jZ0jgRgkcPFkmuYV1ska1HWGh9KF7o/x4vu8ToMDh+2M8KD08eqZczl+uMVqNiHA6rB0UyCx0yBGob2iRddvzZGx2rNF4DLkOsMEkMIgEKHAMIlxfLLpOrfPXq1Zj3IhYmTwm3he7yD6RQL8JHD9bIWdzqmW1TjNGUiDvN08W4BsEKHD4xji6pRcHTpTKlaJ6uUVVxqFUGbuFOSsZugQam1qNbUdWaqTMVPfoTCTg7wQocPj7L6AX/a+qaZaNO/Nl2rh4GauaDSYSIIHeEzhzqUqOnqkwK1liovof4rv3NfNMEvAuAhQ4vGs8vK41uw8XS1llk7lZwj05EwmQwI0TQMBCrGRJig+T+dO4bPzGCTKHLxCgwOELozgIfSivapRNOwtk7tQkGTGMjo0GATGL9EMCF6/UyL5jpcYxXnwsHeP54U/Ar7tMgcOvh99557dpVNc6tbaHt1CEbmciARIYOAKtre3q+j9foiLU9f8suv4fOLIsydsJUODw9hFyY/uKyxpks0bFXDwzRYapoRsTCZDA4BG4oj47dhwqlhUagTY5IXzwKmLJJOAlBChweMlAeLIZCL8NQaOd4bc9OQys2w8JtLXptacRaAPVPGrZ3DQJCKBG0Q9/Bn7TZQocfjPUzjuaX1wn2w4Um5tdaiLfspxT4l4SGFwCRaX18sG+IlkyK0UyUqhdHFzaLN1TBChweIq8B+rNLayV7SpcfPS2kYI3K4SQDw8NlKVz0jzQGlZJAiTgSGDLvkJpbGqTmxekG/up1zblyKyJiZJNw21HVPw+BAnQJ/UQHLS+Nvk3r5yTo2fLZVhKhBSWNshKvaklxoX1tTjmIwESGGACN6nwX1bZKK9syJHMtEj50zsX5eDJcvn+F2YMcE0sjgTcT4COFdzP3CM17j5SIodPl6tmQ+SFty/IvauyKWx4ZCRYKQl0TwAvAfetzjbXaYuuaDmmbtJ3qnEpEwkMdQIUOIb6CPai/c3qdOi51851nllc3iivvZfT+Z0bJEAC3kVg7dYrkl9c39mo514/J03qKp2JBIYyAdpwDOXR62Xb8Xb0ga5CSU4IM8vvktXbYXpyhIweHtPLEngaCZCAOwmczamSPI1bVFrRKCXlDeZz6exU2lu5cxBY14AToMAx4EhZIAmQAAmQAAmQgCMBTqk4EuF3EiABEiABEiCBASfg96tU3tp4RK4UVA44WBboGQKB6jjpwfvmSWQE41R4ZgQGv9Y/vLJbauqaBr8i1tBvAgiN8Fcfni9hYYyS22+YPlCA3wscL755UEprIyQwIMgHhpNdaG0skntunU6Bw4d/Cn98Y79ICH3HDIUhbm0slI/eOZsCx1AYLDe00e8FDjAOj4iXwCBK4G74vQ16Fa1SMeh1sALPEghUP+ChUYmebQRr7xWBlvbyXp3Hk/yDAG04/GOc2UsSIAESIAES8CgBChwexc/KSYAESIAESMA/CFDg8I9xZi9JgARIgARIwKMEKHB4FD8rJwESIAESIAH/IECBwz/Gmb0kARIgARIgAY8SoMDhUfysnARIgARIgAT8gwAFDv8YZ/aSBEiABEiABDxKgAKHR/GzchIgARIgARLwDwIUOPxjnNlLEiABEiABEvAoAQocHsXf+8qrSi9KwcVdvc/QzZlnDrwk7W2t3ZzBQyRAAjdCoK21WXBdDUSqLs+VvHPbB6IolkECXkWAAscNDkd54Sl559mHpL29zWXO1pYmqasucnm8pwPO8h/e8r9SVZbTU1anxxvrK6Spobrz2MXja6Xg0p7O757cqK3Mp/DjyQHwwbo3PP+4Cue7u+2Z4zXR7clODjr+bnNObZLL+r+vyV5eQECg7Hn3R4L7gKdTfzl5uv2s37sIUOC4wfGIjs+S6Td9TnBTcJWKcw/Klle+4epwj/sd8zfWVaiAsFuGj1/ZY15nJxzZ8pScPfhy56HsCTfLhaNvd3735Ma6339a8EbHRAIDRWDywk9JQur4botzvCa6PdnJQcff7SUV4ofrddXXZC8vOn6YROn/vHNb+1rcgOXrL6cBawgL8gkCDN52g8PY3Fgjl06sk8yxN8ml4+9KQ125VBSfk9rKPBkz4x5JypgqR7f/WuqqCo3QMWHuJyQ5c5qc3P285F/YJeFRCTJl0V9LfMpY2bvux5KUOVVyTqzX4HGhMnPFl1SQCbouf0NtqcQlj5aI6GTT2oKLe1SAeEka66skdfhMmbzwEQkKDpPda58U1BeXPErKCk7KhWNvS1L6ZKPNCAwKltL847Lknh9K+siFcmzHb67reUd/yqS86Kz2J18mLfikVKtW5fLp9yR52FSZvuzz0tJUL0e3/Z+Wf0JCwmNk8oKHJSVrxlUWmrfwtGp3CmXEpFuUx73K5qwc2fqUtrVSxky/R0ZNvb2z3r3rfyKtzY2yb8NPJW3EXNMPV33rzMQNEuiBAB7UkbGpqiFoNL/z6IRMM0WRmD5RZuhvOOfkxuuuiUK9ps4c+Iu06tTIyMm36u/0DnMN4zpxzL9/0y+6/G4nzP24FOceljmr/860DNOVJ/e+IAV6vQcGh8roqXeqMLLSCPmtrU0yVq8LpK2vfsvkObbjmS7l4XrOGLVQ8nUK1S7E4J6C9kTFZ2h/tum1P0vPWyzH9H4TFBIus1Z+RaLjM/Xe8Ipes5ukrbXF3KcmznvQ9OX4ruf0HpKi96Edev8ZJzOXf0FEX5wObf4vKVPNbVzSKJl/27dN2/Dn4rEOTaj93oGXn6Pbfi2VpeclOm6YTFn8qMQkDO/Mww0S6I5AYHcHeex6As1NtVJ0+YA5UFNxRR+mTxuBYoTepPAADQ6LkuwJqyQ0PFYf2A8bweL03j8Z+4u5a/5eYhNHyLbXv6tTMu1SkndUTu56XibOe0iC9YZxeMv/GKHCMX9lyXmJiukIxw0BYNvr3zE3opkrvih557fL8Z3PmfaUXDmsUydVZrtRBaFSLT9txBzThtSsWfpAf1gCA4MlKi7d3IwhyNiT6Y8KE2nZsyVdBYBtr33blDFt6eN6k94g+ed3SENtmbmpLbj9e5KSOV32b/yZKaIz78h5Mm7WA3Lgvf80wsehzb9UYWW6LLz9HyQyJtVenYydeZ8KWsEyevrd2p9VRrhx1bcuGfmFBLohgOsKD0ZcqxePv6MatMsybeljkntmsz6I37/umqityJOd7/yzEdbxcMZvFwK1q/yOv9uq0kt6Pbfq77vjGoWwgZcSaEJHTfmQ7H73SXMtVJfnSLWea6WinH3S0lx33XWA41Gx6YJ22ZPVHrwMTFn0qJw98LK+ZPxQJsz7hF7XQSp4PGNOx0vQjGVfNMIVBBTYfyHvhaNvSb1O9c5a8RV9+Tgmx3f9zrSzUo8vufuHRsix1+d478DL0NbXvqUySqARlPDCAaGpu+lle3ncJgEKHP38DYyYrG/y+sDEf0yzNOmNLjZppASHRhitQKhelBAKcDMqzNlrNBF1VQX64O542M+6+avm7X7k5NukRm8wQfpG5Jgf2oHwq9oNlAWtBjQISRlTZPzsj0phN/YYeKMJj0zQN740cz66GxwSISGhUVKrb0yOCeXi7W60aiOQ5n/ouyqAzFGharrResQmjTDCAYQgpPqaEvOJPybvlNv1+EpJGjZZBaoj5saZe+Z9c9NP1XLsCVqeAL1RQv0dk5BlON1I3+xlcZsEnBEIDo2Uebc8YbQB0KJBE+l4TUCTEBGVpG/tF8zvNDwqUQXtY6Y4Z/kdf7cQ8sMi4o3wjEx5Z7eq8PJxvQamSPbE1QJhH9e+q+RYHs4Lj0rW67Pguixoz9w131DhYIHEp4032lJoQ4ZPXCU1KoggTZr/kHnxwEtAaHh05zWKvHP0pQftGjP9LvMyAcEG1/KVMx+o5nNBl/ocOUFzWVF0RmYs/6Jes+OMQFVTkatMr29nl4L4hQSuEqDA0c+fAt4srBQYFGLedKzv1mdLc71RrQbptAluTHNWfV0/48xhyxbE5HWxcgRaidbmBnN+S1OdERassoNCwqTNyhcQ0Ku3DWhXWlo6yrPKsT6t/gQGdfSrs33aBrzJ4Eb87nOPmOkZTK/Yk5UX+9BXzSCzV31NRupb3v6NPzdvY/bzHbe77ZvjyfxOAr0ggN9vgF4XSEG4Pq1rxZbXXJ/6+8ZvFv8nzPmYpOsDHak3+SE0owwrNavWAgK9lXCNdtTbu+sT+Vr1+rTabZWDT3t7cF+AtgEpMDBEr7c2nUZplvW/f0zOH3lDp3srdJ85bP50yav9bNd/6SPny+I7/0nydJpl7W8fNlOf13J03cL1CS0HtLFIeDlCgnaHiQR6Q4ACR28o3eA5EB4aasvNHCoe7niDadf5VEwdwM4jTacdcI6r5Jg/Iialc9VLvL5ZFF3ebzQkuIld1vlovLEghYXHqer2lLFuv6j2JVZCeVg1Y6k+G+vKzA0wUsu90YQ3NdiFYC48PnVsl+yFqiLGWxnqgr1IXMoYs7Jm3Kz7Ze4t3+icirJnwo2yvqZY26acuumbPQ+3SaC/BOzXBK5PaBGzxi0312fW+OWd9lKu6rH/bnEdwV7EWgkWnzzW2IngN41rAVOw0EbiJaNcNQS4bmEvBXsOK9nLwz7ks2y2rHN681mjGhxM3UCrM2LSms5rHnmbG2vNlBLuA9BoxCePMdOYsMFadt+PtV0tRsNjr8fOKUptNiBkwKYLKefEBtWeJqoWM8Oehdsk4JIAjUZdonF1IEDwzyTz5nR1W3d07A8w6kZcmK/+8k6Zf+sTRu25/Y3vyZtP3W+mEGBsteKj/27eYPDWgWR/+4C60p4/MX2SnNn/F3PzGDZmiVw5u0XefuZBI7TgZjlHjVCRxuub2W5dTgfjMBiUoUVIULlue+O7Zg777s+9aoxcw1WFDJVpl9SlPx15O9+y0E49jvrP6Nzxm08/YKZpuuTX+tb97lEjaMEOBW3Y+MIX9E2rTI1N69RW5cGup+u3TC0P88JZY5fp9M13XPbtuozcQQIuCOA323Fd2a5VnHv1N4xN+zVx12dfkYzRi+TtZx8004/Q3C3/yM+R4eo1jRyabPm7/m6/a2y2YCCN3/zUJZ+W7Wqn9dqv7jLXLIxEU9UuKiZxuJw99Kq5LySkTbiq5ey4zuzlwT6qUg3Rcd13TV3bc62fuDS1HP0Po1FMf+L+EBIWqfeRsM4ioA05vvM3skeNy6PiMnR65e/MUt6Te/5gNDswTHdc3WPnhHvHnNVfl/0bfiYH1c4Fmh0YmcIOi4kEekMgQKVwm9KtN1l865xHvvY7qW3L6Fbj4NhjWH/jIjMaA8WHCw/J2o9tvMU0NVabKRR8R2rWh65m0htBh7rVfj6O27/b82P/G099WBbd+Y9XBQmoXBtNHZiXtSes3Ue7oPa0l9eCKRns1/OxOgbW87Nv/ht71h77gxsubmxt+iaEmzLsU7AN1S4M1iBYzFzxZTPFE2JrFwz4YNNiv/nZK8abIY6jHCRXfbPncbXdWnte/ufJj0hKUoyrU7h/iBO49zNPSWjsBJe9sP/u7du4pvBQtoR8+zWBwnDtQDAOVU2EeYDrvu7y23+35ppSTSKmEK0EjQKmU6zfNfbjdmumDvUeYC8bx6zysP3G/35YFt/1T7oCbCa+diZ7Hlx7mOLobOvVaxEnQ2MDw3VMd6B+2Gm89+JX5J7Pv27qsaZ0cS7ag2lZXM/OkiMn9KHDbqVjWthZHmtfS805+c2/fUJiYyKsXfz0YwKcUunD4FsSvdFKdLHhuCbpQwiBvYY94SFsCRvYb5VjnWP/bs+P/VhSCs2GlfDwdhQ2cAyaEWuO1V4e9uF83Czyzm/TpXn3WUV1fvbUH+vGhhuYdXOy30w76g/T+euuQlBYZLxLYQN5UJa9HFd962woN0igGwL23719G9eUJWwgu3VNWEXh2sFv1fqdY393+e2/27GzPmyWqlpl4RPXuv13jX0o27oH2MvGMau8El1iC+2jo7CBc+x5UHaXtl4V2HEeBAocu77+wKuaFZzVkXBf+P/tnQdwHcd5gH/0ThSiEgAJFpEgBFYQLKAoSqKoLkqWx03uduSJy0wmcZyJ7UxmEsfxjONJ4rEnkaWR7cR2HCeWI1M2SVEkLYtFBEmRYgXFBhC99/oAPOT/F1zgeLh37w7v3nt37/7lEPfubuu35f7d/XdX1mf5THlVcyJ/lQKL0i7/ZgJ6BHiEYx4jHHpAg/mORi6UDeZ8w6KenhyVma8fanfq0Rv1+1Dd8whHqEiHLxx/IxzhiplV9ZPib6Vfwj+s87Rcdj56IYHw5BGOQOhFntvZLnnkpS3iUmSFsEFQrBY2yE/1qAY9Y8ME3ETAqvpJzKz0S/iHozuhFjbclPecVmMEeErFGCe2xQSYABNgAkyACQRAgAWOAOCxUybABJgAE2ACTMAYARY4jHFiW0yACTABJsAEmEAABFjgCAAeO2UCTIAJMAEmwASMEWCBwxgntsUEmAATYAJMgAkEQIBXqSA82rCKNtBh43wCXs+Y8xPBKdAlMDnpnTmQTNcivww7gckxro9hzwQbRcD1Ascnn9sEre3TR7rbKF8si0rPAB641ueF/mEvpCZHQ0ZqFORmRK5wFRVVDCnJs9s5WwaSPbINgc9+aAsMj8yeQ2KbiM0zIqOeKahrxV1D0X3BwhhYkBI5A89R0cWQkOD73Kh5ImNnDiXg+o2/HJpvutH2eqfg3Std0NI+AisWp0HZigw4cKwJHruvEOqaBuG9q92wtCgV1q3K0vWHXzIBJhBcApdv9MLt5kFRN6Ojo+DY2XbAC1RtyA1uwOw7EwgDARY4wgA9WEGOjk3CO+91wODwOGwqXwgFObNbjEuBQ4YtBY+SwlRYX8qCh+TCVyYQCgJ0xMDhky2Qk5U4R/CXdfOJ+4sgPi5yRjtCwZXDsDcB10+p2Dt7jMWuu3cMTl7ogBjsGm3HnlFqiv8hTBI06D81bq8drhe/WfAwxpttMYFACAwOjcN+HHF8oDJfCBxqv6he5qIgsvdIPezYlAd5C/ngMzUjvncmAR7hcGa+iVjXNg7Axeu9kLUgHrauy4HYWN+9IfUIhzrZsldVsghHPFbziIeaD98zASsI3KjvB5pGeRJHL/TqqwyLRkEysX5vLFsoH/GVCTiWAAscDss6Goo9V9MNja3DUIJ6GGtXZhpKgT+BQ3pC88nk/5JFKbBhNTdykgtfmUCgBN4+0waJ8dGweW2OKa+u3uqDW9i5eHR7IcTEkGopGybgTAIscDgk3zwe1M843wF9g+OwEUcgivJTTMXcqMAhPa1HweMsCx4SB1+ZwLwJkG7VvrcbYQsKGoV5s3pVZjzsHfDAwePNsLuqAEc8eBWWGXZs1z4EWOCwT15oxoQampOoCDqFb6vW50B6WrymPX8PzQoc0j8peCwuSOFhXQmFr0zAIIHG1iE4dakTnthRBIkJgS1Hp9VnVI9p2pNWnrFhAk4jwAKHTXOsvmUIzuPy1QWpcbBtfW7A2urzFTgkHorPWVxqy4KHJMJXJqBP4CSOSHrGvXA/Kn5aad7DkcfO3lHYtbUANyzkKRYr2bJfwSXAAkdw+Zr2/cL73bhyZAinTJJRhyLLsgYlUIFDJqQBe2zvXmbBQ/LgKxNQExif8ML+o01QjqMQy4rT1K8tuW/vHoW3T7fCYzsKcUM//6vSLAmUPWECARJggSNAgFY4pwaqGntD3X0eWLsqUyxRtcJfpR9WCRzSTyl4FKMuScW9rFwqufDV3QRCKQhQu0G6IWvuyQyaYOPu3OTUW02ABQ6riZrwbwDX49NGXRN4NgQta81KD54ymNUCh0wmzVGfwREPEjw2llk3IiP95ysTcAoBmuro6BmFh7eFdqojWFM3TuHO8XQOARY4wpBXze3DQh8iOSkWFUFzA1YmM5KEYAkcMmwpeNBUUAXuGcBzy5IMXyOdAClzvoErSGgpedny8ChzSuVU2t8jIT4w5dRIzy9OX/gIsMARQvZXcMOfG/UDUJCbJD7KdHZCqEywBQ6ZjsY2HPG41AVFuPyPplpY8JBk+BqJBHr7cbnqiWZ4pGoRZOAGXeE0tPz2939sRCXzHFiUO7/lt+GMP4cd+QRY4AhyHtN0yamLndCBSl73ohLZisULghyitvehEjhk6Cx4SBJ8jVQCNbd6obZx9uA1u6TzLVQmTU6Mhc1rsu0SJY4HExAEWOAIUkEYGpmAE+faYQyXxW1dmw3ZmYlBCsmYt6EWOGSsmtqG4TTuQ0AbHm3iEQ+Jha8OJjB98ForLMyIt+1uvDdu90MN7lD6+P2FEBvj+8gDB2cDR92BBFjgsDjT2rpG8APbhVsYx+BBajmQhD0NO5hwCRwy7aS3QiM9LHhIInx1IgHqSOzHlSG0t0auzQ9VI6V0qvcPbs4Pe4fHiXnNcbaeAAscFjG9VtcHV2v7xSmPNJQZSv0MI0kIt8Ah4ygFD5pjrixnHQ/Jha/2J0DnmVx4vwee3FkEcToHJdopJTQa8+aJFsjPSTJ87pKd4s9xiSwCLHAEkJ+knU7TBa2dI7BqaTqU4n+7GrsIHJIPCR7EriA7CSpRQGPlUkmGr3YkcPTdNiFk0PJ1J5pL13ugEac3SbnVbp0hJ/LkOM+PAAsc8+A2Mor6Gbh/Bl3pY5ln86FVSqLdBA6JnQUPSYKvdiQwhocm/h6nUDaXZ5s+MNFu6enuHYPD1S3i1Fk6MoENEwg1ARY4TBDvxE19SA8hFodTt2/IhRTcR8Mpxq4Ch+TX0jGt45GPIx40JcUjHpIMX8NFgBSeaVMtmkIJ9OC1cKVBHS6tmjuA266vLEnH/+FZMaeOE9+7hwALHAby+mbDAFzGPTQW4k6gW9ZlO1Lr2+4Ch8wGpeBRib1KHv6VZPgaSgKnLnTAMO5r8UBlfiiDDVlYZ3A6c2BoAh7cEpnpCxlIDsgUARY4fOAi/YyzNV3Q3DaC5xSkQjmeV+Bk4xSBQzJu7RiB6osdYrrKjkq4Mp58jSwCE3cOXlu9PD1se+aEiigJ9yfOdcATuHTWLqvpQpV2Dic8BFjgUHGnOVs634SWlNFOmZGyY5/TBA6ZLSx4SBJ8DTYBmjL9w6lWePw+PIE1xR06Dh6ho9Ik2rrFBSnBRsz+u5wACxx3CkBP/xjO13ai7gCI800iTanKqQKHrJ+0Eqgah7lJQZdHPCQVvlpF4Pz73dDWOQq7q0J78JpV8Q/Un2O4CicGNwijbdHZMIFgEXC9wHG7eRDO49r6jLR4Udmcsr7ebIFwusAh00sbq1WjYJiTlQhbcAdX1vGQZPg6HwI0dXoQD14rxt49HT3gZlOL+4xQW+ikfUbcnF9OTLsrBQ7aDOf81R6obxmCxXjC4/rSLCfmnak4R4rAIRPNgockwdf5EugbmD54bdfWAshChXA2AE7aSZXzy3kEXCVwePBcE1rm1osNzfpVmShspDovx+YZ40gTOCSGdhzxoKmw7KwEPLMmh0c8JBi+6hK4WtsnTm4mfY2YmNCd2qwbKZu8lGfFZKXHw8ayhTaJFUcjEgi4QuDoH/SIjbpwYAO24U6B4T5GOhwFJ1IFDsmSBQ9Jgq/+CBw52SLaAP6Y6pOy62m4+rHmt3YmENECBx2Rfu5KN6ShxnkVKkPF44FqbjWRLnDIfG3vHoWTuMooOxNHPFC4ZB0PSYavw3jw2j7cNfS+ijygDebY+CdAo8Gk40Jboruxo+afENswQyAiBQ46N+BWw6A4mXRjWRbvWoklwi0Chyz8QvDA6bPsDBY8JBM3X+uaBuG9q92450QRxMfxce1mygIp1lL7UVKYCmXL3a1Ya4Yb251LIGIEDtqw5+SFTqDzAspXZsCyorS5qXXxE7cJHDKrO3DE4x0UPGiXWFryxyMekox7rsfOtmO+Ty93d0+qrU/pOdwIsQvbV1Ky5aMHrOfrBh8dL3AM4gZdJ/CDQgLHFlQaXIg9WjZzCbhV4JAkWPCQJNxzndnUCkc53aQgHswcJl2pt8+0weM4UuSks6SCyYT9Nk7AsQIH7UB55nInJCfGQhUepBYphysZzzpzNt0ucEhatJsknfRLyyBJr4dHPCSZyLrSKcS0YzBv2219vo5j5450YdaszOSRZOvxRrSPjhM4SHP6et0AKn0lwiY+3Mtw4WSB425USsGDVi7x0si7+Tj57jQeTDbIB5MFPQtJoKPTZ3egEi4bJmCEgCMEjsnJKaBGpL1rFEqX8bHKRjJWbYcFDjWR6XsSPKjhzFxAO83msuChjckRT/no9dBnU0PrkGibn8QplgQXrwIMPXlnhmgbgWN4dEJMjygx0rMT59ph1OMV52fk4nbWbIwToF0Dv/ytkzCCx2zT1kbUi9/zYDF85PGlxj1xiU0peNDSvyqV4EErXrjs2a8gXMDzT3oHxuH+TXlCmfFIdQs8ur0QIu0cJPuRvztGI9hO73u7SShl02GXZ690wcjoJGzfmHu3Rb5zPYFYOxCgJWs//K+r8N2vVog5dVLwO3WxEyXmaNH4JyfZIpp2QGUqDqTU9RjupPibQ/WAe55BfHQUPL6jyJQfbrGcnZkIT6MwRlr4ND8tBY9JXBL4tz84B195vhTK78l0Cw5HpPPVN+vhZsMAdkgmYQrz6YO7l7BOThhyjo62/+AjS+AtPGn38o1e+OW+WnE2VdWGHF7NEob8sHOQYV+QTruAfvfHl8S5Jq8droe9RxrElsP0oXx42yJgYSOw4kP7DpDgRoZ+c+9PnyetciLBg/YbIMHj5f+9Bt19Hvj+z2qgt9+j75jfhozAlZu9UHOrD+i4gt9iu0GjHKwAHDL8mgHdhyMab+AmYYPDE9DYNiw6jZoW+aFrCYRV4KA51+/95Ap09oyJDKBzTvY8VCyG5liJz5oySQLGbtwlMDkxRnxIrfE18n0hweNRFHovXOsRie0bHIfv/7wGaBMkNuEn8OrB2zOR6MD248VfXZu55x/hIfDT126IjqMMnUag2DABJYGwzlW88up1oEOUpKlrHoIbt/thxZIF8hFfLSCw54FiMcTJ6+bNwTyMZ270KEY1aLj41/ih+/BjJeY8YtuWEriObQTp1Wxdmw0lRamwrDBNXC0NhD0zTYB0wzbjXkh1jYNwq2lAXGmzsA2r+QA40zAj1EFYlUYHcNOuMZx/pTnYMVQMHUPlxizsWfI5BxFa2hyWLCqbpMNBRp4nSjss8p4v4c1IOs2Ud7oMbx4YDZ3zyigpd9gLq8DhDsScSibABJgAE2ACTCCsOhyMnwkwASbABJgAE3AHARY43JHPnEomwASYABNgAmElwAJHWPFz4EyACTABJsAE3EGABQ535DOnkgkwASbABJhAWAnE7j14Hq7VdoQ1EnYKfPeOUlhXFthunK/88gQupxy2U7LCFhdaTfDRPRVQmJ9hOg7XbrXB629ewl1See8LfXhR8PTD5bBqeXAP0Trw1hW49H6zflRc9HbnlhVQub7ElimuPlcLR0/dtGXc7BQpap8++Vwl5GbzVgyhyJfYA3+8Cg0dMRAdExeK8GwdxthIPywuzAxY4Nh76CLuI55v67SGKnJTnm54qGrlvASO+qZuOFLdCLEJ5oWVUKXPDuGMj/bCmtKCoAsch49fg5o6D8TEJtgh2WGNg2d0ENLTkmwrcJy71AiHq9sgPiE1rJzsHviUpwueeLCMBY4QZZTY+CsuIQ1i4/hgtMmJcUuw0xbLcUn8kSSY3qihgJjGYbmMZ5b6DL2j+u8tfEttRVx8soU+OtMrr3fS9hGPi0uGBK47uvnkhQHd9/zSWgKsw2EtT/aNCTABJsAEmAAT0CDAAocGFH7EBJgAE2ACTIAJWEuABQ5rebJvTIAJMAEmwASYgAYBFjg0oPAjJsAEmAATYAJMwFoCLHBYy5N9YwJMgAkwASbABDQIsMChAYUfMQEmwASYABNgAtYSYIHDWp7sGxNgAkyACTABJqBBgAUODSj8iAkwASbABJgAE7CWAAsc1vJk35gAE2ACTIAJMAENArYVOPq76qC1rlojyuYfNd86AQM9DeYdRriLphtHYaivxVAqmeE0Ju/kOFw/96ohZlZYYu7mKZppOwZ6GqH55gnzgbAL0wTM5ItpzzUccN3RgBLmR4YEjkO/+AJ+/E/pRpU+XFPz3O53csIDwwPtd/l/4eiPoL+7/q5nRm/U/lFBv3rqF0adh9zelZP/AWcP/4tuuOo06VrWeDk20gue0dltfD2j/VC9/9sQFR0DPW3vw/6ffBymprwaLqcf2ZFhsMslpVxdruvfPwIN+J+MkXwTFk38UeezHbkbTY6RcqVOr1G/pT0t97LtMBJ+VFQ0nH7jO0D+uMkEu+7o5QtxNhK+2fxQh+nkumM27U6xb0jgKNv6GcjMXambpoM//zyOIjTq2vH1sqPxPTj6f38183psuBdab5+C4pUPzjwz80PtH/nTeP2PMO6x5wmui5bfByX3Pq6bRHWadC1rvLx49CW48d5vZt40XHsL8/QeSE7LhdSMIli740+BGl9fxo4Mg10uiYW6XN++cgCKVz0kMBnJN188fT1X57MdufuKu/q5kXKlTq/aD3/3avfKtsNI+KkZiyAF/zffPOYvqIh6H+y6o5cvBNJI+GaBq8N0ct0xm3an2BeHt/mLLFXG5AW52AsYg8vv/BRSMwvFMGRWPh7lfv8X4eyRf4XJ8TF499D3IG/JJijd/HExotBSWw2JKZlw77bPQkbOCrh95Q0YHe6B3o6b2HNshuXrnoGFBeVw6cQrMNzfJoSOVZs+BqNDXZCevQySUrNF1EYGO+HS8Vegr+sWpKYvgnurPgdpmcVw6sA/AtlPz14K3a1XofbyPlhV8dE5/uUWbxAf1s7G81CwbJu/5Ib8fV/nTTH6QDzPHPwuLCwsh/qaN/EE33hY/8BXUBCImZOm1IxCuHjsJRhEIS+roAzWbP8TiMVDtbTcdzVfRgHuNPoXC10tV2D7M9+GNhyxKli6VaR1fGwQbtcchMIVOzTzaMnqRyAlvcB2DM2Wy7Ktn4aWW+/glMivYRKnRkrKHoWl5U+KsmekXK/a9FHoaLwAFQ//peDmL9+ojPoqu7WX9mEcPLBi3bPCr2OvfR3rzSfm5LPdy65eZfFXrrTqfnbhGs22Q6tca9ULZdtBbYq/ck3xp3rQgtO3UpDUS1OkvLOq7sy3TfcXPp1e7qvuREq7HyllyUw6fHdpFb50Nl8C6jmMe4ag7sp+oQ+x5r4XxKgB9ZRXrP+A+JgtW7sHK+0uuHbmV0L/YtPur8GCrCVwfO/f4HD9FAz2NuFH8mWgRmUJNvZn3vwnPHo8BRajm/jEBbB6y6eEYNLXeQtS0vJEDMjdsd9+HYf+o0VDH5eYBtQ40/B/Z9MF/FD3C3tjKMh0YTxJSFH7RxaSF+TDIAo5djQkNPR31YqoEeur1b+A0sqPixN8Lxx9UTNNx/d+E1IwTRsf/ir0tF+DyzgtQ0bLfd6SCsE1t2gD9iw+BdHRsdCLjJPT84Ubytf2hnPit1YeyZEhuzE0Wy4HcIru5P5vCSG1tPJ5OPeHH6CgWmO4XPd33cZyN4mC13TZ9JdvemV3oKceBtA/adrr38VTWJMcV3Zl/LWu/sqVVt331XZolWutuq5sO/yFL8s11aOhXnu2DVpcrXhmVd3Rai+08lWZLxR/f+Hr1Z1IafetyEen+WFI4FAminrRlY/8NVDPi0YzaKSCRi9IF4CmXdIyi4CUdahRbqs/AzGxCdiDbBWjFuTPkrJHYDkKJvSfhvA9KMgsWFiCvfMkyF5UjoJHGoyN9EHindGN4YE26G2/Dut2fllMAdDQ/2BvI4bbqozWzO+Y2Pg5/tFLapwoHk4wGx76M8G2pOwxTGszMrw7TZMTo4IJ5QUNIyYkpePIxeWZpKndJ6XmQGJyJgpdeTiidK+w50HGSSnTI0gzDu/8UOfRyECHeGNnhkbKJfVik1IW4khZrRCaE1OyUEid5mbEPQm3dNw3jRRpGTV3s2WXenXqukDh2Jm7Fgdfz9TlSqvu67Udar7qeqFuO9TxUIcvy3Ui1oMhh7QN6jRZcW+k7OvVHTVXrXxVtunqOGuFb7buaJUFCidS6o6amVPvTQscJCRERUWJ9MZgA6mlKDoxPgLR+JGMwSkBaqArdn1VfBTJUTQKJtJQA0s9RrWhHvjk+Kh4PIF6FzR0GhuXKO6pYJER7jAeeoqOwuKdPxPCv+l4K5/b8bfUpRB8NBRxp9MCyCRJMM5fshlWb/7kTFL8uSeL9NGkfNIyvvLIzgwNl0ssW1Qu6f+qio9A/tItAoER9yRU+2JGnqi565ZdiMyyq1We5DNf5Uq+p6te26Hmq3QnfyvbDvlMXn2FTwK8bNOkXTddjZR9kS8+6o4vrkqGevmiFb5u3YnQdl/JK1J/mxY4fIGIjqY5tw4UAKbEiMfU5ATQFAvpaeSVVOIHLs6XU/FudKgHvOiG3Cel5cysWklBnQ0SMhqu/UG4r685hL31LJxOKICExHSxwoK0k+tQP0QaCkvpHz0fwVUwJO061SjTlJSai8PvKSI9xJf0EDJylusmjdzTSiApoNGoh3plkK4H+NKJDNXlknpaRffsFOWyaOVOv2VC6T4ZyyXpMSlX++gx0y27OCrVgyN3JLDTPDjpc5BR5jPVBTJO5C4ibuCPOr00WmpV22EgeGGF6oGT2waj6TRrT1n2KV/M1B11virbdCPx0K07Lmr3jbBykh1DAgdJ/9O9C7zivxlDqxrujHYULt8udC1O7f8HoSTa330bfvfSB+F3L38ITh/4zrQTYXfW/bRfUWKqhISK1/7tKWjC1SRZ+auB5vzo40gjGxWop3D20D/D3hefhYvHX4bNj31D9NBXYg/10okfw94fPQte7ziGMe03rb5Q+uf1TqCORJ3wdybudvqh4DLLerrHLHteyjS11r4DFagfQzow+378PLyOnKVA5ss9KcbV4QqL3/77HmQ1gSxKoQ+Vd6eNIl8VcaF3Mo/syHA2rYr4i0hrl8uCpduE0vC+nzwvlgEf/M/P4TQbLcc25j4lvVDoGvV23KBQKIPoz52f6AfVB3pyZxRQr+wuLt2FCtTdoszXXt5/ZwRwbl2wI3eRSEN/FFwVrAQjwW1ueknBXKvtmM3rWb7kj7JeqNsOyhv6J4yP8Okd1QNqc9xkZnkqGBEAH226z7rjg6t+vlAeYriivmiHr1d3Iqbdd1OBu5PWqC9+47+nWvszZqYstBjQyIOct1b+FtMpMwUHRM+PdDFo+IyMUMrCXlocKoaSEb1rvKehaTJqvzxjA2IKhp6//tJzsO2pvxO6ImSXenvTc+jpdDtjaHRDCia+/GutOzWQdE8AAAPpSURBVC3W2j/1hV/PfBRmPFD8GEZdhU89sxQ+/FSF4qn5nx944SWIS1tl2KGSizIN5IHynnhLRvSOmND+GvEJqaJnrLavvhdTIiTEoe4H7atyCvcfePoLrwomMhxlXJTujTIkN0rjHW2Cb35pB2woL1Y+NvT70NEa+OHPL0B8yrRyq9qRjDM9V/72Vy6pzNCQbTyOMlDDZ8Y9rZag3tvGXX9+V3lWhj/HPx9ll/KP4kH1Q+lemc9GuI8NtcKXnl8Dj+4sE2kJ1p+vffs1uNGSgKNryYaDkOnyVa7II2V6pcfqtkP6I98r75Xu6bmy7ZD2fIVPAt3rP3oOqp7+e8gpWi+993sdwZV0e3bmwQvPb/drNxwWXvzZUdh3rBNXCWZpBi+50Evlb7N1xxdX8tdIvvgLP9jtvnekAb71F7tg9T0FFBU2QSZgaIRDChsUF+VvEhxkr47ekdKWFDbonhomKWzQvej53aXDMat8R36RvgcZCmP52meAdsKUhj4MpBypNjSSIfU71HGT/jXdeBuWr3/2rriq/QnnvZKLMg0UJ+W9khG9IyakDEofQGmU9umZ8p44kbBBJm9JJQoqaTNKk9KeMi5kTz63I0MZN2U86be/ckllJiE5Q/Aj+2SUfum5X7HhOVwSfly4UbJSup/jn4+yS/kn64fSvTKf7chdJN7gH5kuJas5fBR1X3qrbjukP/K98l7Ji54r2w5pz1f4nbjMmaYXzQgbMg5OvkoulAblb72yT3bVdccXV7JrJF/Inl74kdzuU9rdZma/+DZLefn2z0/3IC2IF/VGqWKwmSVAFfnRT//UMBdmOM2O5rKffOF/ZkEG+RdzNw/YTNuRu3gj7P7Ey+YDYRemCZjJF9OeazjguqMBJcyPbP0VtkpIsMqfMOeV5cGb4WLGruURtZmHoWQRyrBshjmg6JjhZsZuQJFix4Y7OFag4ny1gqK1ftha4LA2qewbE2ACTIAJMAEmEC4CLHCEizyHywSYABNgAkzARQRY4HBRZnNSmQATYAJMgAmEiwALHOEiz+EyASbABJgAE3ARARY4XJTZnFQmwASYABNgAuEiIJbFjg13wfidzbrCFRE7hEvHaVthvN4p3I66zQqvHO+Hd3wooDSMjQ7CpJdZ6kGc8FhTbvXCkO+orZjwDMhb117Hx7BcR02fGmxXCJ7RPtx8i3ZgZuOLgHd82Ncrfh4EArGf+dBmaO3gBkSyLV+1SP6c9/Urn7kfRka5okuAxYsy5U9T19IV+fDFT1SacuNWy2Urg79T4sf2bITG1l63Ip6T7tLl9hU4HqxaCYvy526UOCcR/AAKcplTqIrB/wOnRGRShEx4pgAAAABJRU5ErkJggg==",
      "text/plain": [
       ""
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython.display import Image\n",
    "\n",
    "fig = Image(filename='gfx/intent_out_map-crop.png')\n",
    "fig\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8d64913",
   "metadata": {},
   "source": [
    "A similar process can be used to determine the mapping for `intent(inout)` variables:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "bc9b5b41",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAl4AAAF8CAYAAAAerbUmAAABJ2lDQ1BrQ0dDb2xvclNwYWNlQWRvYmVSR0IxOTk4AAAokWNgYFJILCjIYRJgYMjNKykKcndSiIiMUmB/xsDBwMMgzsDMwJ2YXFzgGBDgwwAEMBoVfLvGwAiiL+uCzMKUxwu4UlKLk4H0HyDOTi4oKmFgYMwAspXLSwpA7B4gWyQpG8xeAGIXAR0IZG8BsdMh7BNgNRD2HbCakCBnIPsDkM2XBGYzgeziS4ewBUBsqL0gIOiYkp+UqgDyvYahpaWFJol+IAhKUitKQLRzfkFlUWZ6RomCIzCkUhU885L1dBSMDIyMGBhA4Q5R/TkQHJ6MYmcQYgiAEJsjwcDgv5SBgeUPQsykl4FhgQ4DA/9UhJiaIQODgD4Dw745yaVFZVBjGJmMGRgI8QHbBUozZ45GUQAAADhlWElmTU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAAqACAAQAAAABAAACXqADAAQAAAABAAABfAAAAACZtezAAABAAElEQVR4AeydB3xc1bH/R733XmxJ7kXuvXfjYGoIKYZHElrqg7yXAqRBkhfyXkJCOsH8qYGEQAIGHGPcwL3J3ZJlucmyrN57139+Z31XV6tdaSWttG2OPtLecur3XO3Ozpkz49HJiSQJASEgBByAQE1tI505X+AAPZEuOCOBKeMTKTQkwBm7Ln12IwLebjRWGaoQEAIOTiA3v5ye/uMO8gsIc/CeSvccjUBzYzX94rH1NHVisqN1TfojBLoREMGrGw45EQJCwN4E/AOCySsgwd7dkPadjIAHtTtZj6W77krA010HLuMWAkJACAgBISAEhMBwExDBa7iJS3tCQAgIASEgBISA2xIQwcttp14GLgSEgBAQAkJACAw3ARG8hpu4tCcEhIAQEAJCQAi4LQERvNx26mXgQkAICAEhIASEwHATEMFruIlLe0JACAgBISAEhIDbEhDBy22nXgYuBITAcBPoaG+lCyf+NWzN1lbmU8GlA8PWntZQaf4pammq1U6Nr3VV+VRddsV4LgdCwB0JiODljrMuYxYCLkzg9N6/WCVsZB16lY7vfLZXEs2NVWYFiF4L6W6als87v4uu8S+SNe3rqrL6sL66kDo7DD6tPDw86ehHv6D2thary/eVsbL4PH348j3U2dlhMeuxHb8m5DNNV7O20cWTwyd4mrYv50LAEQiI4OUIsyB9EAJCwGYEqkouUENtcZ/1JY5eTKmTP9VrvjN7N7Kg8E6veXq7aVr+atZWGjF+pSpiTfu91W3p3rbXHyBoupCCwxMpiH8LLu2zlL3f14PDk2nqkq8ShDpJQkAI9J+AeK7vPzMpIQSEgJMQuHzmA7py9kPy9vGjGSu+RaFRKcaeV5ddUtqsyPgJlLHtlxSVlE5557aTp5cvTV/+TSovyKSiq0f53JvKC7No0e0/p6b6CjqzbyPVsWATmTCJpix6UNWRefAVCo5IUpo21Ddt6dcoL3tnt/Lz1z9Jpfmnadbq76g+9NV+SMQIpbnKzvg7FV05TJ7evjQq/RYW3FbwmLZQe3sLjZl2h6pr36YnVL2ZB1+i9tZmOrbjGYpLmU2T5n+REtLmU2HuYaPAhwK5WR9RXdV1Sl94vyp/6dR71NHRpoSpazm7qKO9jZLGLKEJczZQQ00xZR58mby8/ZQWa97NP6Kr57ap++WFmZR99G8qT2hUquLmFxCu6gS7rMOvMvsASl/0EEXEjlXXtT8NtSU9WHr7Bmq35VUIuCwB+crislMrAxMC7k0AsfuwlDhr9bdp8oL7ycOz+9sdhKeacoO9UVnBWco+/AYLGvewoOBPWK6MS5lF4TFjKDZ5Bgsw95Gnpzftf/8HFBQaTzO5zsqSHMrk5crWlnoWZD5kLdM1mrL4Icq/sJuu5XzSozza6+xsp8CQODUxfbWPTBC6IORAw5TG2rkjHz3Nwk8Ot5VHteVXjRNckneM2lobaMz0O5WgOGrqbSxorVL30d/6qu6Bx4PDEinn2D+4TKPKk3PsLQoIiqKmhkoWGr+hBEcIkzXluTfGt5Vam+towtx7uUwTlVw7ocpBeEubfDNBqISAdvHkJmOfCi8fZMHuQQrk9g9tforH3mm8hwNzLLtlkBMh4KIEur8TueggZVhCQAi4HwEIUL7+oWzM/k8VdBsapN7SjJWPKi1R6qR1rA0qoIDgGPIPjGDBIY6iEiZTY10pYRkTWpnS/JOqTmh8kHBtztrHKXbEDFVHfXXP8i1NNVwmXAlG5vph2j7yFFzcR+Nnf56iEifTyAmrlRBYnJdhrri6BkHRw9OLtUvjKCTCECzaPyia6muKupWJSkxXY4NwBEEOtmgJoxbQxLn3sAavRmnDfP2DecxlqpyXtz/NXvsYa7kWs1bMw1hXysS15M8CW0VRNvno8iMDtIYxydNoMmvd6qqvU3NDhbFcbyyNmeRACLgoAVlqdNGJlWEJAXcngKWx1Ruep/MZb9L2Nx6mueu+T8ljl1rEotkseXr5GI3T9Zmh6UHC0hnyxqfMVdocXMO5JpB4WSgPgUjTMKGMaTLXfitrsXx8g4xZvXjJ1GA479GrcbuxAB+0tzUZ+6ZdR18hyMHYHwJp8thlagzb/vqAWo6NjJ9EpFNQYbnVi5c6TdPpvc9TweX9qi4sT+oLaePxYl5Ieo1XbyxVZvkjBFyYgGi8XHhyZWhCwJ0JKHcGLGBAk5QycQ2VXT/VbxwQwmCLhB18WCKEEBQQHE2jp91OaenreSlydK91di8fw0JQc792SYZHj1G2YhBa0A8s8UH75hcQxkudF5QQdpXttWDvpSVPTx+lndMEHZRDn03TyAlrqIhtv1B+JPOpYy0dljChuQOv3nYtanUVXjmkNHITeQkSAqk+YYkUwlg+L7tC0+cfFGm8PRCWxsJyIAScnIBovJx8AqX7QkAImBC4oX3CMtknbz+qhJTWlgZafPsvumdUS2aGZTNogDQNjV57BcP0/R/8UNlt3fbVTTRrzXcpY/uv6NSe59jmqZ4mzruX4kbOJi7dVTd2+6m6SRm2a+Vv/cq7aumzqvSiWpI05Om9/fRFD9CB939I7z13qxKEYEwfO3ImhUSOoIunNtGmP99CEXHj1Ri5UdWHpNGLaN97T1DymKUEQ/jq0ksUGT+xq383jrDRICx6FC8BVvKS4HSuv52ik6bSlpc2kI9foDKmN2TF6HTjU2eG89TJ6+jEzt9SFtuD+fGyLJZ2kcATAt+7f7qZbeO8WJh7zMD3BnNoz8yxHDfzbkOT8lcIuDABD/5WpFMou/BIZWhCQAg4PIHT5/LpJ7/fTV4BBvukgXTYsDvPS33447ilkW2rAsMNH/y6CpVGh9/+sAQIzQyW07SkP1fLYqzx0nbc4S0TNlG+fsFcxkcV0edXS4E6QU5fHrsnUWbmqv8yaJSsaB8NQMjDMiMM/LWEfrSxQOnjF9Sj/9D2efsaNFAfPP9pWnjrT5VwpZXVXtHXTv7R14tNCRCgIIhp1/XjQ1n9OZZPPTy8DCyYk54n7MXATatHzxz1mGOJ6wNJ7Y359OQjy2jqxIE/OwNpV8oIgf4SkKXG/hKT/EJACDg0AXzIa/ZWOMYSl6bN0ndcabZY6ELSC12m5zDS14Qu3EPdMLrXhC7T/BA89O3py4+Z8Wl2ObEfRVQe5EXqrX3ch3ClCS84R0I/cB3JtLyvf4jKX8buK7BJABotcwntm9aLZUzUrb9uWr/+HEuM0GChjOl4IMDp69EzR3/MsTTXT7kmBFyJgAherjSbMhYhIAQcmgB2Ha5/6K1h6yOWJdfc+8KwtScNCQEh0DcBEbz6ZiQ5hIAQEAI2I6DXhtms0l4qGu72eumK3BICQoAJiOAlj4EQEAJCQAgIASEgBIaJgAhewwRamhECQkAIWEMALiNMg1qX5p/qlxsKtFNRdE7Zk8EQXpIQEAKOQ0AEL8eZC+mJEBACQoD2vftYjyDfx3b8WsVJtBYPwhbt5XrgSwte4iUJASHgOAS69iY7Tp+kJ0JACAgBITAIAog9iXA+CNsjSQgIAcciIIKXY82H9EYICAEnIoAg0Wf2v0AL1j+leo3wRAEhMZQ4ahGd2v1Hqig+T2FRaSpcEXyKZR95gwqvHGYXFxEcuPvLKgg3vNmf2beRyguzlDPWDvafZS4VXT1KWYdfVR7i0xc9xPEYx5K5OqvLrtB11nixbwcV2Ho2O31FsO0ibteT3T6MSr+FA2ivUEGtMw++rBylVnI/l939Ow6c/WaP/pnri1wTAkJg4ARkqXHg7KSkEBACbk6gtbmOA2Z3hSKqLr9C9RxgG0t81eW5tOi2nyvv9cCUk/EPFaIHglBoZArtZ4/0cCCaffRvVJJ3nKYu+YpaYjTEYuwJFgGt0xc+qOJDHtr8lCprrk747EKQ7Bj2Qj92xl1K6EJ/pi75KqVN/hQd+ehpFRi7taWecrO2KuFsAof8uXDibbP969kTuSIEhMBgCIjgNRh6UlYICAEhYIZAUGg8VZddZs3THopPnadyFFw+oOI9FudlKC1TQ00RNdWXK6FrDAtIEJjm3PQEO0PtGYwaFWDZMCZ5Gk2e/0WOq3idQ/1UcIDqnnUqp6RBUcpxajhrxQou7lPxFKMSJ6tg1rHJMwh9QPLy9qfZHM4nacxigmCHGIqm/VMZ5Y8QEAI2IyCCl81QSkVCQAi4I4HOjo4ew45PnUsLb/kpFVw5SFtfvY9DDFUTQutgqc+LBSsEjZ616tsqxiKu+3BYHSTl/b1HbYYLmj8urxvBqFXIIAt16qtobeWwQhzcW0sIPaRp1eCBHl7nkSz1Tysnr0JACNiGgNh42Yaj1CIEhIAbEoAAhXiE9dWFKrYhlgxDwpOptiJPaadiR8ygDzZ+mmorryl7rk527TBq6m1KwKpnjRfCDoXwsiPCCI0Yv5LysndQe3uLWZJYLoRWLD/nEyW4IRQSPOGbq1NfQXj0GK53J2u1lqodjnBXMWbanfos6tiaunoUkgtCQAj0m4AIXv1GJgWEgBAQAgYCgaFxSmDa+uoXyY/jEgaFJ/IND7UECNstaLfCokcpmysY0x/44Ee0eeNdKqZhcFgiLf/s72gi21ftffd79N6fb6WwmNEcBzJS1aFnDE1YQ20Jvfunmzn2oRfN4eVBaMAs1YmyKIOUvugBOsD2ZO89dysLhx0sdN1BCCUEI3yOrqjy4E9vdRkzyYEQEAKDJuDB6urOQdciFQgBISAEbEDg9Ll8+snvd5NXQLINahu+Klqb6zmQdoCxQQhFbS0NvOuwnRCwWp9a+TpLQMYA17inLRtiyREOT/VBqHFfuwbtGgJ26wNP475pnWopEUGruR9aQh+xzKgvq9Wr5cGraV36e4583N6YT08+soymTnSuZ8eRmUrfhoaAaLyGhqvUKgSEgBsR8PHrsqHShg0ByVzS7Ln096Cd0q6bCl3Ip13zZa2auaSV1e55sFbMNJnro1avPq9pXfp7ciwEhMDgCXR9HRp8XVKDEBACQkAICAEhIASEQC8ERPDqBY7cEgJCQAgIASEgBISALQmI4GVLmlKXEBACQkAICAEhIAR6ISCCVy9w5JYQEAJCoD8E2lqbqPiqwTlpXVW+2jnYn/KOnhe7IuEqAzssJQkBITAwAmJcPzBuUkoICAE3IdDR3moI68P+ryB4JLM/LHia19w16DHUVxfQwX8/RXd8fTNdzdpGTexdftbq7xB8Z2UfeV05Uo1OnELTV/yn2qnY3FjF3uJj9VX0OEYeDw+vHrsje2QcggumbW977cvKFQY87oexD7Eld/yf0fB/CJqXKoWASxIQjZdLTqsMSggIAVsROLD5x5TPQafHzbxb/eZf3EMt7Ine2gTBbf973+eYjQuUEAbv9XDjUJp/Uvnv6queM3s30sWT7/SVbUjum7a95NO/orX/8RJ96st/o+rSSxxQ++CQtCuVCgFXJiAaL1eeXRmbEBACgyJQUXSOiq4coU/d/wYh/iJS0pgl6vXiyXfpWs4uJUTh2oQ5G9R10z/wn4VwPFGJ6RQZP0H91nEg7bMHXqSGmmIlfI2f/QUVugdOV3EtNCpVxWYsvHyIiq4eVVql8sIsWnT7z6k49ygHtP4ne7hvpdRJN1Fa+npVJvPgK+zANUF5wYfH/IS0hZTJbXj5+NOMFY9QcHiSWiI8s28j1VXmU2TCJJqy6EH2vF9LKBsckcRlD6j+TVv6NeXt3rRtTTsH1xT+HA8SS6uShIAQ6B8B0Xj1j5fkFgJCwI0IVJZcoNDIkUahSz/0poZKmrb0G/z7NSW41JTn6m8bj/0Cwyl57HLa/a//pqxDrxIEsYDgaBo5fhUvH4bSxHn3qdA/dVXXKW3yzTR//ZNKkLp4chPFpcxS9xDYetL8+6ieBbZDH/6Mg15/QQl6Jz7+A0E4bG2pp9ysD1XooskL7qeLJ96hI1t/TuPnfEF5us888JLqz/73f6DGMnP1t6myJIcy0Z8bZRHWaMrih5R27xqHJTJtW3O8Ck3d8Z3PcvihMs4z2zhOORACQsA6AiJ4WcdJcgkBIeCGBFqb61hj1OWRXo9g4tx7VJxGCEy+/sFKENHf1x/PX/9jmrH8ERaOPqKPXvuS0jJBqwVv99GsCYN3+5SJa5UWqaIom3xu1BcQHMMhhCIIoYmiEiZTYe5hCmBNU3X5FRX/EfEaywsyVVNw2Dp7zfdY0zWPwuPGqRBACWnzacSEVVTHsSQb60qpigVJ5IPw5BcQRuWFXWXnrH2coCmDMAVbNdO2tfHkHH9bxZO86b6XVN+06/IqBISAdQRkqdE6TpJLCAgBNyQQxPEUayvzqL2tmZcC/YwEYLe1/fWHeEkwhZfmJnHMH+MtswcI3ZM6eR0lj1tO21jwun5xNy/9dQ9tc3rv8xzjcT+NnLBaLV+aqxRLltA8IQYk0vhZn1NxF9EftKEZ/COPh6fhe7Wnpw9X1WFcFvRmQRJ541PmskBnWD7Vl/XiwN0q5JDZkZBqO3bETCUkWsgil4WAEOiFgGi8eoEjt4SAEHBvAvGpc9VS3cndf1J2WohjeD7j71RRnKMEMmiJUiauUbsdLZGC+4XCK4dUPMZ2tomCbZYva5s8WcBpqq9UQhZiNSLP+NmfV0GzIRxpCfngvgE7KsN5J2EzG/Ynj11Go6fdzoLcMrVsqeXt7TUwJI7DEgWp/CgL27BwDsrdW9K3reWDAInA35KEgBAYGAHReA2Mm5QSAkLADQggbuGi256mjB2/ok1/Wk+sKiIIY2lTbqPopKm05aUNHOw6UKcN48DU/KMS5+UCSuA6vuu3aidkR0cbjRi3gkawzRe0V168w3HTn2+huTc9rjRiJ3b+lrIOvkJ+vLyoxWXEcuH+D36obK9u/cq7lDBqAW15eYNa5mtraaRln3lWtWNsF2fcNrRYSEoLxudoa9aa71LG9l/RqT3PKVuzifPupbiRs7v6rApwOdV34mXLrrZv++ompW3LOvQaxYyY3qfQphqXP0JACPQg4MHftPpQkvcoIxeEgBAQAkNC4PS5fPrJ73eTV0D3ZbghaayflbaxtssDy3wswGgJ2icISJ2d7UoowXW4ikDwaWio+I/ye4XrzQ1VyqZLv2SJJb2W5lq2twpHFiWMwWcXNE1YHtSCXavdg3yuBd5ub2sh9AeaM215UWsX9UDAQz3Ge3yuGcfjLR/+uXz9gg3tIP+NPqOsWmbUCW6mbatlTQ7CrQl2KOMIqb0xn558ZBlNneh4z44j8JE+OA4B0Xg5zlxIT4SAEHBgAprQo+8iDNSRPDy63kohdBmuQXOkDtUf7G40TRCsNKEL9/RLjMSCk5a82SWEPkH40wuAuKe1q45ZQNQnTejCNQhjMNjXJ31ZTdjT7pu2rYRC7aa8CgEh0G8CBl10v4tJASEgBISAEBACQkAICIH+EhDBq7/EJL8QEAJCQAgIASEgBAZIQASvAYKTYkJACAgBISAEhIAQ6C8BEbz6S0zyCwEhIASEgBAQAkJggAS6W2AOsBIpJgSEgBCwFYGODt4NyLv2JAmB/hBQz01/CkheIWAnAiJ42Qm8NCsEhEBPAsFB/hQb4U3tHaU9b7r5ldYOP2KHEuRFbeTt2ezmNHoO3yvQm/D8SBICjk5A/Hg5+gxJ/4SAEHBrAnUNrbTrcBFNSAulcalhlJNbTdmXa2jlvHgWNNjflyQhIAScioAIXk41XdJZISAE3InA8axyKihpoJsWJ5GPd5dJbmtbB23bX0AJMQE0c1KUOyGRsQoBpycggpfTT6EMQAgIAVcjUF3bQjsPFdKMiZGUlhxicXhX8mvpxLkKWjk/gcJDujzqWywgN4SAELA7ARG87D4F0gEhIASEQBeBw6dLqaK6hdYuTCQvL53r+64s3Y7a2ztp+8ECJXjNnxbT7Z6cCAEh4HgERPByvDmRHgkBIeCGBMqrmunjI0U0f2o0JccH9ZtAfnE9HTpVRivmxlNUuF+/y0sBISAEhoeACF7Dw1laEQJCQAiYJYCg1ftPlFJjUxut4iVDT8++tVxmK+KLHR2dyhDf39eTFs2MNQbJtpRfrgsBITD8BETwGn7m0qIQEAJCQBEoLm+kvRnFtHhWHMVHB9iMiqr3WAktZuHLlvXarINSkRBwYwIieLnx5MvQhYAQsA8BaKb2HCumTvYVu3xu3JBopqBJ281CHdpYNiduUJo0+1CSVoWAaxIQwcs151VGJQSEgIMS0GyxlrMwFB0x9A4/B2s75qAYpVtCwGkJiODltFMnHRcCQsCZCGD3IYznYX+FpcXhTvuPl1Bjcxsb3ydYtVtyuPsn7QkBdyEggpe7zLSMUwgIAbsRgL+tjMxyWsMuIuzpbwv+wbYfLGSnq5E0qhf/YHYDJQ0LATcgIIKXG0yyDFEICAH7EICH+V3sCDWS3TvMSY+2TyfMtJqRWUZlFc20akFCN4/4ZrLKJSEgBGxMQAQvGwOV6oSAEBACIJCTW0NncirppkWJDhlTETEgtx0ooMmjw2l8WphMmhAQAsNEQASvYQItzQgBIeAeBJpb2lW4n8TYQJo+IdLhB33qfAXlFzUoH2L+fl4O31/poBBwdgIieDn7DEr/hYAQcBgCmRer6MLVGqXlCvD3dph+9dUROG/9iINujxkZQuljI/rKLveFgBAYBAERvAYBT4oKASEgBECggQUX2HLBYH3SmHCnhZJ1qYou5tXSavagHxjgPIKj0wKXjrslARG83HLaZdBCQAjYisDJcxWUV1SvtFx+vs6/VNfCS6XQfiXFBfLuxyhbYZJ6hIAQuEFABC95FISAEBACAyBQW9/KcRELaRIbp49NCR1ADY5d5CIvmZ7lpdOV8xIoNNjHsTsrvRMCTkRABC8nmizpqhAQAo5BAO4YisualJbL29vTMTo1BL1oY3cY2PkYE+nvUO4whmCoUqUQGDYCIngNG2ppSAgIAWcnUFXTorRcWIJLTQp29uFY3f+rBXWUcbacVs6Pp4hQP6vLSUYhIAR6EhDBqycTuSIEhIAQ6EHg0KlSguf31QsS3TLkDgJ77+ANBMFsdL9wRmwPPnJBCAgB6wiI4GUdJ8klBISAmxIoq2yiT44W04JpMcrg3E0xGIddUNJAB06W0rLZcWoJ0nhDDoSAELCKgAheVmGSTEJACLgbgc7OTtrHgaVbWjs4sHQ8eXp6uBsCi+MFGwT89mImS1kA8/AQNhZhyQ0hYEJABC8TIHIqBISAECgqbaR9J0po6axYio0KECAWCJRUNNGejGJaOD2G4KlfkhAQAn0TEMGrb0aSQwgIATchADum3bys6MkbFZfNiXeTUQ9+mHuPFYtmcPAYpQY3ISCCl5tMtAxTCAiB3glcYyeoh0+X0UpeVowMl517vdPqebeyppm99xex24koGpnoPjs+e5KQK0KgdwIiePXOR+4KASHg4gTa2jvo48NFFCS79Wwy09j9WVPXyo5X48mVfZzZBJZU4pYERPByy2mXQQsBIQACl6/V0nEO+bNmQQKFhfgKFBsRgOC142ABTR0fwYG3Xc+rv40wSTVuSkAELzedeBm2EHBnAtipiKDW0ZF+NHtytDujGNKxH88qp6KyRhV029cF4lgOKSyp3G0IiODlNlMtAxUCQgAEzl+ppsxLVbR2YSIFB0oMwqF+Khoa21TYofFpoTRxVPhQNyf1CwGHJyCCl8NPkXRQCAgBWxBoam6nnazlGhEfyEtgkbaoUuroB4GzFyop93odrZqfQAH+3v0oKVmFgGsREMHLteZTRiMEhIAZAvjQv8T2XDctSiJ/Py8zOeTScBCA8Iug26mJQSL8DgdwacMhCYjg5ZDTIp0SAq5BoK2tna5er7DbYPBBf+h0KaUkBNHoEaGUNlLsuew2GbqGsdybzb/QfmG5t7CkmhoaW3Q57HcYGx1CIUH+9uuAtOzyBETwcvkplgEKAfsRKCmrofv+6zUKDg4b9k60tXUQ+0MlH29PDmlDVF9XTR/+9ZvD3g9p0DwBbHDYztqvuGh/eumNHVRY1sTBx+2rjWxsbKBvPbCU1iyZaL7TclUI2ICALLTbAKJUIQSEgGUCQUHB5Bkw0nKGIbpj6hzCt+X8ELUk1Q6EgK+PJ61flkyX82uprKqZPP0TydPbvo5rfTqLBjIUKSME+kWAA2NIEgJCQAgIASFgHwKjkkPEh5p90EurdiIggpedwEuzQkAICAEhYCDAK8GShIDbEBDBy22mWgYqBISAEBACQkAI2JuACF72ngFpXwgIASEgBISAEHAbAiJ4uc1Uy0CFgOMR+OaGCfTmM0s5bE+UsXMzJ0aqaw/fPdZ4TQ6EwAT2fP/q04to41MLKCk20AhkTnoU/e2XS2jmJHGKa4QiBw5NQAQvh54e6ZwQcG0CB06WkKenB921NsU4UBzj2sGTpcZrciAE4O8LHu/DQ33pi7ePNgKBuxBv/sWrJCHgDATEnYQzzJL0UQi4KIHjWRXKo/zoESE0bXwEtbPjrbEpocq55pkLVRQa7EOLZsRSXJQ/lVQ00YETpVRVa3C0GRLkTSvmxlNUuD9VVjfTriNFVFPX6qKkZFh6AtNZKzp1XASdzqnUXzYeBwd604JpMZQUF6ieiVPnDZELjBnkQAjYkYAIXnaEL00LASFA9PZHufT4g1Po9pUjlOAFJm9tzaXoCD96+tGZSsPRwQIZtGB3rhpJ330mQ32Y/vyRmRQfE6AEsfAQX/L08qB3tucJUhcncOhUKc1noereW0fRY7851mO0eG5+9p8zWCDv8gn22XWp9Nyb52l3RnGP/HJBCAw3ARG8hpu4tCcEhEA3AprWK31shLqefbmazrK26+G7xymha+NbOfTJ0SLasD6Nblk+glZymJl9x0qU0HXkTBk983ImjeSQQK3sqV6S6xPAnEPQnjAqjJbNjiN4wNenu9akKKHr3R15SoCfMi5cCfb3sKAmgpeelBzbi4AsituLvLQrBISAkcDbrOHS0lusAUMaPSJYvY7iZcj7Pz2GUpMM5wnRBi1XY1ObMsqHgb43a7sKSxtVfvnj+gRee/8SdXZ20uc+lUbwgK9PWKpGjE5oUrF0fTK7ks5erFLCWhgvXUsSAvYm0P2JtXdvpH0hIATcksDxcxVqybCssklpuwABxtRISXEBylbHi4Wrc5er6Ep+ndJy/OKFs5R7vY6Wstbjf/97Fq1bnKjyyx/XJ3Axr1bZ+2E5cd2SpG4Dhu1fQ2MbtbVzoM4bqb29QwlqLKtJEgJ2JyBLjXafAumAEBACIICg1vrlwrKqJrVk9Lu/nqOKaoNBvZ5U9pVqevzZ4zRlbDj94CtTae6UaNq6r0CfRY5dmMDft1yhuVOjCSGH9Kmc4z7iWgzbepVWcgxIVi8kxwVRB69I1rNAJkkI2JuACF72ngFpXwgIAbMEjmWW08RR4fTEQ1No+4EC5S4gjT9Q99wwkP7SnaNp16Eidd2DY87AAF+S+xDALtete6/TrStGdBt0xtlyGjMylL7/8FTazbaBU3j3Y0ykv3pusPQoSQjYm4AIXvaeAWlfCAgBRQDLQPqloC17rrMbiQBaMS+eHvzMOJWngt1GbGMhDEtJWIq874Y/J3wI/+3fV4SkCxPQng3YdmnpX9uv0jzWesXyc9LY1K4ub96dr56bJbNjacMto3jJsYOwE/Lldy9qxeRVCNiVgAc/xF1PsV27Io0LASHgagRKymroaz98h7wC0/ocGgzk8W5kqpXAUlFUmB818+41Uz9dsPFpZ1sezbdXb400V5+n9158uLcscs9OBB5+7O9U3hhFXt5dLiDMdQWG9Ka7GJEPzwiWEvXJi92PRIb5UmVNSzd7L30e0+OWhiL6xj1Tac2Siaa35FwI2IyAaLxshlIqEgJCYDAE9MbQ+nrwgQpbHXMJ9jyS3IeAOaELozcVunANAryl5wb3JQkBexGQXY32Ii/tCgEhIASEgBAQAm5HQAQvt5tyGbAQEAJCQAgIASFgLwKy1Ggv8tKuEHADAlVsX9Pc0k4BAWJK6gbTPegh2t3kWB7TQc+hVNA3ARG8+mYkOYSAEBgAgX3Hiqmsso4C/Tqpvf7CAGqwbZGwkN4Nt23bmtTWHwIjEsOpPPMqmdjH96eKfueFjAWbMfZEYvR+z/s7eLesPCf9hikF+kVAdjX2C5dkFgJCoC8CJeWNtIdjKS6aEUMJMYF9ZZf7QsCuBOCKBL7h5Hm16zS4VeMieLnVdMtghcDQEYADU+XclLUGCF7sAa+mkoSAkxDYyxpaaMBWzI1n9xTy7DrJtDllN0Xwcsppk04LAccicL24gQ6yk8rlc+IoOsLfsTonvRECVhKAg96PDxfRHA4/NTIhyMpSkk0I9I+ACF794yW5hYAQ0BGA89KPjxSSv583LZ4Zq7sjh0LAeQnA0311bQutWpBA3l6y+d95Z9Ixey6Cl2POi/RKCDg8gdzrdYR4iqvmJ1B4qK/D91c6KAT6Q6CmroV2HCykaRMiafSI7oG4+1OP5BUCpgRE8DIlIudCQAj0SqCtrYN2HCqkSBa25k6N6TWv3BQCzk7geFY5FZU10uoFicbdj84+Jum/fQmI4GVf/tK6EHAqAheu1tCZC5W0hj+EQoJ8nKrv0lkhMFAC9RyUfTsHZ5+QFkYTRoUNtBopJwQUARG85EEQAkKgTwIt7AR1Oy+7JMYG0IyJUX3mlwxCwBUJnMmpJCyxr1mYyHaNXq44RBnTMBAQwWsYIEsTQsCZCWRdqqKc3BpauyiRAv3F57Izz6X0ffAEmprbaRtrv1KTgmnquIjBVyg1uB0BEbzcbsplwELAOgKNTby8wlquUcnBlD5WPmCsoya53IVA9pVqyr5crbRfQQHyhcRd5t0W4xTByxYUpQ4h4GIETp2voLyCeqXl8vOVJRUXm14Zjo0IwOHqjoMFFBcVQLMmyxK8jbC6fDUieLn8FMsAhYD1BOrqW9WOxUmjw2hcqhgRW09OcrozgcvXaukkf1lZza5VQoPFtYo7PwvWjF0EL2soSR4h4AYEMs6WEeLWwXDYx1ucRrrBlMsQbUigrb2DdvLSfFiIL82fJm5WbIjW5aoSwcvlplQGJAT6R6CKPXTvOlxIM9hRZFqyOIrsHz3JLQS6E7hWVE9HTpepmI+R4X7db8qZEGACInjJYyAE3JjAwZOlVMvLi/A+7+UlgYHd+FGQoduQAALG7+KYj36+nrRkVpwNa5aqXIGACF6uMIsyBiHQTwLlVRwM+EgRzZ8aTcnxEgy4n/gkuxCwikBRaSPtO1FCS2fFUiwb4EsSAiAggpc8B0LAjQh0dnbS3mMlhLA/y+fGk6enaLncaPplqHYggP+5PRnF1NFBtGxOnPzP2WEOHK1JEbwcbUakP0JgiAgg3ty+4yW0eGYsxUfLt+8hwizVCgGzBMoqm+iTo8W0gA3vk+ICzeaRi+5BQAQv95hnGaWbEYDBfDjvrkKCvcknvKzozTsVl/CSh4eHaLnc7HGQ4ToQgQO89IjYjyvnddlVwo1LIDthFQ20A03UEHZF9owPIVypWgjYg8Dh06X065czCUsc+cX19M6OPJo6PoKWzo4TocseEyJtCgEdgYUzYmlOejS9tytPxX3Erb+8laNsLnXZ5NCFCYjGy4UnV4bmfgTgSfu//vcIlVY2qyXF2elRtHB6rPuBkBELAScgcORMGZ3mwNvb9hdQaJAP/f77c5Xmywm6Ll0cBAHReA0CnhQVAo5G4P2PrymhC/06yg5Rx4wMdbQuSn+EgBC4QWDSqDA6fKpUndXwcuNbH+UKGzcgIIKXG0yyDNE9CMB4d9POPONgW3nn4jvbrxrP5UAICAHHIrCJlxshcGlp677ryjxAO5dX1yQgS42uOa8yKjckcPFqjQr5E8XesqMj/Cki1FeMdd3wOZAhOxeB9vZOqqhuprKqJipnE4Fw/r9NHxvhXIOQ3vaLgAhe/cIlmYWAEBACQkAICAEhMHAC3gMvKiWFwPAT+Ofm4/T2lpPD37C06PQEgoN86cVf3ev045ABCAEh4NwERPBy7vlzu94XldVSY3sE+QWEu93YZcCDI1BanjO4CqS03Qicyb5OzS1tdmtfGrYtgSkTkjiOpfuKH+47cts+R1LbcBJgB6AenrIvZDiRu0Jb4jjWeWfxx8/8mzx9Qpx3ANJzI4HG+krWPN9DCXFhxmvudiCCl7vNuIxXCAgBIeBsBDjYgldAorP1WvprhkBAZ4uZq+51SdQG7jXfMlohIASEgBAQAkLAjgRE8LIjfGlaCAgBISAEhIAQcC8CIni513zLaIWAEBACQkAICAE7EhDBy47wpWkhIASEgBAQAkLAvQiI4OVe8y2jtQGB0vxT1NJU2++a2lqbqPhqRr/L2aLAQPtsi7aHo46a8qvUUFM8HE1JG0JACAiBQREQwWtQ+KSwoxGAcPPJ29+i1uZ6Y9dyjr1Fl89sNp7396C5saqboHVsx6+psvh8f6uh+uoCOvjvp/pVrr2thRpqS/pVxlzmgfbZXF0DuZZ16FU6vvPZgRS1qkx2xt/o6rltVuWVTEJACAgBexIQwcue9KVtmxPo7Gij0vyT1NHeFXi2piKP6qryB9zWmb0b6eLJdwZcfjAFMZa9735vMFU4RNnE0YspdfKnHKIv0gkhIASEgD0JiB8ve9KXtoedwMWT79K1nF0smLVR0pglNGHOBoKW7NTuP1IFa7HCotJo7rrvG/uVm7mViq4eJU8vbyovzKJFt/9c3Su5dpyyj77B131p+vJvUkjECKWZOrNvI9VV5lNkwiSasuhB8vYNNNalPygvzOTyf1PLY6FRqaoOeOO/fOYDunL2Q/L28aPxszfQ2QMvqjwQvsbP/gLFjpihqsnN+oiFyeuUvvB+dX7p1HvUwUJnZPwEs/VqbaPMmf0v0IL1T6lL5zPepICQGBo5fhUVXj5IF078k9pZaE2ddBOlpa/XiqlXa/o8Y8W3KDA0rgfP6rJLSmuI/hXlHqHzGX+n9rZm8vELpujEqRTIfWhqqKSq0ktKMzh62u2UMnGtatdSv64yA2gy/YOiCFrJkPDkbv2VEyEgBISAIxIQjZcjzor0adAEMg+9Qqf3/kX9VhSdM9aHD/dpS7/Bv1+jzIOvUE15rlqiqubXRbf9nBLS5hvz4iAuZRaFx4yh2OQZNGn+feTpafiukp+zm4W2e1hA8ldtIO/+939AQaHxNHP1t6myJIcyeXnNUoIAlDb5Zpq//kklWF08uYmFh2q1HDeLy09ecD8FBEcpgcjXP5QmzrtP9UOrLzgskXKO/YOFxkZ1CcupASyAmKtXK4PX1uY61gieMl6qLr9C9VUFVMtawUMf/kwJdxBGT3z8B9JzQwFzdZv2GREFsORnyhPCaA23BYH34OYnadysz9GYGXcpm7ekMYtV3Wf2vUDRSVMohYW+jO2/otaWBov9qi67rPKMmX4nC6PTqazgjHFMciAEhIAQcGQCIng58uxI3wZMwMc3iLRfT08vYz0T597Dmpca9UHv6x9MjXVlSljCB/n1C3soPnWeMS8OAoJjyD8wQmlxohImG+/NWPkoC2WzWTO0jusq4HpKqarkgtJwYXnQLyCMNWSZxvymB9DmQFNTUZRNPjf6ASEOQha0TigfFj2KoA3z9g1grVA63+sKmRLF5+gXtEGVxTlK45MwaoHSEpnWa9q2ufPC3MNKcIMgVlt5jfsWSeUF3ftvTZ+h+YPwaYknBEX8YmxxI2dRZ2eH4oA+pUxaS6On3qZ+PTw8qbG2lCz1q+TaCYphgWvE+JU0etodNGLcSnPDkmtCYMgI4EtbEf/fDFe6cOJf1NnRPlzNSTtDSEAEryGEK1Xbj8C4mXezlug/1G947DjVEdh9bX/9IbWc19RQRdRp6F986lxaeMtPqeDKQdr66n1K89RXzyEYIHl6+ag3QyxXInn7BJAXLz/Gp8yliXP/Q10z9+f03ufpyEdPU31NodICoTNe3n60esPz5O3tT9vfeJjyWRC0lBB3cOSE1ZR3fhcvnX5MyWOXqfLm6jWto7Ojw/SSEoagzUPf8TueNVLxad2FUHN1m+tzbzyRH0uq2//6AH302pcJ86QJlHoBWXHtbLfYr7aWRsVaG4jE7tRIuN8rNrp8+PI9Soi3NPrBblIxVx7/D7AftaZ9S/3q7brppp7cLIPZQ29lhutefXWhCIGDgC2C1yDgSVHnIlDHuwprK/NoztrHWTO0xvhGjWW2mORptPTOX/KbSZvS+OhHBiEAOwuhnbGUAkPilIYtIDiaNTC3K/uo8JjRlrJT4ZVDvKz3eRbO7jUKEMpFBQtU0Kahf2XXTynBromDymKJrrPzhqR4o9aRE9aob9ywdRrJ+ZHM1Xsju3qB0AONH944seGgJO+4uo7lVCwbQoBD/5PHLWNtX7S+qNm6zfW5N55oE+OAXdycNd+jsSx49ZYs9Ss0KoX5nFF2YxhLKWvAJLkngWC27Zu65KukfRkyR2Gwm1RMyzfzF7eiq0dY07qCrGnfXJ/6uma6qWcka3evnN3SV7Fhub/t9Qf4fXLgG5aGpZMO3IgY1zvw5EjXBkKAo+ki3XhRhyzM4EJweBLbEE2lLS9tYKPuQKUhwv2CyweUQTo0PVgCi7ihIcM9JNh97f/gh6yB2k23fXUTv8F7GN/k8WaPcy9vX5q15rvK7ujUnueUO4uJ8+5VGh1DLfjL5W50LHXyOjqx87eUdfAV8uMlQywxYtnzk7cfVcuMsG9afPsvuM+Jqu5Nf76F5t70OAtEy43VQfhAf5vZbi0mebq6bq5e3ND6DMN3LM9tffWL5MdtBnH96FdC2gLCUuWWlzeoJUxolJZ95lle6kxR9eKPubrN9dkcz2LeoIB2sIQJ4RZ5oIEs4WXZGcv/Ex1U99EOkoGT5X6hvxFx42nzC3crXkGhCaqUKix/3IoA7BZhV4jNMvgSYrpJIyohvccmFdgSZh95g79MHOZnMoJtKr+sbCgztv2SopLSKe/cduPGGQ8Prx7lm+rL1f8evpzAf5zWPo5hOxockUQFlw6ozS6wJ8WXt6Lco7w7+l/8BadG2SVOmv9F9R50ZOvTyrYyLDpNmR5cydxCUfGTemzqiU+dr+o2nVzDmCvYrvSi+kKF9x18+YEmHCYKU7l9/D+f3f//lN2mD5ssTGKbUXzZNJZlc4WG2mJlqoCl+6rSi4SNQvgyNnoqvkjebGwW9pftrc10bMczytwCXx6zebNMEbP05PfBUem38HvMCmN+OehJwIO/fXb/Gt0zj1wRAg5D4I+v7KZthyqVcGCpU9AOYReiljRNlfaNGG8mEHQ6eSlLM5ZvY0Gng+0ntGUvraz2qpYSWeOFXYqm9evP8e+EJQJf3q2HN1vTpM8LWye8qat8XLcH26JhZ2ILvzH7BYbzPYNCGnYdLc21aonOtD7c6+QfbRy4b7ZeEybwcwbbMS1pbWFJBSx82cYMwpppMlu3mT6b8lRzwGywXIIdmKvv2aiqxps42p656r94tbVTMcANPSecW+qXYRyBnMPwNqaNA2VMU3P1eXrvxYdNL8u5ExC486GN5BMy3mxPYU/48VuP0B1f30yZB16icyxQQWuMdPKTP9D6h/5JeSyQwUZq3s0/otDIFLp8+n0W/vfTrNXfpbzs7WrJ/ub7/04f8RcS/E/hecQOY/w/zr/5SX5mN3Urj125EG4W3vY/yp5Rax992fbX+9UGEdh/Hv7wf5Q2LpK/JGx7/UGaveY76gsgnvvEUYtoyuKHaMuLn6c5Nz2hBCHYbGLX8ZI7/48Fm9+o97m0KesJ9qX433v3j5+iWx/+l9EuEmNUY+Yd1rNW/bcSArMOv0ZJ7L5lNG88ObzlpzzG7/Cu65FKO44vV9dgnsC/N33x1a6yvKHH2ztA9XfVF55TG4ZiR8xUAhQ0yrBn1RKEsl1vfkPVGxk/kb+QfqIEzzlrHlOmE0e3/R+t/Nyf+IuRwcRDK6e9tjfk0h9/cgclxIVpl9zutevTye2GLgN2VQJ6oQtjNP0whuG64XrX42/J7YPKyH9g+K4l0/r15xBWYPRuKenzwh7MmFgAQ4IABa2QPkEgwxKhuYR7puKR2Xp1gijq8fELMled0q5Be2cpma3bTJ9Neao54I7iW3bW4b8qmxwvtmVrZ9s4uOjQ7mvt6jnhGvpkrl9d4zCloNUkr+5GQNukgXGf2v0nauFlQf0mFVyHxhXmAcV5GUrr1FBTRNBiIWkbZ6CRPc1aHzx3puXx5c3fZCleFeY/ePZhzoD3AggscJyMurH7VnORMm7mZ5UwB8HLXDK3qQf/e9gwVM9aNWyg0SfUC/cvjXXl/P/1Gs391A/VexY0/BCcoA2E65bywrOqGDTVWlJleYc10sVT76gdwtggA4EqPHZMjw1HWP7H+w5WBkIikqng4j5lNhGVOJnwezVrm+JqSfDS2nXn165PHnemIGMXAkJgWAjAHgaaBewChaavNyF1WDokjbgcAXObNEwHCe0RlsXUZpIANhNY9W21ZI182hc1PJ+WdhHiCxK+NJhLKK9pi71u1AENMIQmLXmxnz5o2FViAU3Tymv3zb1Cm97WZr5NbcyeXoYvcMYxcD9RN3Yo7930mBLONHMHrQ2tLM7BA5pnaPwusVYQ0SawLDuPBTlLqbW159gscbNUh7tdF+N6d5txGa8QsDMBfCgFhsSK0GXneXCn5iFE6TepQGvTycvvo+C+hDeTxKXOMWsaoDEyLQ+nw/0J5RUeO5bgdBmaLwgl17J3Ku0Q6vfzD1M7I7GcDsfIWkKb+k09zQ0VqiycDfc3QbMHuzHYm0GLpU/FecdYi1ak2oKT6DDeFITdmmPZz97std/jfvfcuOLp6aO+PEEYDI8ew8u1O1le61R1IL/e9Y6+LTk2EBCNlzwJQkAICAEh4MQEujatsKqJx9G17Kxt0ohgwQdLhtomFRjTH/jgR7R5411q2QwOiZd/9ndKU6Vpi/SaK9PysG26cPyfNzRVuvbVtpCu9rlyqNAocfQiun5xr9rYA4EKgt8s7gMSnAkf+egXaonQEJnCUN50Uw+iOmCJEcuQ3VK3MRvKaho3ffsXTryjNqNgg0335KHs0mBXiQgW6MPOv3+dNylUKHtPOFQ2TUk8nn3vPUHJY5ZS+qIH6MD7P6T3nrtV8RjDxvmxI2eaFpFzHQExrtfBkEPHJ2CNcb0tRwGj+vKCs92MS21Z/0DrstQvS9cttQMv9tgZaWlTgaVyjnBd+2ZtzvbLXP/EuN4cFee41ptxPUagbcbQNnHABkl/HcfQNJluUsHuYSytabaCWj2qMP/Rn+vL4/oHGz9NC275iRJU9Pn0x2rJjQUjTZhDmCxcM7WBhLYLfYctqb68flMPdlxieXTmym9p3VOvfY0ZwhcEMWwUwO5G/K/jGMulMMyHgDWddxZj6dNHF+IMLjOwAQe+98wluJLBfW1jDza6YAlVOzdXBtfEuJ5teS3BketCwBUImDohtGZMMEZVb5icGYaxB//9lDXFbJoni8MNwb7CUrLUL0vXLdVzbMev1TKHpfuOfH3fu4/x0kaxI3dR+jZMBLTNGEpLdUPoQtPadRyb26QCQUMTukzzm57ry6NeuFmAFss0n2mbmtCFfBBiTIUuw3Vf4wYefXkIYsiPZTzswhwz7U5k75b6GrOm/YJApH3BMhWO0C+90IUGsLPaktCF+6hLXw846s+RR5J5ArLUaJ6LXHURAnBCCNU6fOZYm+AccNXnn+vmw8rasrbKl8jbwbGrSpIQEAKOSQBLbErbNAzdg/B068PvKOHRls2NYyfO0PhJGl4CIngNL29pbRgJ5GYaQmzgGySMRuG2oCTvhFknhlq3TJ0DYhs23lzPHXmdHSLuZ4eIE5WBKuw04HMHcRXbWUBK5cDO2M6tJThSzGAHg4vYzw++NdaUX1VOGMfP/pxy1or72KIOD+5wFYFlhJDIEcqfELxwwwgXqvzI+Anc33fZGeIutQSB/mg2F+gXnDUWcqgjbO2etuzrWvPqFcsJ5pxEdsvEJ0Xs3DTr8Kv8jTuA7TUe4rrGKiNZOFBEcOvIhEk0ZdGD6ps3jH3NXTfneBJxG7WEANvwT7Rg/VPq0vmMNwkGyrApgb+kK2c/5Pb9aMaKbynnk+b6jWUatI25jBs5mzrYD5skIWBPAnpt1lD3Q1s+tWU7plouW9YtdVkmIEuNltnIHScnEJcySxmxxibPYI3XfVRfVUD73/+B8tw+ffk3lC+frEOvdRvlGHY6CEENu51GsFCAhK3gCBw9ZTHiJ+5mIegT5Tzx0Ic/Ux6nIQid+PgPyiu0VhmECgTRLWDhDAmerRFPEAJIGvvMmb/+SeXs8OLJTep+GduRwfh18vwvcZDpBCXw1HDAaiR44p629BtK4IOghXqR0C+4ZYB9RlnBGTrH/rH0KSfjH8pp4mz2qA+nkfvZABZLFqYJAmT6wgdZMxhPhzY/pfKAE3z5zGTHipUlOZTJS59Ilq6j/9mH32Ch8B61ZHJ671+6NQPv4rAn0xKCcWM+4A8JS6qzuJ3JC+5XjCz1O/vo31SIo6lLvqKWGLXlYK1OeRUCQkAIOAMBEbycYZakjwMiYOqEEE4TNSeG2O4MJ4aGUDZd1Zs6B8QdzSEidvtoDhELcw9TAO8wggABoQxOT+ErR0v4JgxtzrXzO9UlhO9ImbBWOVDEzqSKomzy8Q9mwanLkeGMFY+o3U8IE6RPE+feo+IrQmjz1ZVBv+CVGmFBRk25VcUu1JfrzUmkPh+0bnBsOpmXY+uqr3O4kAtUVXJBjRsx6uBwtrwwUwl55q5rdWmOJ+Gxu46FKmsSbFgQRQCaQ7QDLZmlfiOu5Bje4o7wSPD07QmfQ5KEgBAQAk5GQAQvJ5sw6e7ACfTqxLCXapXxqtqyDQeDBqeKygEjG6sqB4wsAIznLeHxafO61TJy4moVCw6776DxiU+dy6E4nuet40+r0BrYvcQGFsYy3bzC37gKO6/trz+kluOaeJeRLrvaKaUZzmI5E6GD9EnvJBLLmXonkfp82nKJ1w1P+m0tBieN6A/GF58yl4N5/weHLDF/XatLqwfLsOa0UZ0dPYOMo9+rNzzP4Ur8afsb0CjuUaFRNOeW+n5jPNrSCMZt2DivtS6vQkAICAHnICA2Xs4xT9LLARLQOyGEE8NLp99T9lP4QNc7MdRXrzkHDIkcqb/c7RiasUu8TJg8dpna/YPdk6Y7gJAHgbkRu2zEuBVqCbPwyiEVXgNaoYO8rNdXquNdlbWVeYT4aRA8co79w1gE27ex7Ik+XL+4hx0ZjjbewwHa15xEQlCBk0TwME1YBoUWKZ/rAheE+oCXbQQAhk0ZBEQVf5I1U+aum9Zn7hz1tjTVqPAliJEJ7VUIe7GHHRv8HEFbBpu0suunLPY7hJdLYWeHIN952TvYtq7FXFNyTQgIASHg0ARE8HLo6ZHODZaA3gnhrV95R8Ud2/LSBiWAQDDRnBjq29E7B5ww917lEtF4/4ZDxIS0BYSAs1te3qA8sMM/zrLPPNtjJ2TqpLVKy6XFaEudvI5O7PwtZbGtlh/HdMQyG5LS4KBuLSkNm4cS3BBvDX328QvUCXceSuDLPPgSHd36NAWGxatlRwhjmi7IkpNIrQm8ol0YzL/7p5t5K7gXx5h7TNlozWK7MGw0OLXnOdbW1dPEeffy0uzdZOm6vv96DaHWFnaWQmDaykGI/XjMQWo51UMttX7y9qNqmRE+lRbf/gu1Td2cc8uJPBd73/0evffnW5V3bf9AxLQUvZfGWF6FgBBwDgLiQNU55kl6eYPAQByo6p0QohpLTgz1kPXOAfUODdUSGgsr2rIaHB9iCdOX7ZMgfJhLWC7Ua5qgufLgoNjqGu9MxG4lfRuow9QpIozQIaRBW6T5yoGGCP1AX7XA3yhrWpepk0jk0ZKWF9oo2Ixpj7JZdQAAQABJREFUdeM+DPGVpouD6+r7b+66Vo9pvdq59gohDk4XtYT+YxwtjTVKkNS44r65fqNtbcnRtE2tTkuv4kDVEhnHv37ngxvJIyDF8TsqPeyTQEfTdXruf+6ihLiwPvO6agbReLnqzMq4jARgwK1PpkuC+nvaseZoEOd6h4amW7rhNb0vz+l6oQX1dbPlYgEMSd8GzpUAopPjNMHKw6PrX1YTkrR7KIdkWpdmF2W42/2vllfTvOnvQpA0F8Ta3HWtHq286bl2Xe+sUruGcWBzgmky12+0rV231IZpPXLu/ARmTxtJRaU1zj+QXkbQ3uFJzW1+1NbhRd5ebRTg3cTvA70UcNJb3l5h5O/f0+TBSYczoG53vYsPqLgUEgJCQAgIASEwtAR+8J/rhrYBO9aee72OzuRUUliILy2YHkM7DxXS4pmxtCejmPx8vdQxXiW5DgERvFxnLmUkQkAICAEh4AQEsGR+KruS8grrKSUpiG5d0eVsGN0PDvShm5cmU11DqxLEfH0MApi/nwhgTjC9fXZRBK8+EUkGISAEhIAQEAKDJ9DS2kGHTpVSVU0LzZgUSdMn9lxi17eiCWD1jW308ZEi8vH2VBowEcD0lJzvWAQv55sz6bGLEIDRfzl7fIdTVmdOjXXlbAhfR8FhST3sy5x5XNJ3IWArAtW1LXSQBS4EjsByYjgvK/YnBQV406eWJFGXAObBAlgciQDWH4qOk1cEL8eZC+mJixLAzkfsDgwMie02wnr20XXw30/RHV/f3O36QE/qqwtVG6YbAAZanzXlMg+8RFcyEWfRn4Wvelpy5y+VHy5rykoeIeDqBPKL6unEuQoKCfKhlXPjyXeQtloigLnGE6NzHOQaA5JRCAFHI4CwO/A/NdRp2+sPsLPV/KFuplv9iGm5/sG3aN2X/qoCV2cf/Xu3+3IiBNyRAIzl3991jUoqmuiW5cm03AZCl56jJoAtnB5Ln/AS5I6DBdTULEHj9Ywc+Vg0Xo48O9I3pyIAf1pn9/8/FSzbxz+EJs27j72/x9DZAy+qgNgQvsbP/gLHi5xhdlwIVo2Yhe3s9yt10k2Ulr5exUhEcOiGmmJ2zprKAbG/qbzLXz7zAV05C02TH81Y8S3KOf4Wtbc207Edz6ily0kcd1FLiAd5dv+LHFfyMi8HJtLkhfermIhXsz5SAbirSi+xR/kCGj3tdhVLEuXgEf/SqU3Kb9j4WZ9X4ZCyj7yhQiD5B0VwQOsvK80WvNtrCXEWK4qztVN5FQJuRaCtje23TpdRRVUzTRkXTret7G4wPxQwAnkJch0vQTawDRgEMC8vD1oyS5Ygh4K1LesUjZctaUpdbk2gqb5CeZqfd/OPKIa9zR/f+RsVdgfBsuEnayILYvCWby7VVuTRoQ9/pgSzCXM20ImP/6AEOATGTpt8M81f/6QSvi5ymCI4Uz2+81n2VP9tFoDuZwesnjRm+p3KvgoaqBHcnpawe2rfe0+oPAioDYFw36YnWKDq4EDW1+nMvhcoOmkKpbCgB0/1cFoKQezI1p+rMEcT597H5xcpJ+MfVMSBwWezR/tQDt2z//0fKgeraAee76HpgtAI7/SShIA7Eairb6VtBwrU76RRYUrgSksOGVYEmgC2aIZBA7ad+9PYhFiwkhyRgGi8HHFWpE9OSSA0KoXD+gSzluqs6j80TXCuCk0VvLVHJ6ZbHFchCzUBQVGslbqi8sChaHlBJo2d+RkWwLLVr49/sAqxA3sqCHIQdCawBg2aJiTYdkXEjuPzZHWOPw21xVRVcoGW3/075Xh06pKv0iYOD1RfXaTypHBIo9EsrCGd2v0naqwtpeK8DBYQxyoNGK7Hp86hXW9+g+3H4tQ9OKBt4LiPTfXlSrAszT9FuVkf0uI7/o8i4yegiCQh4PIECksb6FhmOQX6e9NSB9EyaQJYAwtdu9kPmKcna8DYJ1gA91GS4xCQ2XCcuZCeODkBCEp7Nz2mlgi1eInWDglhcODB3cvLsNtp/KzPUezImSrOY8Hl/TRywmoVCogD+ah4jas3PE/nM96k7W88THPXfZ8DZS812xTCGSE8kea9X/Oyj9BDSIjPqCV42Md1hPUJ4tiK+qT6By/93D+vAF+aterbxjBFqBPCnwhdemJy7KoEzl2uogu5tRQfE6B8bUG4cbQEYXDdYl6CFAHM0aZG9UcEL4ecFumUMxKApigqfhJNW/o1ysveYRwCBJqm+kolOEErZS6mI5YgL/EyYvLYZSpmIXZBQrNUeOUQLz9+nm2+1tHBzU+pOmFLhlgiM1Y+quIcll0/pQQvT08f1oiVUkjkSGMbQWzTBcHoWs7HhCXPvHM7OAxQJAtWCcb+mR6ERaXRlTOb1U5MCFpFV6EBG0Od7W2EpUz0v541XloopLDo0ap/pvXIuRBwFQLt7Z105EwZG8s30qTRw2O/ZQt25gQweMXHdUn2IyD07cdeWnYxAomjF/Hy3zu0+YW7KVCnMYqIHauEn01/voXm3vQ4JY9bfmPk0IsZvi0npC2ghFELaMvLG1R8xLaWRlr2mWcpdfI6OrHzt5R18BUWyCLUEiOWMD95+1GlcYJN1uLbf6HqS+L2Yc+VPGYpwc4MCZou2IId3/EbOsl2YxD8oCFTcQ5VILiub+uGvnhQ0tgllHd+J/37hc+q/GNn3q2M6Q988CPavPEudQ1G+ss/+zvVRhELhxfZED9pzBJ1Ln+EgKsQgNH6gZMl1NTSQfOmRCsfXM44Nk0Ag90XQhHhy9PiWSKA2WsuPdj4ll26SRICzkHgj6/spm2HKs0Gb3aEEXR0tBGEJgTZxrEWyLqzo51ammvVjkR9PztYi6QP9gyfX1ge9A0IM2qtsMyH5UKlYWKjeAhPqLulsUZpx1RA7RuVQhsGezKtXa0t/Ju3NHF+rldLMLDntUVVH66Z9gX94IZZeAvQiijje5TRB7tG3Z0YK2v2HDk1V5+n91582JG7KH1zEAJwA3GUNVx+vp60kA3Wh1NDtHXfdbVMOJQoRAAbSrp91y0ar74ZSQ4hYDUBCDwQupD0wg+EJb+A8B716IUu3MSyoGaHpWXWCz7EAhgS6oYBvmnS2ja9jm+4eqEL95XA1qXw6iYA4r63byBeuiUfM9dQt4eDC13dBiEnQsACgZzcGsq+Uk0xEf5K+IF7BldMMLa/iW3AIIDtPVaihqiWINk9haShJyCUh56xtCAEhIAQEAIOSqCjo5MyeHdiUWkjjU0NodtMAlY7aLdt0i0IYGsXJSrnq1iChPYafsCwO1LS0BEQukPHVmoWAkJACAgBByUAT++w36pvaKM56dE0l2243DUh5qMIYMM3+yJ4DR9raUkICAEhIATsTKCcPcsfPl1K3l5sv8UBq4M5jqIkA4FuAtixYrYlNWjAEKJIku0ICE3bsZSahIAQEAJCwEEJXM7nkF4XqigyzI/WLkwkb28J3GJpqpQAxoygFdwjApglTAO+LoLXgNFJQSEgBISAEHBkArBZOnGugvKLGigtOdit7LdsMS/mBDAY4QcHipZwMHxF8BoMPSk77ASwg66zpZw6POuHvW1p0LkJtLS0OPcApPdWE2hpYfutU6VUU9dKMydF8m+U1WUlY08CegFsL2vA4FAWfsBEAOvJyporInhZQ0nyOAyBu9dPp2XzxzhMf6QjzkPAk4V2Sa5NoKqmhQ6xwIW0cEYMhQYbQnC59qiHb3QQwNbcWIIUAWzg3EXwGjg7KWkHArHRoYRfSV0E6tm7dsbZMpo3NYbwxihJCLgbgbyCOjp1vpLCQnxp5fwE8vUR+62hfAY0AayZNYsQwNraWAOGJUjZqGAVdhG8rMIkmYSAYxI4ca6crrH9yvLZcQSP12NGhlD62AjH7Kz0SgjYmMCp8xV09Xo9jUwIoluWJxujPdi4mWGtzpn0sn6+XrR6QSJpAlgrC2BLRADr83mRkEF9IpIMQsDxCNTWt9LHh4s4YG8YjUnp0gCeu1xFF67W0mr+1i9OEB1v3qRHgyfQ2tahlhMrq1to2oQISkkMHnylDlIDbKd2HS5Uy3kO0qV+dcMggJUQ5kgEMMvoRPCyzEbuCAGHJJCRWUbFZU10E3ucNrclHobFH+0voKS4QDEqdsgZlE4NhAAM5Q+yw9N29i21gP1vRYT6DaQahy4DgQUe5FfxFydnTngP2sOhiDAeLEGGyBJkt+kUwasbDjkRAo5LoLKmmXYdKqLZ6VFWfcu/mFej/BatnJfARsay/dtxZ1Z61huBgpIGOsYhfbCDDgbzWN5y1QSBZf+JUloxL94lhqgJYC2t7SoUkQhghmkVwcslHm8ZhKsTwE6t6toWtQTh6Wm9FUhbewdtY+1XNAf9deeQKK7+fLji+DIvVtGlvFpKjA1Qmtv+PPfOygMOS+FVf9kc1xC8tHmAALb3eImyBVs8M87tvwiK4KU9GfIqBByQQFllE31ytFiFNkmMDRxwD7Hr6+jZcvVNGp67JQkBRySALwpHTpdRWWUzTR4bTqNHhDhiN4esTw28Q/lYVrnSDg1ZI3asWAQwA3wRvOz4EErTQsASAXjc3sffEFtaO2jF3Hiyxbd9xF3bcaiQgjnu2sIZsZaalutCYNgJwCXKgROG533e1GiloR32TjhAg3W8aQZuMRaxXZQrJ7yv7WM3FE2sCXNHDZgIXq78dMvYnJJAYWmDsvNYyi4iYiP9bT4G2MwcOFlKQ1W/zTssFbosgeLyRqWJDWD/cwhYHeDv3h6OaupaKPNitdo84LKTrhtYdwEs1m0c3orgpXsI5FAI2JMANFK7eVnRk30/DrWNBzRqHx8pIi+2F4MAhlBMkoTAcBHIvlJNOVdqKC7an+akR9tEoztcfR/KduB5H2zmT4sZymYcrm4lgB1nDRjbuGEXpKtHHBDBy+EeQemQOxLIK6yno2fKht0Gq7SiiXbz9nVoGwZjQ+aOcyZj7h8BfLE4ws94SXkTjU8L5d+w/lXgBrkrqprp4rVat90IoxfAFrE5BCIRuGISwcsVZ1XG5DQEYEwMR6hBdra7QtgPvOnZyp7MaSZAOjrkBBqb2H6Ll7bxOmdKNMVFBQx5m87aADbT5PJGmNmTo511CDbpN/x/4T2pscmgAXM1AUwEL5s8JlKJEOg/gUv8zfbEuQpauxB+tuz/za6/fsL6P2Ip4U4EIEQc5h2KiJuIzRz4ciGpdwIlrIHOL6oXx8c3MEEA28eOWBtYaMeGg3AX0YCJ4NX7/4HcFQI2JwDN0i7eXRgb5e+Qb7DwGQYv4SvZiaM5z/g2ByIVuhQBOO7NYgPx6Ag/mss7FL29JGC1tRNcVNZI+J0+IdLaIm6Rz9UEMBG83OKxlUE6CgEYzmZdquJwP0kOrQGA4LX9QAFNGx/RLRako3CUfjgWAdhvHWf/UwUljTSaA7VPHhPuWB108N5A+/3yuxepjTU8bRyvMZRD7Ny6IplmTIxy8J4Pb/cggO1nNztwP2JOA4b7Pt6OL+iL4DW8z4205qYEsFtnJ2u5RiQE0dRxEU5D4cS5ciosbVRBt31dOFSL00yInTt6Ob+W4yWW0j23jFI9wXON87qGVpo1OUo2aAxwfrDL+LvPHCNsskEKD/GhP/xgnkuHRxogKlUMAir8HNaxALaYl7HDQ30JPtCefuEM/fhr08if3ZM4chLBy5FnR/rmEgTO5FQSvtGuW5zk8G8I5oDDm/Y21n6NSw2lSaNFk2GOkTtcq+KQVU88e5zqWch6+lsz6TQ/13BHAvsticE3+CcADmR/+9dzqqIv3TGabl6aPPhKXbwGJYAxt7qGNipi/4fbDhQq9zjf3DDBoUfu+Do5h8YnnRMClglAHf7+x9eU6vuOVSOdUujC6ALZKBr9x3LS5k+uqd1plkctd1yRAD7gnnk5k8rZ3UFTSwe99dFVFTf0Jv4yIUKXbWYcvrsSYgIoMsxXsbVNra5dC2xQl3NcS2i9EJUDaQ+7x9l12HDsqKOXbSaOOjPSL6cmYLB3aVBaLuzqcoWUPpbtvUaGqqDbKUlBbP8lBsCuMK/WjOGFf16gnNwaY9ac3Gp2ums8lQMbEEBYsDtXj6RW3nzjDHZKNhiyzarYuu86fzHsqu6ldy6q96qRbNrhiEmWGh1xVqRPTksAIT92sV+uKWzH5coBfvHBe+5yNa2an0DBgT5OO1/S8b4JNHM8PSwrNrM9F7RdOMfv4llDE9Kq7x65bo52NqzHj+wE7d8cZ/N7EewNSfdlICrcj0bEi+DVP5KSWwg4GQF45YbvorWLEt3ijRM7iLbtL6D46ABlWO1k0yXdFQJCQAjYhYBovOyCXRp1JQIV1c3K+zy8cjuqansoeV/hnW5wBLuStV+u4uBwKHlJ3UJACLg3ARG83Hv+ZfSDJICdSNjSvJqFDthouGvCEsn2gwVK8HK3AL/uOucybiEgBAZGQASvgXGTUm5OAKE9sHtm0QzsRAp0cxpdw0e4k0McJgYxH2FjIUkICAEhIAS6ExDBqzsPORMCvRKAo0MIXO3sWgHChYds7erBC24nsMHAz9eTFnN8NWHUA5FcEAJCwI0JiODlxpMvQ+8fgYKSBjrAXrqXzY6jmEj//hV2w9yIOQfv0hC+YIAvaXAEDh6/TC28m1CSgcD0yckUFjLw56q1tZ0OHLssOHUEFs0ZPaCNQZeullJ+YZWuJjm0RGDOtBQSwcsSHbkuBG4QgAbn4yNFBH9cS3gLvSTrCWgaQvjYWTYnzq3t4KynZj7nbff/hbx8JXIA6LQ0VdP/Pn4LTZmQZB6WFVerahroC998hfwDxR8dcDXWl9E7Gx9mh8m+VtDrnuWZ53fQJ4fzycdXvpB2J9P9rKWxiv70P3eTOFDtzkXOhEA3AlcL6ijjbDnv2IuniFCxWeoGx4oTLDMuY8/S8Hj+zo48mj81mpId1LeOFcOxaxZPT0/yCUywax8cpXEvT9to/gL8/clbmBqmtb1uwNPLFhjk5RdJPgFhA67DHQp6ebSqYbqGS213mDEZ47ASQIgU+KhCgOi71qaI0DVI+jC0/wxzRBDg7Rz3EbsgJQkBISAE3JGAaLzccdZlzL0SuHi1hk6dr1Tx0kKDxSt7r7D6eRMBlRFs+d2deTRrUiSlJYf0swbJLgSEgBBwbgKi8XLu+ZPe25AADJe37Mmn2oZWpeUSocuGcHVVwckqtF9Yfty69zrBA74kISAEhIC7EBCNl7vMtIyzVwLnLldR9uUauonD/QQGyL9Fr7BsdHN2ejTVsZD7wSfXKH1MOI1LFfsQG6GVaoSAEHBgAqLxcuDJka4NPYHGpjbazB/8sDm6c/VIEbqGHnm3FhBg+9OrU6iRA9xu3p1vCHTbLYecCAEhIARci4B8tXet+ZTR9IPA6fMVlHu9XgW19vfz6kdJyWprAtPGR9K4lFD6aP91Gj0ihNLHRti6CalPCAgBIeAQBETwcohpcL9OFJfW0CM/fpva2EfWcCf4lmpmey5vL0/y9vakTVv86KVn/mO4uyHtmRAI8Pem21eOpKxLVfT+rmu0ekEC/fK5j+jM+UKTnMN/igAFj31tNc2Zljr8jUuLQkAIuBQBEbxcajqdZzCtbe3U1ulNnoEpdul0YFBXsxWVOV0ncmR3ApNGh9MY1nptY7cT5y+XEfmnsONV++4ubWssoobGFruzkQ4IASHg/ARE8HL+OXTqEUgcP6eeviHrvK+vF92yfAT9/R1WNZGHxHscMtJSsRAQAsNNQIzrh5u4tCcEhIDVBBCmSZJjE6gpz6Wi3MM26eSFE/+izg7beKW3SYccoJKO9lYCF2tTweUDVFt5zdrsLpvPls+lNZD6w13e1awhKnmEgBAQAg5OoLL4PH348j3U2WnZL1p7Wws11JYMeCTmyp/e+zzVVOQNqM5mjl3X0lRrLJubtZWKrh41njvaQdahV+n4zmd77ZbpmHrNbOZmfXVhN+Ez7/wuusa/SNa0D4Ej+8gbZmq2zyVrnkv0zHTc/eltX8/ljjce5i8HR/pTZZ95TdvsD3cRvPrEKxmEgBAQAo5PIDg8maYu+Sovy1p+Wy/NP0l73/3egAdjWr65oYoFpSM0YtyKAdV5Zu9GunjyHWPZkeNX0pWzW4znjnaQOHoxpU7+VK/dMh1Tr5nN3Nz2+gOssco33rnKwugI5oJkTfuYi/wLu6m1pcFYhz0PrHku0T/Tcfenz309l5Pmf4kiYsf1p8o+85q22R/uYuPVJ17JIASEgBBwfAKtzXV09dw2ShqzhK5mfURNDZVUVXqJNQkFNHra7RSVkE5nD7xIDTXFSvgaP/sLFJ00RWlHCq8cJv+gCJq84MsUHjOGMrb9kqKS0inv3Hby9PKl6cu/yQKdV4/yTfXlFBY9igKCoxWgotyjLEj9i5obayh2xHSaNP+L5OXtR0e2Pk1oLyw6jSqKsulK5haKip+ktFueXt5UXphFi27/OcWnzqfMg684LOzqsktKQxcZP8Eso/KCzB5jKmYmF078k30FtlLqpJsoLX29mgOMMzgiiQouHSDUN23p1+j4rt9Se2szHdvxDMWlzGZmn6fS/NM0a/V3FJO+2g+JGEFBYQkUGBJLZfmnKGHUAruz1D+XePasGfeEufeYfS6tfa5Nn8uCS/soMDSW2tuazbbv6eVDjXVldHb/i1RdfpmCwxJp8sL7CTzNPbvjZ32+x/9C7IgZVnO3/NXI7tMlHXBnAt/cMIHefGYpzZ4cZcQwc2Kkuvbw3WON1+RACExIC6VXn15EG59aQEmxgUYgc9Kj6G+/XEIzOSakO6TWlnoquXZCDbWu6jqd2feCEqxS+MM+Y/uvyNsviEaOX0W+/qE0cd59SsDKyfiHss+avea7FBqZQvvf/yEvVXZSWcFZyj78Bk2Ycw95+/jT6b1/UcKVafnqsssUFBKn2qzl5cb97/9AaWemL/8GweYl69Br6l7Z9dMssNSo42YWCMu5/riUWaoPsckzWEC7j3euerPQEK8+HPHB6YipjjVRNeVXDGMyw8h0TPVVBXTow58poXPCnA104uM/sOB5jrVR9ZSb9aGyxZqy+CGlobqW8wmNmX4nC7reNGrqbcxxFbd1leejnT/QDYz7al9jFhgaT3UscDtC0j+X1o7b0nNp7XOtfy7BAM8ztLOW2sczv++9J8jD01MJuT7+IbRv0xNq2d7cs4svGqb/C2jHWu4ieIGWJIcjcOBkCb8Re6iYiVrn7uL4frh28GSpdklehQDB+z18gIWH+tIXbx9tJOLDPtrgpw2v7phSJq2l0fwBjl8sP7bwB09oVCp5+wZQdGI6C2AhSjjCh3pxXobSTDXUFJEm9MxY+ajSuqROWkd1LEB4efv2KN/cWE3+N7RdELSg5UqZuJa1a5Np3MzPUnEv9loBwTHkHxjBH1ZxKj/myNsngHx8g6ieNSPOkEwZmY6pkDcdBARFsRblihKy/IMiWejMVEPz9g2kOWsfZ2YzFGdoJqFt9PD0UstiIRHJSlj1CwhXwpg5Hqbta3kgGGAuHTFZM248S5aeS2uea/1zacrAXPsNtcVUVXKBpi37BrMfq5bs66ryWVtsnqG5/wW0Yy13WWo0nRU5dwgCx7Mq6NK1WuXFfNr4CGpnR6tj2bN59pVqOnOBP0CCfWjRjFiKi/KnkoomOnCilKpqDX6WQoK8acXceIoK96fK6mbadaSIaupaHWJc0omhJTCdtaJTx0XQ6ZxKsw0FB3rTgmkxlBQXqJ6JU+cr1XNmNrOTX/TkD3AtYSkFmhPT1NbaSJ4sUHnxcqJXgC/NWvVt8gswxMzUbMVUWQs7DaGlam9tUtW2sU0RhCYtefn4UYdWjj3Q9mb0r5WB5qGtzVCfds2RX/tipPgyI/BFGj/rcxQ7ciZhpyLKau50vDA/GivdgCGEoQ5LyVL7bTwn+rmwVN4e160Zd2/PpTXPtf65NB2jufbx7GIpHdpdJAhWSOp/xspnF/mt5S6CF2hJckgCb3+US48/OIW9mY9Qghc6+dbWXIqO8KOnH52pNBwdLJBBC3bnqpH03Wcy1Ifpzx+ZSfExAUoQCw/x5W+LHvTO9jyHHKN0ynYEDp0qpfksVN176yh67DfHelSM5+Zn/zmDBXI/473Prkul5948T7sznEPDYuz4AA8gRDXVV/IHf5vSrEDD0snHWNqCEFDPWhLksZRMyweExFBJ3nGVPZw1BZdOv6c0ZtDSXMveSVGJk9U9P/8wwu42aMJy2f5MS6gPuywhlOEDsbmhQgkggVyvsyb9mMD30slNlDx2GfkFhrPtW5XSLGIHn6UEZ8GNdaUUEsmxY5kD7JKw8xMaSmtTIzONjBtvbXaHyKcfty2fS2sGF8Q2XRC2ruV8rJYQ887tYG1sJAWFJlBvz67+fwn/P9Zyd08dvDUzIXnsTkDTeiFuH2L5ZV+uprOs7UJQZSwrbXwrh+59bK8Kch3GAtbK+QkslPkroevImTJ6+MmD9J1fZcjSpN1ncng6gDnHM5KaFEzLZhtsYvQt37UmRQld7+7Ioy98Zw89vfG0un0PC2qukdjRLP+oxB8CcDyrJcN1D7WMgg+YTX++ha7zzjcY09dUXKXNG++izS/cTUe3/uJGcTitNXw86DUEWIbRl4+Mn0iwp4HglDh6ERvHz6MtL22g9/5yu9pVh/qRxrGm5+yBl+j95+9gLRi0z4a+JaTNZ0FsK7333G18vU1tBvDnpTks2Tlk0nHFB605RvoxxaXMUQbuW17eoFx9bHvtfl4ChDsPzEjX/LAUzJcM50nMEfZGRz78H7Z5S1I2eVWlFw04rGgfHOHaAHPjGEk/Vv0xMJgft6Xn0sCoi5uBYc/nWv9cgkHXXJlvH5quWau/Tcd3/Ibe/8sddGb/CzR33ffVEq+lZ9f0f6E/3D1YtTv8wfIc42mQXtiRQH5hJT36k/fJKzC1117AoP7xh6aoPD997pQSvP7vv2dSWnII7TiIb42dFB8doIIq7z5aRC/88wK98JMF5Meez/cdL6Ete/Lpcn5dr220VJ+nTS8+3GseuWkfAl/+9utU0xbPSzWWtTDYgPG9B9Lp96+fo6KyRvr5ozOoorqF/vHhFfr6FybQr1/JpMOny+iZ786m2Eh/uv+H+6mt3fC298OvTlVLkw/9+ABV97Ic3dpQSN/60kxaNn+cfUBwq3c8uJF8Q3vXYkCTBeNstazHb+1YqkLSruMYS1otzbW8pBiOU5WU6wHO78MG+Ej6/Kbn+vLI98HGT9OCW36ibJWQFxoa5IEtjT7B7xH6hQ85ff1YnuECKj92U2Lpc+bKb+mL9jjuaMqnJx9ZTlMmJPW4Z+2FqpoGuv87b5J3cJddYF9l9Vz1Y0A5/bl+TLiHsWM5y5eXcSEEIOnzq2VGnSAHDRds8bBkppjw8z9z1X91m1d9eX192Fl69KNf0C0P/5PbMgjPqsE+/rTWXqA3fn8fBfKSc3/Tr/6yg/adajAuU5uW1/dVf9zbuFGH6XOp54/7pnVpzzWu659LfT79sWn7EIewCURbbkcbSJaeXZTX2rSGe0djHj3z/ZvJ+lkxtC9/hcCwEjh+rkItGZZVNimhC43DmBopKS5A2ep48VLiuctVdIUFrJbWDvrFC2cp93odLWWtx//+9yxatzhR5Zc/rk/gYl6tsvfDcuK6Jd0/lGH719DYZhS6QKO9vYM/zDr51zXYQOhCUlqqbjZeXVYlEMb0Qhfy+7CQpAldONfqwTGS/lxfHtdHT72drl/ca8jIf+E+wlTowk1oyjQbGn19uIb8mIeCy/tpzLQ7jXU52oGeq34M6Kf+XBuT1n+MHUuNmtBlmh9M9UISlhUhdCGNmfFpdjmxXx1b0/71i3to9PQ7utWnCtvxj56N/ri3caO7ps+lfvy4b1qX9lzjuv651OfTH5u2j/kxFbrQjqVnV/+/0B/uXf+NqF2SEHBAAm1tHdTKv1oqq2pSS0a/++s5pdnQrmuvMMB//NnjNGVsOP3gK1Np7pRo2rqvQLstry5O4O9brtDcqdE0irWi+lRe1ayuxbCtV2llM3+wESXHBfESF3vNZoFM0sAIpC96wKCJGVhxYyl86N368DtGLZ3xhpsfwN5p/UNvWU0BmjG9EGd1QRfLaKvn0los/eEugpe1VCWfwxA4lllOE0eF0xO8BLn9QIFyF4Clxz03DKS/dOdo2nWo6P+3dx7AcR5Xnn/IOefADDGAmWAOoiiSClbWeh0oZ1s+p7LPJ3trHe6uzj7ZVV7XhvKWg1ySk7w+B2WLpEiKCsw5gSDFBBJEzjljcO812MCHwURgwjcz/2YRX+r46+5v3vf6dbe6L1p9McCHCx0CMst194EqemTLtHGFPlnaREXTk+m7X1xCMiy9mGc/ZvHQo7QbmTULN3kCnvqhFw0C3EQC7vB1x+/ElILrji9ZuJMWBK/gamdBWRoZBjIOBe18v4qXkYijLWty6QsfHrG5aeZlI/awECZDSTIU+ak76znJj/B/vVkelFxQqBECum0YzVVf2nuL1rDWK5vbSU/vyDIKf3+vUrWbTSuzacfDs3nI0UIyE/I3r9wxXAZQEAABEPABAQhePoCMJKZG4Bs/Pj5O8BLDaDGif/7lq5SREkN9bNdlXKdLZjOKjc8Q+9Nre00tBwhtZgKnyprU7Fax79Oum4Wtrz17XA0nylCiOHn+iz9/QM/99Qqlp0RTSzsbPN8xsh/xgb8gAAIg4H0CELy8zxgpTJGAvR9H+UEVWx1bTux54EKHgFHoMpZaC13GezKsaK/dGP3hHARAAAS8QQCzGr1BFXGCAAiAAAiAAAiAgA0CELxsQMEt3xAYHIRBs29IIxUQAAEQAAGzEMBQo1lqIoTyIWtyyfIOvb2dFBVV4feSDwxiKQG/V4KdDMgabcOdVWSRtR985GQW7ADbfkXz5toyK1Zcf2/XyIkf/w5Lvvr8nw8/IhhN2jLgmb1XBweHaBhMFVdhMRUni8bqtcemEk8wh9XtFoJXMNeyycoms84OnKrnTXAt9MnH5tGWVdmmyKHs9QhnTgI/eOYh6unxzI+sOyW0cFuVLYiio8JpOe+eIBuN5GQluxOFx/1uXD2Hmlu7PR5vIEYYEZFBqSnjV8Z3txzRUZG0tDiP30f48BJ2UVEFvEPE5D5wlhYX3GmbYOmoHUZGZFNiQgxhyyBHlPDMYwRkKxfZwmdTSbaa0u+xiBERCHiRQG0Dt9szaLdeRIyoQSDkCEDwCrkq922BZdjm3eO1FMnDNiJ0GbfM8G1OkBoITI6AaGplkVWZIbl5VQ4Pp0BDOjmSCAUCICAEIHihHXiNQGVtFx3lzYm3rM5V62p5LSFEDAI+ICC2ie+eqKO1SzPVVkM+SBJJgAAIBCEBCF5BWKn+LpKsCL7/WC2vIB9J65eZw47L30yQfvAQOMxDj7K3471r8tgmBtqv4KlZlAQEfEMAgpdvOIdMKjcqO+jspWbaujaPUpKiQ6bcKGhoEZAdEd4+UkMlCzNoZkFiaBUepQUBEJgSAQheU8KHwJrAAM9U3Mc/RFnpMbRyYaa+jSMIBDWBE6WN1Mi7J2xbl6c2ZQ/qwqJwIAACHiEAwcsjGEM7kg/K26jsehttX5+nNqgObRoofagR6OwaoL380bGwKJXmzvTvkhOhxh7lBYFAJADBKxBrzSR57u0bYi1XNU3PS6Al89JNkitkAwT8Q+Ds5WaqrOum7az9iomO8E8mkCoIgIDpCUDwMn0VmTODpVdb6MbtTqXliovFOrzmrCXkytcEenoHac/haiqanqw0YL5OH+mBAAiYnwAEL/PXkaly2M2zufayluuuGclUPCfVVHlDZkDALATKrrXStYoO2sbD7/H4MDFLtSAfIGAKAhC8TFENgZGJM5eaqKq+h+7joZRoDKUERqUhl34j0Nc/pGy/CrPjaRlvOwQHAiAAAkIAghfagVMC7Z0D9PbRGlp8VyoVsaYLDgRAwHUCV2+1U+nVVjXzMSkhyvWA8AkCIBCUBCB4BWW1eq5Qx883UHNbv/rRkG1/4EAABNwnIBvDy8zHjNQYWr0Yy624TxAhQCB4CEDwCp669GhJWtr7aP/RWlq5KINm5GOBSI/CRWQhS+BmVSedutikFhhOTcYCwyHbEFDwkCYAwSukq9924Q+drqdunp0lq89jQ2DbjHAXBCZLYGhomLfUqqGEON5Sazm21JosR4QDgUAlAMErUGvOC/luaO6l907W8f6KWZTPBsFwIAAC3iNQxWt+HTnXQPesyqHMtFjvJYSYQQAETEUAgpepqsM/mRkeHlYC17CFaDP/CEDL5Z96QKqhR8Bi4b53oo77HNHdK3MoLAybbodeK0CJQ40ABK9Qq3Gr8tY0dNOhMw3qpZ+djq9uKzy4BAGfEKhv6qH3T9XThuVZlJcFbbNPoCMREPATAQhefgLvj2Qr67roMAtZH3lgJsmX9v5jtRQbHU4bS3L8kR2kCQIgYEXgwKk66uu30L1rcpXm+bX9FbR8fjpNxwQXK1K4BIHAJYC9XgK37tzO+W9fuU6l11ooPyuO6pp6aQu/3NNTYtyOBwFAAAS8Q2ATfwQ1t/XRK/sqqCAnnv686yadvdxC//srS72TIGIFARDwOQEszORz5P5J8PiFRjp/pYU1XUR/2llOj2+dDqHLP1WBVEHAIQH5GHpi23TVTwd5BuRF3n7oKBvhw4EACAQHAQhewVGPDksxwIs3/v6166N+Glr66LV3KkavcQICIGAuArsPVlFNQ89opn7/+nXq5y2I4EAABAKfAGy8Ar8OnZZAvpbf52UiMtNi1LT1TF49OzczjmZPS3IaFh5AAAR8T+BaRTtV876oTa191NjSq44bV2TDHtP3VYEUQcDjBCB4eRwpIgQBEAABEAABEAAB2wQw1GibC+6CAAiAAAiAAAiAgMcJhPysxjffvkBVtW0eB4sI/UMgnBeg3PHEKoqPwz54/qkBz6a6c38pVda0ejZSxOY3AtI/P/bYSkpMwGxqv1UCEvY7gZAXvP7y97PU1BVH4WERfq8MZGDqBIb66umx+5dA8Jo6SlPE8Nc3z1JDeyyvaYX+aYoKmWImLP319Mj2xRC8psgRwQObQMgLXlJ9sXGpFB4RFdg1idwrAkME7UiwNYWY+FSKQP8MimodCsPoQlBUJAoxJQKw8ZoSPgQGARAAARAAARAAAdcJQPBynRV8ggAIgAAIgAAIgMCUCEDwmhI+BAYBEAABEAABEAAB1wlA8HKdFXyCAAiAAAiAAAiAwJQIQPCaEj4EBgEQAAEQAAEQAAHXCUDwcp0VfIIACIAACIAACIDAlAhA8JoSPgQGARAAARAAARAAAdcJQPBynRV8ggAIgAAIgAAIgMCUCEDwmhI+BAYBEAABEAABEAAB1wlA8HKdlc98WoYG6OqZl3yWXkdLJVVfP+yz9JAQCAQyAfTPQK495B0E/E8AgpeTOig7+js6/fa/OfTV19NK/b0dDv04etjVVkPDlqFRLxUf7Kfb/F+cK+mPBnTjxJhmWFg4nXjrxzQ02O9GDN7xOlWW3skVYjUrAVf6x1TblLGvCAf0z8m/68zajpAvEPAlAQheTmjnz9lIMxc+6NDXhQPP0bWzLzv04+jhnhc/T6J10u5W2W6aNu9edelK+jqcO0djmomp+ZTA/6uvH3QnCq/4nSpLr2QKkZqWgCv9Y6ptythXBAT65+TfdaZtSMgYCPiQADbJdgK7rfG60mal586nk3t+QhkFi6ji0l7eVDualt3zNWqqvki1t07wdSQ11ZTRhseepbqbJ3io8G80xEOGM4vvp1mLHqLu9jq6eOS3lJhWoIb1JL6ld3+ZTu//dxoa6KNT+35KOTNW0ryVH6OGyvNUsu1bKmfO0k9Km6a0ZZdP/olqy49ReGQ0zV70MAtuW6i8dCfnoZ+Klj6u4jr46ndUvBePvDAuzeK1n6a8WWup5uaxUYFPAug8J6TmcZ4PUfa05exvPV08/DxFRMXS8i1fp8TUAhY6X6HbV/aTZWiQCoo20fxVO1TYsmO/p7jELKopP0KpWXfRss1focjoeLpx4Q3O2y6KjIrhOP47JWfMUPm7eXH3BJZ93a1Ueuh5amu6QYkp+bRw/edIygwHAkLAWf+w1T97u5rpwsHnqJM/dtLzimnxhi+oPo7+if6JXgUCviAAjZcTyvJybm8qV74aq0vp8rE/smDxFAsNsXT+wC9ZWCphoaKIsguXU/HaT1FXazUd3fVDFqA+rgSQM+/8jJprL9FAfxfdLNvFmq3btHjj01R59T0WVt6lomVPKKFt9pJHWejZymndouHhIYpPylFpOktfPInQdevSHlqy6Us0i7Vzx9/6EbXUXeG0KqiD49OuvuIUDQ50T0hTnick56q8a79y1HmWoZaF6z5H1868TMd3P0vzVn2cwsMjWAB7QXnv7W5hIfKrSpCUH6/2ppsqbHnpm9TTUU/L7/k6C6UXqezYH6ivp00N3ZZse0bFGRY+1gStWYaFRdDB175D4kcE0ajYJBLhcXjYYswmzkOYgLP+Yd2mwsMj6dDr31PtfQW3wZb6K3SRzQl0W0f/RP8M4e6EovuIwNivno8SDPRklt/7DaWZmln8AHWykCUandj4NIpPzqGMvIVKaxSXkMEamnIlZMUmpCutmJRbtD2r7vtnpTkS7VZXW7US2sJYiEnLnsuanEL+8m6nmLhUJYzZYmWdvvipvnZQacoy8hfS9PnblBBYV3HSVnB1TwRFY5pyMzYhk7raayeEkTyv3P5PrOlaQ6k5c1lY+qzSjk2bv5U6WSATt2D1Uyrfna1VFB2bSD2djeq+hC3Z/m2SfM1Z8ghzKFUCa3RsstIIxsSljNNeWbPs7qij1vqrtHTzV5nPXUqw7GytZG4T86kSxJ+QJ2DdP6zbVE9ng2pT0jYbKs9yX0tRHwUCDv0T/TPkOxAA+IQABC83MYshurjwiKhxBvE6msGBHtYGRVIED0XK/3klH6VcFlrESdiwsDB1HmEnvAhEEoc9Zyv9AdZiRUUnjAaJ4CG8EWP9MJe1Q0ODvaN5G42IT4x5lnJpDVV4eBRxIjy8OEB7X3xaDR/28rAgDY+FHheWWQzzv4jIGNq241cUGRlLe//4Rdb8vT8WwOpssL+b049Qwpo8iuBhVHGiEYQDAVsEbPUPo7/BgV51GRkVp/pn7ozV/OHwSXXP2F7RP9E/je0G5yDgSQIQvDxAU4Swbh5SkyEw0SbJcFrhXZtpztLHqHDuZtaKZTpMRYQY+RIfHh7mIcYsnl3Yp2xOHAYyPEzNLKKKy2+r8JKP+ttnlPZNvuZbWGMkQtitsreUvZcOZkxT7kk4Z/nUYY3HTtbayZCmaPJmLNiuGOjnA31dajhVuFSxgJWaOWekXCx8imZC/DdWndPe1dHIMoFtukTYun3lHfWs4tI+1i6m8zBR3rgwuAABRwSMbUqG8OUjRdq69E+xv0zNmuMoOH9IoX9qQEaW6J+aCo4g4B4BGNc746U0VCNaKtFW6S9q49exGKYfeuP7ym7rkf/2CuXNXkc7f7NDDUEO9vfQ5g/LchQclv+NOtGc3dF+FczZoGyZCovuptUPfp+H65KpteGaGpIc8eM4/UUbPk+HX/8+vfaLR5TgI8b02dNXUFL6NLp27lV69ecPU1rOPDWsIvkQZ0xzzYf+J7U1XKf03AWj2Rs5GZ/n8eXneDj/YlyfWbCEdr6wg6Ji4pVGS0ci2rGyo7+lE7t/RAkpeTzs+C01DPnuX7+h8jLAGq2Nj/1Ye1dHI8tHv/Qq23Y9Q6f3/SudZVs50QaufuC7dodhx0WEi9Ag4Gb/VG2Kh79P7v0XOvf+L0g+Dhas+QTlTF+J/on+GRp9BqX0O4Ew1rIYBof8nh+fZ+DT3/wDdVny1NChrcSVITcjkh99mbUnsxe1M16rIQzW7IidiDhZE0uGyqJZ66SHF43+1VCgQZCTdcAio+PUMKXMnpQvyxVbvzmiQXIhfUlTfkRkmFGGBLWT6pV8RMUkTMi/TlP8vvGrJ2n9Iz+grMJlOqg6GvNssQxyWSLGysPXOi3R8onAKMOAcq+t8Qa985ev02Nffl1puUT7pp3E09/DtmzxqRzXRKWrNUspw4jt21gcOi7r41DXDfrljz5MWRlJ1o9wHYAEPvvMi9Q+mMvDgjy0bcNNtn9Km5L1vaJjEkf7vrGto396qX92l9PPf/gk5WQl26hN3AKB0CAw8VcvNMrtcimVZouFLnFGocv6WmY5aqFLnskQ2YhgMablMoYXQc4odETzjD0txBQtf1It3yDxuJq++BXhSsch1+JE6JP74ozpy7VOs5GXrxAjZGuhyzqMxK2FSPXMIOCJYCXPJqYffkfTJiFGnPiRSQfG8utncrRmKfEaBTejX5yHNgFX+4etNiWTYuQDRztj/0D/RP/U7QJHEPA0AWi8nGi8PA3c1fjkS96eYOJqHO7483R6ojGQZSYmYzfmTr6t/ULjZU0ksK+dabz8VTpP9xdn5fB0en7rn9B4OatqPA8BAtB4mbSSfSl0CQJPpycaA18LXSatSmQrCAl4ur84Q+Tp9NA/nRHHcxDwHgEIXt5ji5hBAARAAARAAARAYBwBCF7jcOACBEAABEAABEAABLxHAIKX99giZhAAARAAARAAARAYRwCC1zgcuAABEAABEAABEAAB7xGA4OU9togZBEAABEAABEAABMYRGFtpc9zt0Lro7W7mWX0ja3WFVsmDr7SW/r7gK1SIl6i3q5nXh0P/DIZmgP4ZDLWIMkyVQMgLXp98ciXV1rdPlaNpw7d0WKipzULt3RZKjA+n1MQwyk4N3h+xsLBplBAfY9r6QMbcI/DUE9w/69rcCxRgvmubh6irZ5gsvIfI7PwIiggfW3Q5wIriNLvSP5MSY536gwcQCGYCIb+AajBWroXf4KfKmqimvoeKpidRcVEq7T5YRQ9sLKCbVZ109nIzzSpMpKXz0oOx+CgTCAQEgaGhYXrrUBXNnpZE82elUHfPIO08UEUbl2dTblZcQJQBmQQBEHCfAAQv95mZNkRv3xAdOdtAnd0DtHJRBuVljewbKRnWgpfOvBbAZhYk0rL5EMA0FxxBwBcEWtv7ac/harpvQz6lJkWPS3LfkRpKTY6ilQszx93HBQiAQHAQgOAVBPXY3NpHR883qCGKDfy1nJgwtv+cLp614KXvQwDTJHAEAd8QuHSjlcorO5UGOtzOsOKVm2109VYH3b8xnyIjMAfKNzWDVEDANwQgePmGs1dSKa/soAtXWyk9OZrWLs2iyEj7L2h7gpfO2KgAls8asAXQgGkuOIKAJwnsPyrarGhaUZzhNNr2zgE1FLlldS5lpsEuyikweACBACEAwStAKkpnc3h4mM5caqbK2m6ayXZaS+am6UcOj84ELx34VnWnin9GfgItX+D8x0GHwxEEQMA+ge5ett96v4o2lWRTTobr9lvS3986VE35bPO1BDaZ9gHjCQgEEAEIXgFSWf39bL91roHa+Ct4BWukCnMT3Mq5q4KXjrSCBbDTLOBBANNEcASByRGQj5nTZc300OZCio6yr5V2FHvp1RaqrOum+9bn89IawTvr0REDPAOBYCEAwcvkNdna0U9H2WCeZ5rT+mVZlGJliOtq9t0VvHS8WgCbnpfg0vCIDocjCIAA0eGz9WSxEG1ckT1lHC3tfSSG99tZ+LI2yJ9y5IgABEDAZwQgePkMtXsJVdR00Tle9iE5MYrWLcue9JeyTnWygpcOL/k5zUtUQADTRHAEAfsEBgYtPLRYqZZskZnDnnKyVIwsOTFnWiItmJ3qqWgRDwiAgA8JQPDyIWxXkjr/QTOvtdXFQ4nxbGOVzivqe2ZYYaqCl8777douOnURApjmgSMIWBOob+6l90/U0oN3F1JCnHfWqD5zqYmaWvtp69pcj70jrMuBaxAAAe8QgODlHa5uxSpfx8fYfqu5rZ8NaNPIk1/IOiOeErx0fFoAm8a2ZiULYYSvueAY2gRkceIGFry2rcvzukAkAt57IuDxwsi2lpAJ7ZpA6UHAvAQgePmxbjq6BtSCp4NDFrUcRHqK97a68bTgpbFVsgbsJGvARABbUew5DZ2OH0cQCAQCMgS4h2cfylC87BThKzfIH21v8pDmQk6zaHqyr5JFOiAAAlMgAMFrCvAmG7S6vlvZS8XzMMR6tt+KjfH+3oneErw0Ay2AyRBpCa9R5KkhUh0/jiBgVgLtnf1qyYeta/PImx9Pjsp//EIj9fCSFZtX5TryhmcgAAImIADBy4eVUHatla5VdFBedpwSTnw5LdzbgpfGWFnHGrDSJirMYQGMhyAhgGkyOAYjgSs320lWmX9wUyFFRHjGHnOynOSDTrYM+9DdBRQX6x3bssnmDeFAAATGCEDwGmPhlTMZRpSvUbH78OdwgK8ELw0RApgmgWOwEniX7asSWWu9cpF59lSU9f7+zkOPonWewbtQwIEACJiPAAQvL9VJV88gHT5TT30DbL+1JNPvW374WvDSWKt40ccTpY1UwBqwldCAaSw4BjAB2Yz+zfcqeZmXLMrPHtuI3kxFOnS6nrXNvPYf790KBwIgYC4CELw8XB91TT0saDRRbHQEbVieZRqVv78EL41XhkFE8wcBTBPBMRAJiCb3+PlGtQp9DPdxM7vRFfN56DHa5Hk1M0fkDQQ8TQCCl4eIip3H5fJ2yk6PpdWLM023rYe/BS+NWQtgoilYtQg2YJoLjuYncPx8A/WwtiuQDNjF4F72iDSzds78NY8cgoBnCUDwmgJPmUIuw2i1jT00b1YKzef/ZnVmEbw0HxHAhF1eZhytYkEVRviaDI5mIyB2mrtYeFkwJyVgl2x451gtJcRHqo9Cs/FFfkAg1AhA8JpEjctX5GGePSRHERpyMuImEYtvg5hN8NKlhwCmSeBoRgJNrX20/1gNPcCLlCYlRJkxiy7n6dqtdrp0Q2ZgFlBk5OQ263Y5MXgEARCwSwCCl100Ex80tvQqOyV5aW1go1VvbQcyMeWp3zGr4KVLVtMwYgOWyxowGaqFBkyTwdFfBEqvtlAlTw65jzel9uXSL94sbycv2rzrYBXdw+t9ZbFZBBwIgIDvCUDwcoH59dsddJHX4MrgleXXLM2kyIjA+1o0u+Clq8EogK3iafrB8oOny4ej+QkMDw/T3iM1JB8BS+ammT/DbuZQly87LZaW8X6wcCAAAr4lAMHLDm+x3zrNG9FW1/XQ7GmJtOiuwH4BB4rgpaujtqGHjl1oUMO4ZpysoPOJY3AR0BqhLatz/b4EjLfJll1vpZtVnWoYFR843qaN+EFgjAAErzEW6qyPFyCU1Z9lH0VZed2s6/RYZdvpZaAJXrpAEMA0CRy9TeAGa7ZLr7aqld9DxQaqtaOf9h6uVpt6pyV7b69Yb9cd4geBQCIAwetObbW099HRc40jiw7y/onJiYFtSGvdCANV8NLlkJmjx3g6v0xkgAZMU8HRUwQOnKqjKLbdXLs0y1NRBkw8ot2X94OsdC+7a8CBAAh4l0DIC16yyOC5D1ooNSlarXUjL99gdIEueOk6kQVqj7GALIbBa3hHAAyRaDI4ToaAaLjf5C12xJ5wWm7CZKIImjDnPmimuqZe2r4uD5NbgqZWURAzEghJwUuMS89dbqGKmi6anp9Ay+YHv4FpsAheuhNBANMkcJwsAZnIcfgMNpU28pOZ2+8crw2K5TOM5cI5CJiJQEgJXv28b+LRcw0kdg3L5qWx0BU6m8gGm+ClO1E9a8BkiDgzPYb3xMyCBkyDwdEhgZMXG6m9c4DuXZPn0F8oPtQLxsqi0HNnJociApQZBLxKICQEr/bOfrXgKSu6aB3bcKQmR3sVqhkjD1bBS7OGAKZJ4OiIwNDQiD1T0fQktduEI7+h/uwk7yzR0TVIW9bkhjoKlB8EPEogqAUv2dD2TFmzWncj+kwAABUVSURBVHF6/bKskN4oNtgFL90r6pt76SjPSs1MYw0YC9mwAdNkcJQJNPt4fS5ZEDWFbTrhnBOQWcUHz9TTh3i1+/i4SOcB4AMEQMApgaAUvGTF6Ru3O6kgJ55WFKfDUJSbQagIXrrFKwGMh5UzUyGAaSahfMSaVZOvfTHR2MkTEMQWdmZB6JhnTJ4YQoKAYwJBI3gNDrL91vlGaua91RbNTaXZhUmOSx5iT0NN8NLV28AasCMsgMmuA+tY6wkNmCYTOkfRcmWkRtPyBRmhU2gvlPTw2Xqy8FDtxpIcL8SOKEEgdAgEvOAlK00f5h9WEbzWsHF1Bms44CYSCFXBS5OAAKZJhM6xu2dQaWo2rcwJiI3sA6Fmbtd20Qm2/Xro7kKKiY4IhCwjjyBgOgIBK3iJ7YHMTIqPjaT1vGF1bAxeAo5aV6gLXpqNTJc/zDZg6awBE7s/aMA0meA6ylY4Zy838yr0hRQdFZxr8/mrxnr7hpRAu5rX0SvMCe21z/xVB0g3sAkEnOB16UYrXb3ZwRvYxtJKbKLscuuD4DUelVEAk5muERFh4z3gKmAJHGZjcJnBvGFFdsCWIRAy/t6JWqX1CsXV/gOhfpBH8xIICMFLpoCLerueV1WePxtry0ymOUHwsk1NBDDZmzONlxhZx1tFQQCzzSkQ7g6wucGb78EI3Jd1pfe3fPDuArXlki/TRlogEKgETCN4dfcOqmFDI0i5J1+vvf0WtT9fNm8TA+c6gS62cfnqD49SDw8NiD5HhIpHt0yjjz44y/VIQsSnFsBkjbf1VgKYzJBE2zNfQzhd1kTSxjexsbes4/b+yTo1tIhlD3xbV1IHu3jW491sS5fNe6mKCcjg4HBI7nvpW/JILVAJmGJhFrHH+M//ukw/eaZE2dyIIfTxC42sxg5XP4J4kU6ueSXwujsPbCygl/dVEI+8UHR4GD24qXBykQV5qMy0WHqEhdImnhUrU+e1ADbEGwj/r5+doa/tmE+L7koLcgqBVTxp1/LuECN6+aj4h/tmYOkYP1ShvGeE/dtHa5l/K734xg01yUn2Ug0LwxC+H6oESZqcgN+tTmVV+Z+8UKr2TXz17Qp6ff9tulbRoQSGbevysWjfFBuQGBeLACtOzpMTo6YYY3AHl1mxIoAVz0lVAtiv/3qFmtv66T/+cIla2/uDu/ABVDpZq+/KzXaSNabeeLeSNvAEG/zI+68Chf09q3Jo14EqpYWUfXBPlDb5L0NIGQRMTMCvgpfsCfbT35RRY0ufQiT7KD567zS13hJsbTzTakTQ2s4rdcfHRiiBwjOxBn8sIoDdz9rC81daVGHbeF+//3jxEllYAwbnfwIv760YzYQMBf/iz1dGr3HiHwK/e+26+oDWqb+055Y+xREEQMBAwK9Djc+/dJUul7eNZudmdRddu9VORTOwMesoFA+cPHrPNErlLVJkSADOdQJvH62hFoOW6+K1Vvob/5h85IGZrkcCnx4ncJXfEY2tvWrf1VmFibxYciLNwoLJHufsboQy3FhSnEE3qzvVEHA5DwOfudSEhWvdBQn/QU/Ar8b1Hbz4aV//EBvPD/HRQn1sBJ7OmobczLigB48Cmp+AtE2x8RKnLVVkSAVrxvm37oZ5rQgMK/q3DlxNHXXlKin4CyUCfhW8Qgk0ygoCIAACIAACIAACfrXxAn4QAAEQAAEQAAEQCCUCELxCqbZRVhAAARAAARAAAb8SgODlV/xIHARAAARAAARAIJQIQPAKpdpGWUEABEAABEAABPxKIPL1PefoSnmDXzNhpsS3b5pPS4untrr78386zMsQdJupWH7Li8w++9ijJVSQm+p2Hq7cqKM39pbyqvtYO8sxvDB6ZNsimjcnx7G3KT7d/W4ZlX5QPcVYgif45jVFtGrZTFMW6NiZcjpw/Lop82amTMn76ZNPrqLsTCxhZKZ6Cfa8RO5+7zLdboig8AisaN7X007TC9KmLHi9vu8C78+TG+xtx6XyDfc3073r505K8Kqoaqb9xyopMsZ9oc2lzAWJp4HeVlo8P8/rgtfbh67QpZv9FBEZEyTkJl+M/t5OSkmKM63gdaa0kt4+VkfRMYmTL2QIhBzub6IPbSmG4BUCdW2mIqoVNaNikigyChtQDw0OeKRuwnlPxKg4CAsC0xLWNSWmUdwuo8HSMUNLr+PnHnwq74qo6HgPxhiYUVksQ6bPeFRUPMWg7zisJwt1OHyOhyDgDQKw8fIGVcQJAiAAAiAAAiAAAjYIQPCyAQW3QAAEQAAEQAAEQMAbBCB4eYMq4gQBEAABEAABEAABGwQgeNmAglsgAAIgAAIgAAIg4A0CELy8QRVxggAIgAAIgAAIgIANAhC8bEDBLRAAARAAARAAARDwBgEIXt6gijhBAARAAARAAARAwAYBCF42oOAWCIAACIAACIAACHiDAAQvb1BFnCAAAiAAAiAAAiBgg4BpBa/2pptUe/OYjSy7f6v6xmHqaLntfsAgD1F17QB1tdW4VEowHMFkGRqgq2decomZJzyBu/sU3Xl3dLRUUvX1w+4nghBuE3CnXtyO3EYA9B0bUHDLFARcErz2/fGLLAQdd5hh+QEfnuQ2GkOD/dTdUT8u/vMHfkXtzRXj7rl6YR2fdPjLx//oanCf+ys7+js6/fa/OUzXukwOPdt42NfTSv29Y9tj9Pe207Fdz1JYeAS11H1Au37zFA0PW2yEHLllRobebpdScut2XfHBfrrN/8W5Um/Koxt/rOvZjNxdLY4r7cq6vK7Grf3ZCq/fHa6kHxYWTife+jFJPKHkvN13HNWLcHYlfXfrwzrNQO477pYd/gOLgEuCV/Haz1Ba9lyHJdvz4udZq1Tp0I+9hw2VZ+nAK/80+rivu5Vqbx2naXO3jN5z58Q6Pomn8up7NNDf7U40PvObP2cjzVz4oMP0rMvk0LONhxcOPEfXzr48+uT2lXe5Tu+i+KRsSkwtpCWbvkTyI2TPmZGht9ulsLBu17fKdtO0efcqTK7Umz2e9u5b17MZudvLu/V9V9qVdXmt43B2bR3e+O5wJf3E1HxK4P/V1w86Syqonnu77ziqFwHpSvruArdOM5D7jrtlh//AIqA2yXaWZXkpxSdn81dhH1088ltKTCtQ6vn03Pm09O4v0+n9/05DA310at9PKWfGSpq/+imlYaopP0axCWm0cN1nKTWriG6VvUW93S3U2nCdNQnVNGfpY5SRt4hKDz9P3e11Sviat/Lj1NvVRCmZsykuMVNlraezkUoPPU9tTTcoMSWfFq7/HCWlTaPju39E4j8lcxY1116m8os7aV7JxybElz1tuRIwGivPUd7sdc6K6/PnbY3XlTZKeJ7c8xPKKFhEFZf2UnhENC2752ssEEVMKFNiagFdOPgcdbKwm55XTIs3fIEiefNiW+Gbqi+yIHuC44ukppoy2vDYs1THGsy8WWtVWQf6OunWpT1UULTJZh3NWHAfJaTkmY6hu+2yeO2nqebGER4q/BsN8ZDhzOL7adaih1Tbc6Vdz1v5MWqoPE8l276luDmrN2mj9tpueelOzkM/FS19XMV18NXvcL/5xIR6NnvbddRZnLUrW30/s2CxzXeHrXZtq18Y3x3yTnHWriX/0g9q2KxBC9SOyhQszzzVdyb7TneWfnhElN2+Eyzv/WBpSyiH+wTsqzgMcTVWl5J8SQ70d9HNsl3KXmrxxqeVFkk0J0XLnlA/6rOXPMovr6105eSflX3Wyu3fpuT0GXTo9e/zMNYwdbZWsbDwa5KX6wz+0Tu5918oMiaBpnOY6NhkWrDmU0pAa2u8QQlJOSoHEu7ga9/hIbFw9YMXFZtE8iMlw2KNVedZYGlX/vpYoGvifIqwZh2feIhPzqVOFvbM6ER4am8qV1kT1peP/ZHmr3qKIqNi6fyBX9os06HXv0cJXKYV256hlvordJGHK8XZCp8zo0RxzS5czl+an6Lw8EhqZcbxKbkqjNRr/e0z6txWHWlNodkYutsuO3jo+uiuHyphff6qHXTmnZ+xwH7J5Xbd3nSL290QC6AjbdNZvTlqux0tFdTB8WlXX3GKoqLjAq7t6vzbOjprV7b6vr13h612bauvG98dztLX7Vr6UVerOd8Ntrh64p6n+o6t94WtejXWi+TfWfqO+k6wvPc9UY+IIzAJuCR4GYsmWpVV9/0zyZe4aLdEcyXaLLEVkuHIpLRCEqNG+XGqqzhJEZExrFGoVVosiWdG8X00hwU0+S9DW/0s0CVnzGRtTRxl5i9iASyJ+nraKPaOtqu7o45a66/S0s1fVUNjMiTW2VrJ6dYaszV6HhEZPSE+eSgvaclHILjl935DsZ1Z/ACXtZoZji/T0GCvYiJ1Ier1mLgU1mRdHC2adfi4xCyKjU9j4TOHNYwLlb9+ZhyXMKJRHA1458S6jno6GtQTMzN0pV2KViMuIYM1p+Xq4yE2IZ2F9RFuroQXIT8mLlV9ZFgzk2tr7u62XfnKt+4LEq+ZuUv+XHXW7cpW33f07rDma90vrN8d1vmyTl+361juB10B8m6wLpMnrl1p+476jjVXW/VqfKdb59lW+u72HVttQdIJlr5jzQzXgU3AbcFLhKWwsDBV6gj+obBlUD840EPhLCxE8FCZ/FCVbH1GCQcSKJwFNO3kh0Y0CNZONDJDA73q9iDbZcmQgmh/xEkHE6fCcT4cGYQrj3f+DKr4RvJtvG/Gc21rpfjYmLAwUhZiJnGKce6M1bRg9SdHi+IsvHiUYUepJ1vOXh2ZmaHL7ZLblrRL+T+v5KOUO2uNQuBKePm4sMdMIrHm7rDtUnC2XVvtSd+z1670czk6endY8zWG0+fGd4e+p4/20pcPGf1O035D6ehK21f1Yqfv2ONqZOioXmyl77DvBOl738gL58FNwG3Byx6O8HAZk29gQWhYacCGhwZJhh7Fjitn5ir+oY+yF1Q96+1qIQuHkfBxSVmjsxwT2KZLhK3bV95R4Ssu7WPtTToPs+VRTGyKmpEns1lusv2YdpKWMT6538OzJuXrJ1CdsUxxidk8LJWgyiN8xU4pNWuOw6JJeJk5qgVV0YJZzyR1GAE/DESG1u1SvrwL79qs2mXh3M1O24QxfDy3S7FzNM4OdcTMYdtlLWULa3Llw0XsZMTeS5yxnqUviAtE7irjLvyxLq9ozz317nAheeVF+kEgvxtcLae7/oxtX+rFnb5jXa/Gd7or+XDYd0Love8KK/gJPAIuCV7yNTjytclH/jfqZBbcHe1XwZwNyhbr+K7/q4zp25tv0d+f+wf6+6//kU7s/vFIEOV3LPxIXGFqCFGEq1d//jBV8ezD9NwFJDYBIiSIpquE7ZhO7/tXev2Xj9OFQ7+m1Q98V2ls5rLGovTwC/T6rx4ni2WA0xiJW2brGeOzWAbZhuqminc072Y6MXAZYz2iQdFf4sYy1ZYfoRK2nxMbuZ0v7KA3mLMWTO2FFwPimzwj77VfPMqsBpnFfGrjSQ4jzlCvhrzIM11HZmQ4VlZD/lWmbbfLvFnr1OSKnb/ZoZbP2PP7z/Hwsyxj4lr4hJQCZYvY2nBNUpEKkj93TjkO6Q9y545W2FHbnT5/K080aVZtvvzirjsa4Yl9wYzcVSFd+mPgamClGCluE8srE3FsvTvG6nqMr8Rj7BfW7w6pG/mnnJ305Zn0A3nnhJIb42lgJADsvNPt9h07XB3Xi9Qhp6v6i+30HfWdoHnvh1KDQ1nHEQj78nf/33Bte+roUN64p3cuRBMlQ1PijOdqmHG0A5HSBIitlqiVxSnjVf5qj2IDenFK28LXMmQjzjqu/r4ONTQp99947kla9/D/UbZk4le+/kdsbFLkctSJtksLaPbiq715Qq3V8/AX/zb64zgageGkm22ZPvXYLPrIwyWGu+6fPvH0cxSVNM/lgEYuxjJIBMZr4a0ZyTNhIutzRcckKk2JtX/razVUKMIs24bJumzHef2iR774kmKi0zHmxRjeVYYSxugsvVX0va9souWLphlvu3S+78Al+s8Xz1N0wsgkAOtAOs9y33jurF1Km5GhjGjWOskPgDvhZXadfM2v2PrNce3ZmP6E+Oy0Xak/yYf0D2N4Yz27wr2vq5a+smMx3b+5WJXFW3++/eyrdK0mhrWt8S4noctlr11JRMby6oit3x06Hv3ceG0ML/eN7w7tz176Iti+8asnaf0jP6CswmU6eqfHHp55/ejmHHp6xwanfv3h4Zd/OEA7DzbyrPJ0m8lrLvLQeO5u37HHVeJ1pV6cpe/t976l5zb98H9spQV35UlW4EDAJwRc0nhpoUtyZDwXAUp/5cszMW7VQpdcywtaC11yrTQB42y8xlazkLjEHkycpDFnyWMkK6trJz+QYkRu7USzpe2/rPOm46u69j7NWfb4uLxax+PPayMXYxkkT8ZrIyN5JkzEaF4EAe2M/uWe8Vo4idAlLmfGKhbYkkaNy7U/Y17En75vRoY6b8Z8yrmzdiltJiY+VfET/+KMcTkKX7T8SV5K5ZAKY2RlDD8hPjttV+pP9w9jeGM9m5G7KryLf3S5jKwm8DH0fR2t9btDx6OfG6+NvOS+8d2h/dlLv5GXB5Fhd3eELp2HQD5qLlIG47mjti9+rfuOPa7i15V6EX+O0g/m976UHS40CYxJPiYr/6INnx/RKHggX6KdkBcE3BgBeaHd/+nfuswFDEfYia3LQ0//ZQykl8/A3X3A7rw7sqevoO2f+LX7iSCE2wTcqRe3I7cRAH3HBhTcMgUBU0sjnhKWPBWPKWrMg5lwh4s7fj2YRVNG5UsWvkzLlLAnmSl3uLnjd5LZQbA7BHzJ2pdpoYJBwB0Cpha83CkI/IIACIAACIAACICA2QlA8DJ7DSF/IAACIAACIAACQUMAglfQVCUKAgIgAAIgAAIgYHYCELzMXkPIHwiAAAiAAAiAQNAQgOAVNFWJgoAACIAACIAACJidgFpOoq+7iQbuLHpq9gx7M38DfZ0eid5iGeZtXuo8ElegR2IZ6JpSEfp6O2nIApaOIA72e6bdOkpDP5N3xWB/h74M2eNAH7frsBxTl7+/t40XMZUdPeDsEbAMdNt7hPsg4DUCkZ/5x9VU24AXqSa8aF6+Pp308WufuZt6evHC0wCn5afpU7eO84ty6cufWOVWmFD1XDzX+ytvf/zRFVRZ2xqqiCeUe/4c8wpeW9bPpfzciQtOTygEblBeNjihGfiWwP8HuIad4qd/EeMAAAAASUVORK5CYII=",
      "text/plain": [
       ""
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fig = Image(filename='gfx/intent_inout_map-crop.png')\n",
    "fig\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15545c2e",
   "metadata": {},
   "source": [
    "You may be wondering why `intent(out)` has been included as a permitted value for the rightmost tree of the above flowchart. It is to account for the following possibility: an `allocatable` variable is allocated in subroutine A, and passed as an argument to subroutine B. Subroutine B must therefore declare the variable as either `intent(inout)` or `intent(in)`. Subroutine B then passes the variable as an argument to subroutine C without using it first. In subroutine C, the variable can correctly be of any declared intent."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a2320a8a",
   "metadata": {},
   "source": [
    "## Checking `intent` consistency across function calls"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "112d9adb",
   "metadata": {},
   "source": [
    "The code below gives an example of how a `visit_CallStatement` method can be implemented:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "ae85e899",
   "metadata": {},
   "outputs": [],
   "source": [
    "intent_map = {'in': {'none': ['in'], 'lhs': ['in'], 'rhs': ['in']}}\n",
    "intent_map['out'] = {'none': ['out'], 'lhs': ['in', 'inout'], 'rhs': ['in', 'inout', 'out']}\n",
    "intent_map['inout'] = {'none': ['in', 'inout', 'out'], 'lhs': ['in', 'inout'], 'rhs': ['in', 'inout', 'out']}\n",
    "\n",
    "def visit_CallStatement(self, o):\n",
    "    \"\"\"\n",
    "    Check intent consistency across callstatement and check intent of\n",
    "    dummy arguments corresponding to allocatables.\n",
    "    \"\"\"\n",
    "\n",
    "    assign_type = {v.name: 'none' for v in self.in_vars + self.out_vars + self.inout_vars}\n",
    "    assign_type.update({v.name: 'lhs' for v in self.vars_written})\n",
    "    assign_type.update({v.name: 'rhs' for v in self.vars_read})\n",
    "\n",
    "    for f, a in o.arg_iter():\n",
    "        if getattr(getattr(a, 'type', None), 'intent', None):\n",
    "            if f.type.intent not in intent_map[a.type.intent][assign_type[a.name]]:\n",
    "                print(f'Inconsistent intent in {o} for arg {a.name}')\n",
    "            if f.type.intent in ['in']:\n",
    "                self.vars_read.add(a)\n",
    "                self.vars_written.discard(a)\n",
    "            else:\n",
    "                self.vars_written.add(a)\n",
    "                self.vars_read.discard(a)\n",
    "        if getattr(a, \"name\", None) in [v.name for v in self.alloc_vars]:\n",
    "            if not f.type.intent in ['in', 'inout']:\n",
    "                print(f'Allocatable argument {a.name} has wrong intent in {o.routine}.')\n",
    "\n",
    "IntentLinterVisitor.intent_map = intent_map\n",
    "IntentLinterVisitor.visit_CallStatement = visit_CallStatement\n",
    "routine.enrich(source.all_subroutines) # link CallStatements to Subroutines\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "83a0eaa7",
   "metadata": {},
   "source": [
    "In the final line of the above code-cell, we called the function `enrich`. This uses inter-procedural analysis to link `CallStatement` nodes to the relevant `Subroutine` objects. Also note that in the above code-cell, `intent_map` has been declared as a class-attribute because it will be the same for every instance of `IntentLinterVisitor`. \n",
    "\n",
    "We can now finally run our intent-linter and check if any rules are broken:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "1baca9f0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All rules satisfied\n"
     ]
    }
   ],
   "source": [
    "intent_linter = IntentLinterVisitor(in_vars, out_vars, inout_vars)\n",
    "intent_linter.visit(routine.body)\n",
    "intent_linter.rule_check()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.8 ('loki_env': venv)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  },
  "vscode": {
   "interpreter": {
    "hash": "5b6429b76fde06fc4400bf3c27b3ae893ffb7a047f8b8ee9418a3bc77878d107"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
loki-ecmwf-0.3.6/example/04_creating_new_visitors.ipynb0000664000175000017500000002542615167130205023362 0ustar  alastairalastair{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fa3b1299",
   "metadata": {},
   "source": [
    "# Creating new visitors"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23abb66d",
   "metadata": {},
   "source": [
    "In the previous notebook, we relied heavily on the [_FindNodes_](https://sites.ecmwf.int/docs/loki/main/loki.visitors.find.html#loki.visitors.find.FindNodes) visitor, which looks through a given IR tree and returns a list of matching instances of a specified [_Node_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Node) type. Although this functionality is sufficient for most use cases, there may be scenarios that require the implementation of bespoke visitors.\n",
    "\n",
    "For node types that could appear in a nested structure, for example [_Loop_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Loop) or [_Conditional_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Conditional), we may be interested in knowing at what depth they appear in a given IR tree. The following illustrates how this can be achieved by building a new `FindNodesDepth` visitor based on `FindNodes`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d758ee9",
   "metadata": {},
   "source": [
    "## Dataclass to store return values"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0537525",
   "metadata": {},
   "source": [
    "The default return value for `FindNodes` is a list of nodes. For `FindNodesDepth`, we would also like to return the depth of the node. We can create a new dataclass (essentially a c-style struct) called `DepthNode` to store both these pieces of information:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "547ef8bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "from loki import Node\n",
    "from dataclasses import dataclass\n",
    "\n",
    "@dataclass\n",
    "class DepthNode:\n",
    "    \"\"\"Store node object and depth in c-style struct.\"\"\"\n",
    "    \n",
    "    node: Node\n",
    "    depth: int"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d63f588",
   "metadata": {},
   "source": [
    "## Modifying initialization method"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee2518e0",
   "metadata": {},
   "source": [
    "`FindNodes` has two operating modes. The first (and default mode) is to look through a given IR tree and return a list of matching instances of a specified node type. The second, which is enabled by passing `mode='scope'` when creating the visitor, returns the [_InternalNode_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.InternalNode) i.e. the [_Scope_](https://sites.ecmwf.int/docs/loki/main/loki.scope.html#loki.scope.Scope) in which a specified node appears.\n",
    "\n",
    "For our new visitor, we are only interested in the default operating mode of `FindNodes`. Therefore let us define a new initialization function for our `FindNodesDepth` class:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "37350bd3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from loki import FindNodes\n",
    "\n",
    "class FindNodesDepth(FindNodes):\n",
    "    \"\"\"Visitor that computes node-depth relative to subroutine body. Returns list of DepthNode objects.\"\"\"\n",
    "    \n",
    "    def __init__(self, match, greedy=False):\n",
    "        super().__init__(match, mode='type', greedy=greedy)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "613e4b3a",
   "metadata": {},
   "source": [
    "## Modifying the `visit_Node` method"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77845b25",
   "metadata": {},
   "source": [
    "In order to achieve the desired functionality of our new visitor, we will need a new `visit_Node` method. We start from a copy of `FindNodes.visit_Node` and make only a few changes to it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "67983caa",
   "metadata": {},
   "outputs": [],
   "source": [
    "def visit_Node(self, o, **kwargs):\n",
    "    \"\"\"\n",
    "    Add the node to the returned list if it matches the criteria and increment depth\n",
    "    before visiting all children.\n",
    "    \"\"\"\n",
    "\n",
    "    ret = kwargs.pop('ret', self.default_retval())\n",
    "    depth = kwargs.pop('depth', 0)\n",
    "    if self.rule(self.match, o): \n",
    "        ret.append(DepthNode(o, depth))\n",
    "        if self.greedy:\n",
    "            return ret \n",
    "    for i in o.children:\n",
    "        ret = self.visit(i, depth=depth+1, ret=ret, **kwargs)\n",
    "    return ret or self.default_retval()\n",
    "\n",
    "FindNodesDepth.visit_Node = visit_Node"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5d6ba1d",
   "metadata": {},
   "source": [
    "The first change to `visit_Node` is the addition of a line that sets `depth`. If `visit_Node` is called from the base IR tree, then `depth` is initialized to 0. If on the other hand `visit_Node` is called recursively, then the current `depth` of node `o` is retrieved. The second and final change is the addition of a `depth` keyword argument to the recursive call to `visit` for the children of node `o`. As recursion signifies moving down one level in the IR tree, the `depth+1` is passed as an argument.\n",
    "\n",
    "Having now fully defined our new visitor, we can test it on the following routine containing nested loops:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "cf2196d3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "DO k=1,n\n",
      "  DO j=1,n\n",
      "    DO i=1,n\n",
      "      var_out(i, j, k) = var_in(i, j, k)\n",
      "    END DO\n",
      "    DO i=1,n\n",
      "      var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "    END DO\n",
      "  END DO\n",
      "  \n",
      "  CALL some_kernel(n, var_out(1, 1, k))\n",
      "  \n",
      "  DO j=1,n\n",
      "    DO i=1,n\n",
      "      var_out(i, j, k) = var_out(i, j, k) + 1._JPRB\n",
      "    END DO\n",
      "    DO i=1,n\n",
      "      var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "    END DO\n",
      "  END DO\n",
      "END DO\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from loki import Sourcefile\n",
    "from loki import fgen\n",
    "\n",
    "source = Sourcefile.from_file('src/loop_fuse.F90')\n",
    "routine = source['loop_fuse_v1']\n",
    "print(fgen(routine.body))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "02b6e47c",
   "metadata": {},
   "source": [
    "`loop_fuse_v1` contains a total of 7 loops, with a maximum nesting depth of 3. Let us see if our new visitor can identify the loops and their depth correctly:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "df95eda3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 Loop:: k=1:n 1\n",
      "1 Loop:: j=1:n 2\n",
      "2 Loop:: i=1:n 3\n",
      "3 Loop:: i=1:n 3\n",
      "4 Loop:: j=1:n 2\n",
      "5 Loop:: i=1:n 3\n",
      "6 Loop:: i=1:n 3\n"
     ]
    }
   ],
   "source": [
    "from loki import Loop\n",
    "\n",
    "loops = FindNodesDepth(Loop).visit(routine.body)\n",
    "\n",
    "for k, loop in enumerate(loops):\n",
    "    print(k, loop.node, loop.depth)\n",
    "    \n",
    "depths = [1, 2, 3, 3, 2, 3, 3]\n",
    "assert(depths == [loop.depth for loop in loops])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74ce856d",
   "metadata": {},
   "source": [
    "As the output shows, the depth of all 7 loops was identified correctly. Note that the subroutine body itself is assigned a depth of 0, and because the outermost `k`-loop is a child of the subroutine body, it has a depth of 1.\n",
    "\n",
    "We can also use our new visitor to find the depth of the [_Assignment_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Assignment) statements within the bodies of the loops:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2aa221a6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 Assignment:: var_out(i, j, k) = var_in(i, j, k)             4\n",
      "1 Assignment:: var_out(i, j, k) = 2._JPRB*var_out(i, j, k)    4\n",
      "2 Assignment:: var_out(i, j, k) = var_out(i, j, k) + 1._JPRB  4\n",
      "3 Assignment:: var_out(i, j, k) = 2._JPRB*var_out(i, j, k)    4\n"
     ]
    }
   ],
   "source": [
    "from loki import Assignment\n",
    "\n",
    "assigns = FindNodesDepth(Assignment).visit(routine.body)\n",
    "\n",
    "for k, assign in enumerate(assigns):\n",
    "    print(f'{k} {str(assign.node):<60}{assign.depth}')\n",
    "    \n",
    "depths = [4, 4, 4, 4]\n",
    "assert(depths == [assign.depth for assign in assigns])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "067e5919",
   "metadata": {},
   "source": [
    "All the `Assignment` statements and their respective depths are identified correctly. We can do a similar test on nested `if` statements:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "56f6f076",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 i 1\n",
      "1 j 2\n",
      "2 k 3\n",
      "3 h 3\n"
     ]
    }
   ],
   "source": [
    "from loki import Subroutine\n",
    "from loki import Conditional\n",
    "\n",
    "fcode = \"\"\" \n",
    "subroutine nested_conditionals(i,j,k,h)\n",
    "    \n",
    "    logical,intent(in) :: i,j,k,h\n",
    "\n",
    "    if(i)then\n",
    "      if(j)then\n",
    "\n",
    "        if(k)then\n",
    "          ! do something\n",
    "        else\n",
    "          ! do something else\n",
    "        endif\n",
    "        \n",
    "        if(h)then\n",
    "          ! also test h\n",
    "        endif\n",
    "\n",
    "      endif\n",
    "    endif\n",
    "\n",
    "end subroutine nested_conditionals\n",
    "\"\"\"\n",
    "\n",
    "routine = Subroutine.from_source(fcode)\n",
    "\n",
    "conds = FindNodesDepth(Conditional).visit(routine.body)\n",
    "for k, cond in enumerate(conds):\n",
    "    print(k, cond.node.condition, cond.depth)\n",
    "    \n",
    "depths = [1, 2, 3, 3]\n",
    "assert(depths == [cond.depth for cond in conds])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.8 ('loki_env': venv)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  },
  "vscode": {
   "interpreter": {
    "hash": "5b6429b76fde06fc4400bf3c27b3ae893ffb7a047f8b8ee9418a3bc77878d107"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
loki-ecmwf-0.3.6/INSTALL.md0000664000175000017500000003113315167130205015372 0ustar  alastairalastair# Installation

There are multiple different ways of installing Loki, tailored towards various
use-cases:

- via `pip install` as a pure Python package
- via the provided install script to ease the setup of optional dependencies
- via CMake/ecbuild to enable installation as part of a CMake project
- via ecbundle
- manually

Loki is a pure Python package that depends on a range of upstream packages. We
recommend to use a
[virtual environment](https://packaging.python.org/en/latest/guides/installing-using-pip-and-virtual-environments/#creating-a-virtual-environment)
to avoid conflicts with versions of system-wide installed packages. The
CMake/ecbuild installation method as well as the provided script `install` will
do this automatically.

## Requirements

- Python 3.9+ with virtualenv and pip
- For graphical output of Scheduler dependency graphs: graphviz

### Optional requirements

The following is required to use the OMNI frontend:

- JDK 1.8+ (can be installed using the install script or ecbuild)
- libxml2 (with headers)

## Installation without prior download

The easiest way to obtain a useable installation of Loki with the fparser frontend does not
require downloading the source code. Simply run the following commands:

```bash
python3 -m venv loki_env  # Create a virtual environment
source loki_env/bin/activate  # Activate the virtual environment

# Installation of the Loki core library
pip install "loki @ git+https://github.com/ecmwf-ifs/loki.git"

# Optional: Installation of IFS lint rules for the use as a linter
pip install "lint_rules @ git+https://github.com/ecmwf-ifs/loki.git#subdirectory=lint_rules"
```

This makes the Python package available and installs the scripts `loki-transform.py` and `loki-lint.py` on the PATH.


## Installation from source

After downloading the source code, e.g., via

```bash
git clone https://github.com/ecmwf-ifs/loki.git
```

enter the created source directory and choose one of the following installation methods.

### Installation with pip

This yields an installation that uses the fparser frontend and is suitable for
development of transformations and working with the example notebooks:

```bash
python3 -m venv loki_env  # Create a virtual environment
source loki_env/bin/activate  # Activate the virtual environment

# Installation of the Loki core library
# Optional:
#   * Add `-e` to obtain an editable install that allows modifying the
#     source files without having to re-install the package
#   * Enable the following options by providing them as a comma-separated
#     list in square brackets behind the `.`:
#     * tests    - allows running the Loki test suite
#     * examples - installs dependencies to run the example notebooks
#     * docs     - installs dependencies required to generate the Sphinx documentation
#     * dace     - installs DaCe
pip install .

# Optional: Installation of IFS lint rules for the use as a linter
#           (again optionally with `-e` for an editable install)
pip install ./lint_rules
```

### Installation using install script

The provided `install` script can be used to install Loki with selected
dependencies inside a local virtual environment `loki_env`. This is the
recommended way when additional optional dependencies, such as the OMNI
frontend, are required.

After downloading Loki, call the script with `-h` to display usage information:

```text
$ ./install -h
Loki install script. This installs Loki and selected dependencies.

Usage: ./install [-v] [--hpc2020] [--use-venv[=]] [--with-*] [...]

Available options:
  -h / --help                  Display this help message
  -v                           Enable verbose output
  --hpc2020                    Load HPC2020 (Atos) specific modules and settings
  --use-venv[=]          Use existing virtual environment at 
  --with[out]-jdk              Install JDK instead of using system version (default: use system version)
  --with[out]-omni             Install OMNI Compiler (default: disabled)
  --with[out]-dace             Install DaCe (default: enabled)
  --with[out]-tests            Install dependencies to run tests (default: enabled)
  --with[out]-docs             Install dependencies to generate documentation (default: disabled)
  --with[out]-examples         Install dependencies to run the example notebooks (default: enabled)
```

On the ECMWF Atos HPC facility, the `--hpc2020` flag is recommended as it loads
required modules.  Omitting all (other) options (i.e., any of the `--with-*`)
will install only the Fparser2 frontend.

After completion, this script writes a `loki-activate` file that can be sourced
to bring up the virtual environment and set paths for the external dependencies.

#### Examples:

The default command on ECMWF's Atos HPC facility for a full stack installation is

```bash
./install --hpc2020 --with-omni
```

On standard Linux hosts with up-to-date JDK and ant, it is as easy as

```bash
./install --with-omni
```

To update the installation (e.g., to add JDK), the existing virtual environment can be provided, e.g.,

```bash
./install --with-omni --with-jdk --use-venv=loki_env
```

### Installation using CMake/ecbuild

Loki and dependencies (excluding OpenFortranParser) can be installed using
[ecbuild](https://github.com/ecmwf/ecbuild) (a set of CMake macros and a wrapper
around CMake). This requires ecbuild 3.7+ and CMake 3.19+.

```bash
cmake -DCMAKE_MODULE_PATH=/cmake -S  -B 
cmake --build 
```

The following options are available and can be enabled/disabled by providing `-DENABLE_=`:

- `NO_INSTALL` (default: `OFF`): Do not install Loki but make the CMake
  functions below available. This is useful if Loki is available on the path from
  elsewhere and only the CMake integration is required
- `EDITABLE` (default: `OFF`): Install Loki in editable mode, i.e. without
  copying any files
- `OMNI` (default: `OFF`): Install the OMNI compiler as well as its
  Java dependencies as required. Note that this is an experimental setup and comes
  with no support or guarantees.

This method is also suitable to create a system-wide installation of Loki.
After running the above steps, install Loki to a chosen prefix using

```bash
cmake --install  --prefix 
```

*Note: Using this to install Loki system-wide does currently not install the OMNI frontend with it, even if the relevant ecbuild option is activated. It is recommended to install them separately, if required.*

The ecbuild installation method creates a virtual environment in the build
directory and downloads OpenJDK and Ant on-demand if no up-to-date versions have
been found. This installation method is particularly handy when used as a
subproject of a larger CMake build.

When used this way, it exports a number of CMake functions that can then be used
elsewhere:

- `loki_transform`: A wrapper for calls to `loki-transform.py` that takes care
  of automatically setting path and environment.
- `loki_transform_plan`: A wrapper for calls to `loki-transform.py` in `plan`
  mode to generate CMake plan files.
- `loki_transform_target`: A wrapper that takes care of calling the plan mode
  during configuration and applying bulk transformations at build time to a CMake
  target. This includes updates to the target's source file list as determined
  during the planning stage.
- `generate_xmod`: A wrapper for calls to OMNI's `F_Front` frontend to generate
  xmod dependency files.

This allows to apply transformations as part of the build process without the
need to take care of PATH handling on the user side. See the [CLOUDSC
dwarf](https://github.com/ecmwf-ifs/dwarf-p-cloudsc) for an example how this can
be used.

### Offline installation using CMake/ecbuild

When the CMake/ecbuild installation procedure is required on a system without
internet access, then the required Python wheels can be downloaded and transferred
to the target system. To do so, run the [`populate`](populate) script from the Loki
main directory.

This will download all required Python weels into a directory `artifacts`.
Transfer this wheelhouse directory to the target system and provide
`-DARTIFACTS_DIR=` to the CMake command when installing Loki.

The behaviour of this script can be customized using the following environment variables:

- `ARTIFACTS_DIR`: Choose a different target directory (default: `artifacts` in the current
  working directory)
- `LOKI_INSTALL_OPTIONS`: Add additional PIP install options to ensure dependencies for this
  are included in the wheelhouse. Most commonly required is `[tests]`.
- `LOKI_WHEEL_PYTHON_VERSION`: When using a different Python version to download the wheels
  than on the target system, specify the version here (e.g., `LOKI_WHEEL_PYTHON_VERSION=312`
  to request wheels for Python 3.12). See the
  [PIP documentation](https://pip.pypa.io/en/stable/cli/pip_download/#cmdoption-python-version)
  for more details.
- `LOKI_WHEEL_ARCH`: When the system that downloads the wheels uses a different architecture
  than the target system (e.g., an ARM-based MacBook is used to download wheels for a Linux
  x86_64 system), specify the target architecture here (e.g.,
  `LOKI_WHEEL_ARCH=manylinux_2_17_x86_64`). A list of typical platform tags are available
  [here](https://packaging.python.org/en/latest/specifications/platform-compatibility-tags/#platform-tag).


## Installation on MacOS

Although tailored to the Linux environment commonly found on HPC systems, Loki
can also be installed on MacOS.

This requires installing some additional dependencies using
[Brew](https://brew.sh) to allow running the Loki test suite:


```bash
# Install dependencies with brew
brew install gcc@13 graphviz python@3.11

# Install Loki using the install script
# NB: we explicitly select Python 3.11 (in case a newer version is the default)
#     by adding it in first place to the search path
PATH="$(brew --prefix)/opt/python@3.11/libexec/bin:$PATH" \
  CC=gcc-13 CXX=g++-13 FC=gfortran-13 \
  ./install --with-examples --with-tests --with-dace

# Amend the Loki environment with correct compiler variables
echo "export PATH=$(brew --prefix)/opt/python@3.11/libexec/bin:$(brew --prefix)/bin:${PATH}" | cat - loki-activate > loki-activate.tmp
mv loki-activate.tmp loki-activate
echo "export CC=gcc-13" >> loki-activate
echo "export CXX=g++-13" >> loki-activate
echo "export FC=gfortran-13" >> loki-activate
echo "export F90=gfortran-13" >> loki-activate
echo "export LD=gfortran-13" >> loki-activate

# Activate the virtual environment to run the tests
source loki-activate
pytest --pyargs loki
```

## Installation as part of an ecbundle bundle

Loki being installable by CMake/ecbuild makes it easy to integrate with
[ecbundle](https://github.com/ecmwf/ecbundle). Simply add the following to your
`bundle.yml`:

```yaml
projects :

  # ...other projects ...

  - loki :
    git     : https://github.com/ecmwf-ifs/loki
    version : main

```

See the [CLOUDSC dwarf](https://github.com/ecmwf-ifs/dwarf-p-cloudsc) for an
example how this can be used.

## Manual installation

The following outlines the manual steps for installing Loki using a virtual
environment. This installation method is not recommended but may be used when
maximum control over all steps is required or all of the above are not working.
You can create an empty directory and copy-paste the following steps to obtain a
working version:

### 1. Clone the Loki repository

```bash
git clone https://github.com/ecmwf-ifs/loki
```

### 2. Create and activate virtual environment

```bash
python3 -m venv loki_env
source loki_env/bin/activate
pip install --upgrade pip
```

Note that we need to make sure that we use a recent pip version (21.3 or newer)
that has support for editable installs using `pyproject.toml`.

### 3.  Install Loki and Python dependencies

```bash
pushd loki
pip install -e .[tests,examples]
pip install -e ./lint_rules
popd
```

### 4.  Install OMNI frontend -- optional

#### Option a: install latest xcodeml-tools

```bash
git clone --recursive --single-branch https://github.com/omni-compiler/xcodeml-tools.git xcodeml-tools
pushd xcodeml-tools
# Now build and install OMNI in the venv:
cmake -S . -B build -DCMAKE_INSTALL_PREFIX=../loki_env
cmake --build build
cmake --install build
popd
```

#### Option b: install (older) OMNI version and CLAW

```bash
git clone --recursive --single-branch --branch=mlange-dev https://github.com/mlange05/claw-compiler.git claw-compiler
pushd claw-compiler
# Now build and install CLAW in the venv:
cmake -S . -B build -DCMAKE_INSTALL_PREFIX=../loki_env
cmake --build build
cmake --install build
popd
```

### 5.  Verify everything is working

```bash
pushd loki
py.test transformations lint_rules .
popd
```

Note that the order is important to avoid clashes with conftest utilities.
loki-ecmwf-0.3.6/AGENTS.md0000664000175000017500000000245615167130205015253 0ustar  alastairalastair# AGENTS

## Purpose

This repository contains Loki source code and tests. When editing or adding tests, prefer assertions that match Loki's native IR and expression semantics rather than assertions that depend on rendered source formatting.

## Loki Test Assertions

- In structural IR tests, prefer native Loki node comparisons over `str(...)`.
- Loki symbols and expressions compare directly to strings, so prefer:
  - `node == 'a + b'` over `str(node) == 'a + b'`
  - `loop.variable == 'i'` over `str(loop.variable) == 'i'`
- Loki numeric literals compare directly to Python numbers, so prefer:
  - `loop.bounds.start == 1` over `str(loop.bounds.start) == '1'`
  - `literal == 5` over `str(literal) == '5'`
- When creating local test helpers, return native nodes whenever possible.
  - Good: `(assign.lhs, assign.rhs)`
  - Good: `(loop.variable, loop.bounds.start, loop.bounds.stop, loop.bounds.step)`
  - Avoid: `(str(assign.lhs), str(assign.rhs))`
  - Avoid: `(str(loop.variable), str(loop.bounds.start), str(loop.bounds.stop), ... )`
- Use stringification only when the test is explicitly about rendered output, pretty-printing, or a node type that does not compare reliably through Loki's native equality support.
- If stringification is still necessary in a structural test, keep it narrowly scoped and document why.
loki-ecmwf-0.3.6/.pylintrc0000664000175000017500000003745015167130205015617 0ustar  alastairalastair[MASTER]

# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code.
extension-pkg-whitelist=

# Add files or directories to the blacklist. They should be base names, not
# paths.
ignore=CVS

# Add files or directories matching the regex patterns to the blacklist. The
# regex matches against base names, not paths.
# ignore-patterns=

# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
#init-hook=

# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
# number of processors available to use.
jobs=1

# Control the amount of potential inferred values when inferring a single
# object. This can help the performance when dealing with large functions or
# complex, nested conditions.
limit-inference-results=100

# List of plugins (as comma separated values of python module names) to load,
# usually to register additional checkers.
load-plugins=

# Pickle collected data for later comparisons.
persistent=yes

# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no


[MESSAGES CONTROL]

# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED.
confidence=

# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once). You can also use "--disable=all" to
# disable everything first and then reenable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use "--disable=all --enable=classes
# --disable=W".
disable=W0511,  # Disable TODO warnings
        unspecified-encoding,  # It's fairly sure to assume we'll only support Linux
        duplicate-code,  # That's something that happens with tests and visitors...
        invalid-name,
        missing-module-docstring,
        missing-class-docstring,
        missing-function-docstring,
        raw-checker-failed,
        bad-inline-option,
        locally-disabled,
        file-ignored,
        suppressed-message,
        useless-suppression,
        deprecated-pragma,
        use-symbolic-message-instead,
        too-many-instance-attributes,
        too-few-public-methods,
        too-many-public-methods,
        too-many-return-statements,
        too-many-branches,
        too-many-arguments,
        too-many-positional-arguments,
        too-many-locals,
        too-many-statements,
        too-many-nested-blocks,
        pointless-string-statement,
        attribute-defined-outside-init,
        protected-access,

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
enable=c-extension-no-member


[REPORTS]

# Python expression which should return a score less than or equal to 10. You
# have access to the variables 'error', 'warning', 'refactor', and 'convention'
# which contain the number of messages in each category, as well as 'statement'
# which is the total number of statements analyzed. This score is used by the
# global evaluation report (RP0004).
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)

# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details.
#msg-template=

# Set the output format. Available formats are text, parseable, colorized, json
# and msvs (visual studio). You can also give a reporter class, e.g.
# mypackage.mymodule.MyReporterClass.
output-format=text

# Tells whether to display a full report or only the messages.
reports=no

# Activate the evaluation score.
score=yes


[REFACTORING]

# Maximum number of nested blocks for function / method body
max-nested-blocks=5

# Complete name of functions that never returns. When checking for
# inconsistent-return-statements if a never returning function is called then
# it will be considered as an explicit return statement and no message will be
# printed.
never-returning-functions=sys.exit


[TYPECHECK]

# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager

# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=

# Tells whether missing members accessed in mixin class should be ignored. A
# mixin class is detected if its name ends with "mixin" (case insensitive).
ignore-mixin-members=yes

# Tells whether to warn about missing members when the owner of the attribute
# is inferred to be None.
ignore-none=yes

# This flag controls whether pylint should warn about no-member and similar
# checks whenever an opaque object is returned when inferring. The inference
# can return multiple potential results while evaluating a Python object, but
# some branches might not be evaluated, which results in partial inference. In
# that case, it might be useful to still emit no-member and other checks for
# the rest of the inferred objects.
ignore-on-opaque-inference=yes

# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local

# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis). It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=

# Show a hint with possible names when a member name was not found. The aspect
# of finding the hint is based on edit distance.
missing-member-hint=yes

# The minimum edit distance a name should have in order to be considered a
# similar match for a missing member name.
missing-member-hint-distance=1

# The total number of similar names that should be taken in consideration when
# showing a hint for a missing member.
missing-member-max-choices=1

# List of decorators that change the signature of a decorated function.
signature-mutators=


[SPELLING]

# Limits count of emitted suggestions for spelling mistakes.
max-spelling-suggestions=4

# Spelling dictionary name. Available dictionaries: none. To make it work,
# install the python-enchant package.
spelling-dict=

# List of comma separated words that should not be checked.
spelling-ignore-words=

# A path to a file that contains the private dictionary; one word per line.
spelling-private-dict-file=

# Tells whether to store unknown words to the private dictionary (see the
# --spelling-private-dict-file option) instead of raising a message.
spelling-store-unknown-words=no


[MISCELLANEOUS]

# List of note tags to take in consideration, separated by a comma.
notes=FIXME,
      XXX,
      TODO


[SIMILARITIES]

# Ignore comments when computing similarities.
ignore-comments=yes

# Ignore docstrings when computing similarities.
ignore-docstrings=yes

# Ignore imports when computing similarities.
ignore-imports=no

# Minimum lines number of a similarity.
min-similarity-lines=4


[LOGGING]

# Format style used to check logging format string. `old` means using %
# formatting, `new` is for `{}` formatting,and `fstr` is for f-strings.
logging-format-style=old

# Logging modules to check that the string format arguments are in logging
# function parameter format.
logging-modules=logging


[BASIC]

# Naming style matching correct argument names.
argument-naming-style=snake_case

# Regular expression matching correct argument names. Overrides argument-
# naming-style.
#argument-rgx=

# Naming style matching correct attribute names.
attr-naming-style=snake_case

# Regular expression matching correct attribute names. Overrides attr-naming-
# style.
#attr-rgx=

# Bad variable names which should always be refused, separated by a comma.
bad-names=foo,
          bar,
          baz,
          toto,
          tutu,
          tata

# Naming style matching correct class attribute names.
class-attribute-naming-style=any

# Regular expression matching correct class attribute names. Overrides class-
# attribute-naming-style.
#class-attribute-rgx=

# Naming style matching correct class names.
class-naming-style=PascalCase

# Regular expression matching correct class names. Overrides class-naming-
# style.
#class-rgx=

# Naming style matching correct constant names.
const-naming-style=UPPER_CASE

# Regular expression matching correct constant names. Overrides const-naming-
# style.
#const-rgx=

# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=-1

# Naming style matching correct function names.
function-naming-style=snake_case

# Regular expression matching correct function names. Overrides function-
# naming-style.
#function-rgx=

# Good variable names which should always be accepted, separated by a comma.
good-names=i,
           j,
           k,
           ex,
           Run,
           _

# Include a hint for the correct naming format with invalid-name.
include-naming-hint=no

# Naming style matching correct inline iteration names.
inlinevar-naming-style=any

# Regular expression matching correct inline iteration names. Overrides
# inlinevar-naming-style.
#inlinevar-rgx=

# Naming style matching correct method names.
method-naming-style=snake_case

# Regular expression matching correct method names. Overrides method-naming-
# style.
#method-rgx=

# Naming style matching correct module names.
module-naming-style=snake_case

# Regular expression matching correct module names. Overrides module-naming-
# style.
#module-rgx=

# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=

# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=^_

# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
# These decorators are taken in consideration only for invalid-name.
property-classes=abc.abstractproperty

# Naming style matching correct variable names.
variable-naming-style=snake_case

# Regular expression matching correct variable names. Overrides variable-
# naming-style.
#variable-rgx=


[FORMAT]

# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=

# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=^\s*(# )??$

# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4

# String used as indentation unit. This is usually "    " (4 spaces) or "\t" (1
# tab).
indent-string='    '

# Maximum number of characters on a single line.
max-line-length=120

# Maximum number of lines in a module.
max-module-lines=1500

# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
single-line-class-stmt=no

# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=no


[VARIABLES]

# List of additional names supposed to be defined in builtins. Remember that
# you should avoid defining new builtins when possible.
additional-builtins=

# Tells whether unused global variables should be treated as a violation.
allow-global-unused-variables=yes

# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,
          _cb

# A regular expression matching the name of dummy variables (i.e. expected to
# not be used).
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_

# Argument names that match this expression will be ignored. Default to name
# with leading underscore.
ignored-argument-names=_.*|^ignored_|^unused_

# Tells whether we should check for unused import in __init__ files.
init-import=no

# List of qualified module names which can have objects that can redefine
# builtins.
redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io


[STRING]

# This flag controls whether the implicit-str-concat-in-sequence should
# generate a warning on implicit string concatenation in sequences defined over
# several lines.
check-str-concat-over-line-jumps=no


[IMPORTS]

# List of modules that can be imported at any level, not just the top level
# one.
allow-any-import-level=

# Allow wildcard imports from modules that define __all__.
allow-wildcard-with-all=no

# Analyse import fallback blocks. This can be used to support both Python 2 and
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no

# Deprecated modules which should not be used, separated by a comma.
deprecated-modules=optparse,tkinter.tix

# Create a graph of external dependencies in the given file (report RP0402 must
# not be disabled).
ext-import-graph=

# Create a graph of every (i.e. internal and external) dependencies in the
# given file (report RP0402 must not be disabled).
import-graph=

# Create a graph of internal dependencies in the given file (report RP0402 must
# not be disabled).
int-import-graph=

# Force import order to recognize a module as part of the standard
# compatibility libraries.
known-standard-library=

# Force import order to recognize a module as part of a third party library.
known-third-party=enchant

# Couples of modules and preferred modules, separated by a comma.
preferred-modules=


[CLASSES]

# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,
                      __new__,
                      setUp,
                      __post_init__

# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,
                  _fields,
                  _replace,
                  _source,
                  _make

# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls

# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=cls


[DESIGN]

# Maximum number of arguments for function / method.
max-args=5

# Maximum number of attributes for a class (see R0902).
max-attributes=7

# Maximum number of boolean expressions in an if statement (see R0916).
max-bool-expr=5

# Maximum number of branch for function / method body.
max-branches=12

# Maximum number of locals for function / method body.
max-locals=15

# Maximum number of parents for a class (see R0901).
max-parents=7

# Maximum number of public methods for a class (see R0904).
max-public-methods=20

# Maximum number of return / yield for function / method body.
max-returns=6

# Maximum number of statements in function / method body.
max-statements=50

# Minimum number of public methods for a class (see R0903).
min-public-methods=2


[EXCEPTIONS]

# Exceptions that will emit a warning when being caught. Defaults to
# "BaseException, Exception".
overgeneral-exceptions=builtin.BaseException,
                       builtin.Exception
loki-ecmwf-0.3.6/.pylintrc_ipynb0000664000175000017500000004014515167130205017013 0ustar  alastairalastair[MASTER]

# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code.
extension-pkg-whitelist=

# Add files or directories to the blacklist. They should be base names, not
# paths.
ignore=CVS

# Add files or directories matching the regex patterns to the blacklist. The
# regex matches against base names, not paths.
# ignore-patterns=

# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
#init-hook=

# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
# number of processors available to use.
jobs=1

# Control the amount of potential inferred values when inferring a single
# object. This can help the performance when dealing with large functions or
# complex, nested conditions.
limit-inference-results=100

# List of plugins (as comma separated values of python module names) to load,
# usually to register additional checkers.
load-plugins=

# Pickle collected data for later comparisons.
persistent=yes

# Specify a configuration file.
#rcfile=

# When enabled, pylint would attempt to guess common misconfiguration and emit
# user-friendly hints instead of false-positive error messages.
suggestion-mode=yes

# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no


[MESSAGES CONTROL]

# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED.
confidence=

# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once). You can also use "--disable=all" to
# disable everything first and then reenable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use "--disable=all --enable=classes
# --disable=W".
disable=W0511,  # Disable TODO warnings
        unspecified-encoding,  # It's fairly sure to assume we'll only support Linux
        duplicate-code,  # That's something that happens with tests and visitors...
        invalid-name,
        missing-module-docstring,
        missing-class-docstring,
        missing-function-docstring,
        raw-checker-failed,
        bad-inline-option,
        locally-disabled,
        file-ignored,
        suppressed-message,
        useless-suppression,
        deprecated-pragma,
        use-symbolic-message-instead,
        too-many-instance-attributes,
        too-few-public-methods,
        too-many-public-methods,
        too-many-return-statements,
        too-many-branches,
        too-many-arguments,
        too-many-locals,
        too-many-statements,
        too-many-nested-blocks,
        pointless-string-statement,
        attribute-defined-outside-init,
        protected-access,
        trailing-whitespace,
        trailing-newlines,
        line-too-long,
        pointless-statement,
        wrong-import-position,
        wrong-import-order

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
enable=c-extension-no-member


[REPORTS]

# Python expression which should return a score less than or equal to 10. You
# have access to the variables 'error', 'warning', 'refactor', and 'convention'
# which contain the number of messages in each category, as well as 'statement'
# which is the total number of statements analyzed. This score is used by the
# global evaluation report (RP0004).
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)

# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details.
#msg-template=

# Set the output format. Available formats are text, parseable, colorized, json
# and msvs (visual studio). You can also give a reporter class, e.g.
# mypackage.mymodule.MyReporterClass.
output-format=text

# Tells whether to display a full report or only the messages.
reports=no

# Activate the evaluation score.
score=yes


[REFACTORING]

# Maximum number of nested blocks for function / method body
max-nested-blocks=5

# Complete name of functions that never returns. When checking for
# inconsistent-return-statements if a never returning function is called then
# it will be considered as an explicit return statement and no message will be
# printed.
never-returning-functions=sys.exit


[TYPECHECK]

# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager

# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=

# Tells whether missing members accessed in mixin class should be ignored. A
# mixin class is detected if its name ends with "mixin" (case insensitive).
ignore-mixin-members=yes

# Tells whether to warn about missing members when the owner of the attribute
# is inferred to be None.
ignore-none=yes

# This flag controls whether pylint should warn about no-member and similar
# checks whenever an opaque object is returned when inferring. The inference
# can return multiple potential results while evaluating a Python object, but
# some branches might not be evaluated, which results in partial inference. In
# that case, it might be useful to still emit no-member and other checks for
# the rest of the inferred objects.
ignore-on-opaque-inference=yes

# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local

# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis). It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=

# Show a hint with possible names when a member name was not found. The aspect
# of finding the hint is based on edit distance.
missing-member-hint=yes

# The minimum edit distance a name should have in order to be considered a
# similar match for a missing member name.
missing-member-hint-distance=1

# The total number of similar names that should be taken in consideration when
# showing a hint for a missing member.
missing-member-max-choices=1

# List of decorators that change the signature of a decorated function.
signature-mutators=


[SPELLING]

# Limits count of emitted suggestions for spelling mistakes.
max-spelling-suggestions=4

# Spelling dictionary name. Available dictionaries: none. To make it work,
# install the python-enchant package.
spelling-dict=

# List of comma separated words that should not be checked.
spelling-ignore-words=

# A path to a file that contains the private dictionary; one word per line.
spelling-private-dict-file=

# Tells whether to store unknown words to the private dictionary (see the
# --spelling-private-dict-file option) instead of raising a message.
spelling-store-unknown-words=no


[MISCELLANEOUS]

# List of note tags to take in consideration, separated by a comma.
notes=FIXME,
      XXX,
      TODO


[SIMILARITIES]

# Ignore comments when computing similarities.
ignore-comments=yes

# Ignore docstrings when computing similarities.
ignore-docstrings=yes

# Ignore imports when computing similarities.
ignore-imports=no

# Minimum lines number of a similarity.
min-similarity-lines=4


[LOGGING]

# Format style used to check logging format string. `old` means using %
# formatting, `new` is for `{}` formatting,and `fstr` is for f-strings.
logging-format-style=old

# Logging modules to check that the string format arguments are in logging
# function parameter format.
logging-modules=logging


[BASIC]

# Naming style matching correct argument names.
argument-naming-style=snake_case

# Regular expression matching correct argument names. Overrides argument-
# naming-style.
#argument-rgx=

# Naming style matching correct attribute names.
attr-naming-style=snake_case

# Regular expression matching correct attribute names. Overrides attr-naming-
# style.
#attr-rgx=

# Bad variable names which should always be refused, separated by a comma.
bad-names=foo,
          bar,
          baz,
          toto,
          tutu,
          tata

# Naming style matching correct class attribute names.
class-attribute-naming-style=any

# Regular expression matching correct class attribute names. Overrides class-
# attribute-naming-style.
#class-attribute-rgx=

# Naming style matching correct class names.
class-naming-style=PascalCase

# Regular expression matching correct class names. Overrides class-naming-
# style.
#class-rgx=

# Naming style matching correct constant names.
const-naming-style=UPPER_CASE

# Regular expression matching correct constant names. Overrides const-naming-
# style.
#const-rgx=

# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=-1

# Naming style matching correct function names.
function-naming-style=snake_case

# Regular expression matching correct function names. Overrides function-
# naming-style.
#function-rgx=

# Good variable names which should always be accepted, separated by a comma.
good-names=i,
           j,
           k,
           ex,
           Run,
           _

# Include a hint for the correct naming format with invalid-name.
include-naming-hint=no

# Naming style matching correct inline iteration names.
inlinevar-naming-style=any

# Regular expression matching correct inline iteration names. Overrides
# inlinevar-naming-style.
#inlinevar-rgx=

# Naming style matching correct method names.
method-naming-style=snake_case

# Regular expression matching correct method names. Overrides method-naming-
# style.
#method-rgx=

# Naming style matching correct module names.
module-naming-style=snake_case

# Regular expression matching correct module names. Overrides module-naming-
# style.
#module-rgx=

# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=

# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=^_

# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
# These decorators are taken in consideration only for invalid-name.
property-classes=abc.abstractproperty

# Naming style matching correct variable names.
variable-naming-style=snake_case

# Regular expression matching correct variable names. Overrides variable-
# naming-style.
#variable-rgx=


[FORMAT]

# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=

# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=^\s*(# )??$

# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4

# String used as indentation unit. This is usually "    " (4 spaces) or "\t" (1
# tab).
indent-string='    '

# Maximum number of characters on a single line.
max-line-length=120

# Maximum number of lines in a module.
max-module-lines=1500

# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
single-line-class-stmt=no

# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=no


[VARIABLES]

# List of additional names supposed to be defined in builtins. Remember that
# you should avoid defining new builtins when possible.
additional-builtins=

# Tells whether unused global variables should be treated as a violation.
allow-global-unused-variables=yes

# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,
          _cb

# A regular expression matching the name of dummy variables (i.e. expected to
# not be used).
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_

# Argument names that match this expression will be ignored. Default to name
# with leading underscore.
ignored-argument-names=_.*|^ignored_|^unused_

# Tells whether we should check for unused import in __init__ files.
init-import=no

# List of qualified module names which can have objects that can redefine
# builtins.
redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io


[STRING]

# This flag controls whether the implicit-str-concat-in-sequence should
# generate a warning on implicit string concatenation in sequences defined over
# several lines.
check-str-concat-over-line-jumps=no


[IMPORTS]

# List of modules that can be imported at any level, not just the top level
# one.
allow-any-import-level=

# Allow wildcard imports from modules that define __all__.
allow-wildcard-with-all=no

# Analyse import fallback blocks. This can be used to support both Python 2 and
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no

# Deprecated modules which should not be used, separated by a comma.
deprecated-modules=optparse,tkinter.tix

# Create a graph of external dependencies in the given file (report RP0402 must
# not be disabled).
ext-import-graph=

# Create a graph of every (i.e. internal and external) dependencies in the
# given file (report RP0402 must not be disabled).
import-graph=

# Create a graph of internal dependencies in the given file (report RP0402 must
# not be disabled).
int-import-graph=

# Force import order to recognize a module as part of the standard
# compatibility libraries.
known-standard-library=

# Force import order to recognize a module as part of a third party library.
known-third-party=enchant

# Couples of modules and preferred modules, separated by a comma.
preferred-modules=


[CLASSES]

# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,
                      __new__,
                      setUp,
                      __post_init__

# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,
                  _fields,
                  _replace,
                  _source,
                  _make

# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls

# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=cls


[DESIGN]

# Maximum number of arguments for function / method.
max-args=5

# Maximum number of attributes for a class (see R0902).
max-attributes=7

# Maximum number of boolean expressions in an if statement (see R0916).
max-bool-expr=5

# Maximum number of branch for function / method body.
max-branches=12

# Maximum number of locals for function / method body.
max-locals=15

# Maximum number of parents for a class (see R0901).
max-parents=7

# Maximum number of public methods for a class (see R0904).
max-public-methods=20

# Maximum number of return / yield for function / method body.
max-returns=6

# Maximum number of statements in function / method body.
max-statements=50

# Minimum number of public methods for a class (see R0903).
min-public-methods=2


[EXCEPTIONS]

# Exceptions that will emit a warning when being caught. Defaults to
# "BaseException, Exception".
overgeneral-exceptions=BaseException,
                       Exception
loki-ecmwf-0.3.6/populate0000775000175000017500000000217015167130205015520 0ustar  alastairalastair#!/usr/bin/env bash

# (C) Copyright 2024- ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

if [[ $BASH_SOURCE = */* ]]; then
    SOURCE_DIR=${BASH_SOURCE%/*}/
else
    SOURCE_DIR=./
fi

ARTIFACTS_DIR=${ARTIFACTS_DIR:-"${SOURCE_DIR}/artifacts"}

# Download dependencies for Python packages in this repository
cmake \
    -DWHEELS_DIR=${ARTIFACTS_DIR} -DREQUIREMENT_SPEC=${SOURCE_DIR}${LOKI_INSTALL_OPTIONS:-} \
    -DLOKI_WHEEL_ARCH=${LOKI_WHEEL_ARCH:-None} -DLOKI_WHEEL_PYTHON_VERSION=${LOKI_WHEEL_PYTHON_VERSION:-None} \
    -P ${SOURCE_DIR}/cmake/loki_get_python_wheels.cmake

cmake \
    -DWHEELS_DIR=${ARTIFACTS_DIR} -DREQUIREMENT_SPEC=${SOURCE_DIR}/lint_rules \
    -DLOKI_WHEEL_ARCH=${LOKI_WHEEL_ARCH:-None} -DLOKI_WHEEL_PYTHON_VERSION=${LOKI_WHEEL_PYTHON_VERSION:-None} \
    -P ${SOURCE_DIR}/cmake/loki_get_python_wheels.cmake
loki-ecmwf-0.3.6/docs/0000775000175000017500000000000015167130205014671 5ustar  alastairalastairloki-ecmwf-0.3.6/docs/sites-manager.py0000775000175000017500000001204515167130205020007 0ustar  alastairalastair#!/usr/bin/env python3

# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import os
import click
from sites.sdk import SitesClient
from sites.sdk.sites import Site, Authenticator
from sites.sdk.sites.site_content import ApiCallException

class NotRequiredIf(click.Option):
    """
    Custom option class that makes an option not required if some condition
    is fulfilled.

    Source: https://stackoverflow.com/a/44349292
    """
    def __init__(self, *args, **kwargs):
        self.not_required_if = kwargs.pop('not_required_if')
        assert self.not_required_if, "'not_required_if' parameter required"
        kwargs['help'] = (kwargs.get('help', '') +
            f' NOTE: This argument is mutually exclusive with {self.not_required_if}'
        ).strip()
        super().__init__(*args, **kwargs)

    def handle_parse_result(self, ctx, opts, args):
        we_are_present = self.name in opts
        other_present = self.not_required_if in opts

        if other_present:
            if we_are_present:
                raise click.UsageError(
                    f"Illegal usage: `{self.name}` is mutually exclusive with "
                    f"`{self.not_required_if}`"
                )
            self.prompt = None

        return super().handle_parse_result(ctx, opts, args)


def get_file_manager(ctx):
    """Connect to the Sites hub and instantiate a file manager for chosen space and site."""
    if not ctx.obj['space']:
        raise click.UsageError('No space given. Provide "--space" before the command.')
    if not ctx.obj['name']:
        raise click.UsageError('No site name given. Provide "--name" before the command.')

    # Authenticate against Site
    if ctx.obj['token']:
        sites_authenticator = Authenticator.from_token(token=ctx.obj['token'])
    else:
        if not ctx.obj['password']:
            ctx.obj['password'] = click.prompt('Password', hide_input=True)
        if not ctx.obj['totp']:
            ctx.obj['totp'] = click.prompt('TOTP', hide_input=False)
        sites_authenticator = Authenticator.from_credentials(
            username=ctx.obj['username'],
            password=ctx.obj['password'],
            otp=ctx.obj['totp']
        )

    # Connect to the selected site...
    site = Site.from_space_and_name(
        space=ctx.obj['space'], name=ctx.obj['name']
    )
    client = SitesClient(authenticator=sites_authenticator)

    # ...and create a file manager
    site_content_manager = client.content(site=site)

    return site_content_manager


@click.group()
@click.option('--space', help='Name of the space in the Sites hub')
@click.option('--name', help='Name of the site in the space')
@click.option('--token', help='API authentication token')
@click.option('--username', help='Username for authentication',
              default=lambda: os.environ.get('USER', ''), show_default='current user',
              cls=NotRequiredIf, not_required_if='token')
@click.option('--password', help='Password for authentication')
@click.option('--totp', help='One time password for authentication')
@click.pass_context
def cli(ctx, space, name, token, username, password, totp):
    """Interact with the Sites hub. """
    ctx.obj['space'] = space
    ctx.obj['name'] = name
    ctx.obj['token'] = token
    ctx.obj['username'] = username
    ctx.obj['password'] = password
    ctx.obj['totp'] = totp


@cli.command(short_help='Upload a local path')
@click.argument('local_path')
@click.argument('remote-path', required=False, default='')
@click.option('--clean/--no-clean', help='Clean upload path first', default=False)
@click.pass_context
def upload(ctx, local_path, remote_path, clean):
    """Upload LOCAL_PATH to REMOTE_PATH.

    LOCAL_PATH is a local file or directory that is uploaded into REMOTE_PATH
    in the sites hub. If no REMOTE_PATH is given, LOCAL_PATH is uploaded into
    the root folder of the site.
    """
    my_site_manager = get_file_manager(ctx)
    if clean:
        try:
            my_site_manager.delete(remote_path=remote_path, recursive=True)
        except ApiCallException:
            pass
    my_site_manager.upload(local_path=local_path, remote_path=remote_path, recursive=True)


@cli.command(short_help='Download a remote path')
@click.argument('remote-path')
@click.pass_context
def download(ctx, remote_path):
    """Download REMOTE_PATH."""
    my_site_manager = get_file_manager(ctx)
    my_site_manager.download(remote_path=remote_path)


@cli.command(short_help='Delete a remote path')
@click.argument('remote-path')
@click.pass_context
def delete(ctx, remote_path):
    """Delete REMOTE_PATH."""
    my_site_manager = get_file_manager(ctx)
    my_site_manager.delete(remote_path=remote_path, recursive=True)


if __name__ == "__main__":
    cli(obj={})  # pylint: disable=unexpected-keyword-arg,no-value-for-parameter
loki-ecmwf-0.3.6/docs/source/0000775000175000017500000000000015167130205016171 5ustar  alastairalastairloki-ecmwf-0.3.6/docs/source/backends.rst0000664000175000017500000000367715167130205020512 0ustar  alastairalastair======================
Generating source code
======================

.. important::
    Loki is still under active development and has not yet seen a stable
    release. Interfaces can change at any time, objects may be renamed, or
    concepts may be re-thought. Make sure to sync your work to the current
    release frequently by rebasing feature branches and upstreaming
    more general applicable work in the form of pull requests.


At the end of a source-to-source translation process the output source code
needs to be generated. Loki provides a number of different backends depending
on the target language, which, once again, are :doc:`Visitors `.

All backends are subclasses of :class:`Stringifier` that convert the internal
representation to a string representation in the syntax of the target language.

.. autosummary::

   loki.ir.pprint.Stringifier
   loki.ir.pprint.pprint

Typically, this includes also a custom mapper for expression trees as a
subclass of :any:`LokiStringifyMapper`. For convenience, each of these
visitors is wrapped in a corresponding utility routine that allows to generate
code for any IR object via a simple function call, for example:

.. code-block:: python

   routine = Subroutine(...)
   ...
   fcode = fgen(routine)

Currently, Loki has backends to generate Fortran, C, Python, and Maxeler MaxJ
(a Java dialect that targets FPGAs).

.. autosummary::

   loki.backend.fgen.fgen
   loki.backend.cufgen.cufgen
   loki.backend.cgen.cgen
   loki.backend.pygen.pygen
   loki.backend.dacegen.dacegen
   loki.backend.maxgen.maxjgen

.. warning::
   Backends do not make sure that the internal representation is
   compatible with the target language. Adapting the IR to the desired output
   format needs to be done before calling the relevant code generation routine.
   For language transpilation (e.g., Fortran to C), corresponding
   :doc:`transformations ` must be applied
   (e.g., :any:`FortranCTransformation`).
loki-ecmwf-0.3.6/docs/source/loki_api.rst0000664000175000017500000000017215167130205020512 0ustar  alastairalastair=============
API reference
=============

.. autosummary::
   :toctree:
   :recursive:

   loki
   scripts
   lint_rules
loki-ecmwf-0.3.6/docs/source/using_loki.rst0000664000175000017500000000046015167130205021066 0ustar  alastairalastair==========
Using Loki
==========

The following pages describe in more detail the fundamental concepts and API
design choices underpinning Loki:

.. toctree::
   :maxdepth: 1

   Internal representation 
   visitors
   frontends
   backends
   transform
   utils
   loki_scripts
loki-ecmwf-0.3.6/docs/source/conf.py0000664000175000017500000001122115167130205017465 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html

# pylint: disable=invalid-name,redefined-builtin

# -- Path setup --------------------------------------------------------------

# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
from importlib import metadata


# -- Project information -----------------------------------------------------

project = 'Loki'
copyright = '2018- European Centre for Medium-Range Weather Forecasts (ECMWF)'
author = 'Michael Lange, Balthasar Reuter'

# The full version, including alpha/beta/rc tags.
release = metadata.version('loki')
# The short X.Y version.
version = release


# -- General configuration ---------------------------------------------------

# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
    'sphinx.ext.autodoc',  # create documentation from docstrings
    'sphinx.ext.napoleon',  # understand docstrings also in other formats
    'sphinx.ext.autosummary',  # automatically compile lists of classes/functions
    'sphinx.ext.intersphinx',  # link to docs of other projects
    'sphinx.ext.autosectionlabel',  # allows to refer to sections using their title
#    'recommonmark',  # read markdown
    'sphinx_rtd_theme',  # read the docs theme
    'myst_parser',  # parse markdown files
    'nbsphinx',  # parse Jupyter notebooks
    'sphinx_design',  # cards, panels and dropdown content
]

autosummary_generate = True  # Turn on sphinx.ext.autosummary
html_show_sourcelink = False  # Remove 'view source code' from top of page (for html, not python)
add_module_names = False # Remove namespaces from class/method signatures

intersphinx_mapping = {
    'python': ('https://docs.python.org/3', None),
    'pymbolic': ('https://documen.tician.de/pymbolic/', None),
    'fparser': ('https://fparser.readthedocs.io/en/latest/', None)
}

# The file extensions of source files. Sphinx considers the files with
# this suffix as sources. The value can be a dictionary mapping file
# extensions to file types.
source_suffix = {
    '.rst': 'restructuredtext',
    '.md': 'markdown'
}

# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']

# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['**/tests/']

# Prefix each section label with the document it is in, followed by a colon
autosectionlabel_prefix_document = True

# -- Options for HTML output -------------------------------------------------

# The theme to use for HTML and HTML Help pages.  See the documentation for
# a list of builtin themes.
#
html_theme = 'sphinx_rtd_theme'

# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = []

html_theme_options = {
    'logo_only': False,
    'display_version': True,
    'prev_next_buttons_location': 'bottom',
    'style_external_links': False,
    'vcs_pageview_mode': 'view',
    'style_nav_header_background': '',
    # Toc options
    'collapse_navigation': True,
    'sticky_navigation': True,
    'navigation_depth': 4,
    'includehidden': True,
    'titles_only': False,
}

# -- Options for todo extension ----------------------------------------------

# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True


# -- Options for autodoc extension -------------------------------------------

autodoc_default_options = {
    'members': True,  # include members in the documentation
    'member-order': 'bysource',  # members in the order they appear in source
    'show-inheritance': True,  # list base classes
    'undoc-members': True,  # show also undocumented members
}
loki-ecmwf-0.3.6/docs/source/transform.rst0000664000175000017500000003275615167130205020753 0ustar  alastairalastair.. _transformations:

========================
Transformation pipelines
========================

.. important::
    Loki is still under active development and has not yet seen a stable
    release. Interfaces can change at any time, objects may be renamed, or
    concepts may be re-thought. Make sure to sync your work to the current
    release frequently by rebasing feature branches and upstreaming
    more general applicable work in the form of pull requests.

.. contents:: Contents
   :local:


Transformations
===============

Transformations are the building blocks of a transformation pipeline in Loki.
They encode the workflow of converting a :any:`Sourcefile` or an individual
program unit (such as :any:`Module` or :any:`Subroutine`) to the desired
output format.

A transformation can encode a single modification, combine multiple steps,
or call other transformations to create complex changes. If a transformation
depends on another transformation, inheritance can be used to combine them.

Every transformation in a pipeline should implement the interface defined by
:any:`Transformation`. It provides generic entry points for transforming
different objects and thus allows for batch processing. To implement a new
transformation, only one or all of the relevant methods
:any:`Transformation.transform_subroutine`,
:any:`Transformation.transform_module`, or :any:`Transformation.transform_file`
need to be implemented.

*Example*: A transformation that inserts a comment at the beginning of every
module and subroutine:

.. code-block:: python

    class InsertCommentTransformation(Transformation):

        def _insert_comment(self, program_unit):
            program_unit.spec.prepend(Comment(text='! Processed by Loki'))

        def transform_subroutine(self, routine, **kwargs):
            self._insert_comment(routine)

        def transform_module(self, module, **kwargs):
            self._insert_comment(module)

The transformation can be applied by calling :meth:`apply` with
the relevant object.

.. code-block:: python

   source = Sourcefile(...)  # may contain modules and subroutines
   transformation = InsertCommentTransformation()
   for module in source.modules:
       transformation.apply(module)
   for routine in source.all_subroutines:
       transformation.apply(routine)

Note that we have to apply the transformation separately for every
relevant :any:`ProgramUnit`. The transformation can also be modified
such that it is automatically applied to all program units in a file,
despite only implementing logic for transforming modules and subroutines:

.. code-block:: python

    class InsertCommentTransformation(Transformation):

        # When called on a Sourcefile, automatically apply this to all modules
        # in the file
        recurse_to_modules = True

        # When called on a Sourcefile or Module, automatically apply this to all
        # Subroutines in the file or module
        recurse_to_procedures = True

        def _insert_comment(self, program_unit):
            program_unit.spec.prepend(Comment(text='! Processed by Loki'))

        def transform_subroutine(self, routine, **kwargs):
            self._insert_comment(routine)

        def transform_module(self, module, **kwargs):
            self._insert_comment(module)

With these two attributes added, we can now apply the transformation to all
modules and procedures in a single command:

.. code-block:: python

   source = Sourcefile(...)  # may contain modules and subroutines
   transformation = InsertCommentTransformation()
   transformation.apply(source)

Most transformations, however, will only require modifying those parts of a file
that are part of the call tree that is to be transformed to avoid unexpected
side-effects.

Typically, transformations should be implemented by users to encode the
transformation pipeline for their individual use-case. However, Loki comes
with a growing number of built-in transformations that are implemented in
the :mod:`loki.transformations` namespace:

.. autosummary::

   loki.transformations

This includes also a number of tools for common transformation tasks that
are provided as functions that can be readily used when implementing new
transformations.

Batch processing large source trees
===================================

Transformations can be applied over source trees using the :any:`Scheduler`.
It is a work queue manager that automatically discovers source files in a list
of paths and builds a dependency graph from a given starting point.
This dependency graph includes all called procedures and imported modules.

Calling :any:`Scheduler.process` on a source tree and providing it with a
:any:`Transformation` applies this transformation to all files, modules, or
routines that appear in the dependency graph. The exact traversal
behaviour can be parameterized in the implementation of the :any:`Transformation`.
The behaviour modifications include:

* limiting the processing only to specific node types in the dependency graph
* reversing the traversal direction, i.e., called routines or imported
  modules are processed before their caller, such that the starting point/root
  of the dependency is processed last
* traversing the file graph, i.e., processing full source files rather than
  individual routines or modules
* automatic recursion into contained program units, e.g., processing also all
  procedures in a module after the module has been processed

When applying the transformation to an item in the source tree, the scheduler
provides certain information about the item to the transformation:

* the transformation mode (provided in the scheduler's config),
* the item's role (e.g., ``'driver'`` or ``'kernel'``, configurable via the
  scheduler's config), and
* targets (dependencies that are depended on by the currently processed item,
  and are included in the scheduler's tree, i.e., are processed, too).

.. note::
   The scheduler's dependency graph will include all dependency types it discovers.
   This includes not only control-flow dependencies via procedure calls, but also
   dependencies on other modules via the import of global variables, or dependencies
   on derived type definitions.

   However, for backwards-compatibility with the original scheduler implementation,
   only control-flow dependencies are followed and processed by default, and reported
   as ``items`` in :any:`Scheduler.items`. To remove this limitation, which is required
   e.g., for the :any:`GlobalVarOffloadTransformation`, the ``enable_imports`` option
   can be set to ``True``. This can be done in the ``[default]`` block of the config,
   or as a constructor argument in the :any:`Scheduler`.

The Scheduler's dependency graph
--------------------------------

The :any:`Scheduler` builds a dependency graph consisting of :any:`Item`
instances as nodes. Every item corresponds to a specific node in Loki's
internal representation.

The name of an item refers to a symbol using a fully-qualified name in the
format: ``#``. The ```` corresponds to
a Fortran module, in which a subroutine, interface or derived type is
declared. That declaration's name (e.g., the name of the subroutine)
constitutes the ```` part. For subroutines that are not embedded
into a module, the ```` is empty, i.e., the item's name starts with
a dash (``#``).

In most cases these IR nodes are scopes and the entry points for transformations:

* :any:`FileItem` corresponds to :any:`Sourcefile`
* :any:`ModuleItem` corresponds to :any:`Module`
* :any:`ProcedureItem` corresponds to :any:`Subroutine`

The remaining cases are items corresponding to IR nodes that constitute some
form of intermediate dependency, which are required to resolve the indirection
to the scope node:

* :any:`InterfaceItem` corresponding to :any:`Interface`, i.e., providing a
  callable target that resolves to one or multiple procedures that are defined
  in the interface.
* :any:`ProcedureBindingItem` corresponding to the :any:`ProcedureSymbol`
  that is declared in a :any:`Declaration` in a derived type. Similarly to
  interfaces, these resolve to one or multiple procedures that are defined in
  the procedure binding inside the derived type.
* :any:`TypeDefItem` corresponding to :any:`TypeDef`, which does not introduce
  a control flow dependency but is crucial to capture as a dependency to enable
  annotating type information for inter-procedural analysis.

Finally, :any:`ExternalItem` denotes items that the scheduler was unable to discover.
The expected item type of the missing item is stored in :any:`ExternalItem.origin_cls`.
When batch processing a transformation, the external items are ignored, unless the
config option ``strict=True`` is enabled. In that case, an error will be issued when
an external item is encountered that matches the ``item_filter`` that is provided by
the transformation's manifest (in :any:`Transformation.item_filter`).

To facilitate the creation of the dependency tree, every :any:`Item`
provides two key properties:

* :any:`Item.definitions`: A list of all IR nodes that constitute symbols/names
  that are made available by an item. For a :any:`FileItem`, this typically consists
  of all modules and procedures in that sourcefile, and for a :any:`ModuleItem` it
  comprises of procedures, interfaces, global variables and derived type definitions.
* :any:`Item.dependencies`: A list of all IR nodes that introduce a dependency
  on other items, e.g., :any:`CallStatement` or :any:`Import`.

This information is used to populate the scheduler's dependency graph, which is
constructed by the :any:`SGraph` class. Importantly, to improve processing speed
and limit parsing to the minimum of required files, this relies on incremental
parsing using the :any:`REGEX` frontend. Starting with only the top-level program
units in every discovered source file and a specified seed, the dependencies of each
item are used to determine the next set of items, which are generated on-demand
from the enclosing scope via partial re-parses. This may incur incremental parsing
with additional :any:`RegexParserClass` enabled to discover definitions or dependencies
as required. Only once the full dependency graph has been generated, a full parse
of the source files in the graph is performed, providing the complete internal
representation and automatically enriching type information with inter-procedural annotations.

Pruning the dependency graph
----------------------------

If the intention is not to process some items it is recommended to not
leave them dangling as :any:`ExternalItem`. Instead, they should be explicitly
excluded from the dependency graph and the ``strict`` mode enabled.
To exclude specific items, any of the following annotations can be used, resulting in
different behaviour:

* ``disable``: Dependency items matching an entry in this list are treated as if they
  don't exist, and their definitions are not searched for or parsed. This is useful, e.g.,
  to exclude frequently used utility routines or modules (such as the
  `yomhook module in IFS `_),
  which are not to be transformed.
* ``block``: Dependency items matching an entry in this list are not parsed or added to
  the dependency graph, and therefore excluded from transformations. They are, however,
  included for reference in the dependency graph visualization produced by
  :any:`Scheduler.callgraph`.
* ``ignore``: Dependency items matching an entry in this list are parsed and added to the
  dependency graph. This makes their definitions available for enrichment but they are
  not processed *by default*. Transformations can include them during batch processing
  by enabling the :any:`Transformation.process_ignored_items` option. A typical use case
  for this are dependencies that are part of a separate compilation target (and therefore
  transformed separately), but analysis passes may need to collect information across an
  entire call tree (e.g., use of temporary arrays).

These three lists can be supplied globally in the ``[default]`` section of the scheduler
config file, or per routine. The matching of items against entries in these lists is
supports basic patterns (via :any:`fnmatch`), and is also effective for entire scopes.
For example, a subroutine ``my_routine`` that is defined in a module ``my_mod`` would be
matched by any of the following:

* ``my_routine``
* ``my_mod``
* ``my_mod#my_routine``
* ``*_routine``

By default, all items are expanded during dependency discovery, i.e., for every item
all dependencies are added to the graph, and then dependencies of these dependencies are
added as well. This procedure continues until all dependencies have been included.
For individual items, this expansion can be disabled by setting ``expand=False`` for
them in the scheduler config.


Filtering graph traversals
--------------------------

Often, only specific item types are of interest when traversing the dependency graph.
For that purpose, the :any:`SFilter` class provides an iterator for an :any:`SGraph`,
which allows specifying an ``item_filter`` or reversing the direction of traversals.
Other traversal modes may be added in the future.



.. autosummary::

   loki.batch.scheduler.Scheduler
   loki.batch.scheduler.SGraph
   loki.batch.scheduler.SFilter
   loki.batch.configure.SchedulerConfig
   loki.batch.configure.TransformationConfig
   loki.batch.configure.ItemConfig
   loki.batch.item.Item
   loki.batch.item.FileItem
   loki.batch.item.ModuleItem
   loki.batch.item.ProcedureItem
   loki.batch.item.TypeDefItem
   loki.batch.item.ProcedureBindingItem
   loki.batch.item.InterfaceItem
   loki.batch.item.ItemFactory
loki-ecmwf-0.3.6/docs/source/utils.rst0000664000175000017500000000736615167130205020077 0ustar  alastairalastair=========
Utilities
=========

.. important::
    Loki is still under active development and has not yet seen a stable
    release. Interfaces can change at any time, objects may be renamed, or
    concepts may be re-thought. Make sure to sync your work to the current
    release frequently by rebasing feature branches and upstreaming
    more general applicable work in the form of pull requests.

.. contents:: Contents
   :local:


To assist the development of custom transformations, a number of useful tools
for recurring tasks or house keeping are included with Loki.


Pragma utilities
================

An easy way of injecting information at specific locations in source files
is to insert pragmas. This allows to annotate declarations or loops, mark
source code regions or specify locations. Pragmas are represented by a unique
node type :any:`Pragma` and thus can be picked out easily during a
transformation.

A number of utility routines and context manager are available that allow
for easy parsing of pragmas, can attach pragmas to other nodes (such as
:any:`Loop`), or extract pragma regions and wrap them in a dedicated internal
node type :any:`PragmaRegion`:

.. autosummary::

   loki.pragma_utils.is_loki_pragma
   loki.pragma_utils.get_pragma_parameters
   loki.pragma_utils.pragmas_attached
   loki.pragma_utils.pragma_regions_attached


Dataflow analysis
=================

Rudimentary dataflow analysis utilities are included with Loki that determine
for each IR node what symbols it defines (i.e., assigns a value), reads
(i.e., uses before defining it), and which symbols are live (i.e., have been
defined before) entering the IR node in the control flow.

.. autosummary::

   loki.analyse.analyse_dataflow.dataflow_analysis_attached


Dimensions
==========

With the modification of data layouts and iteration spaces as one of the core
tasks in many transformation pipelines in mind, Loki has a :any:`Dimension`
class to define such a one-dimensional space.

.. autosummary::

   loki.dimension.Dimension


Python utilities
================

Some convenience utility routines, e.g., to simplify working with strings
or files are included in :mod:`loki.tools`:

.. autosummary::

   loki.tools.files
   loki.tools.strings
   loki.tools.util

A notable example is :any:`CaseInsensitiveDict`, a `dict` with strings as keys
for which the case is ignored. It is repeatedly used in other Loki data
structures when mapping symbol names that stem from Fortran source code. Since
Fortran is not case-sensitive, these names can potentially appear with mixed
case yet all refer to the same symbol and :any:`CaseInsensitiveDict` makes sure
no problems arise due to that.

Other frequently used utilities for working with lists and tuples are
:any:`as_tuple`, :any:`is_iterable` and :any:`flatten`.


Loki house keeping
==================

For internal purposes exist a global configuration
:class:`loki.config.Configuration` and logging functionality.

.. autosummary::

   loki.config
   loki.logging


Build subpackage
================

As part of Loki's test suite but also useful as a standalone package are the
build utilities :mod:`loki.build`:

.. autosummary::

   loki.build.binary.Binary
   loki.build.header.Header
   loki.build.lib.Lib
   loki.build.obj.Obj
   loki.build.builder.Builder
   loki.build.compiler
   loki.build.max_compiler
   loki.build.workqueue


Linting functionality
=====================

The source analysis capabilities of Loki can be used to build a static source
code analysis tool for Fortran. This is being developed as a standalone script
:doc:`loki-lint ` and includes a few data structures for the linter
mechanics in :mod:`loki.lint`:

.. autosummary::

   loki.lint.linter.Linter
   loki.lint.reporter
   loki.lint.rules.GenericRule
   loki.lint.utils.Fixer
loki-ecmwf-0.3.6/docs/source/loki_scripts.rst0000664000175000017500000000012115167130205021422 0ustar  alastairalastairScripts
=======

.. toctree::
   :maxdepth: 1

   loki_lint
   loki_transform
..
loki-ecmwf-0.3.6/docs/source/index.md0000664000175000017500000000102015167130205017613 0ustar  alastairalastair```{toctree}
:hidden:

Home page 
getting_started
using_loki
loki_api
```

```{important}
Loki is still under active development and has not yet seen a stable
release. Interfaces can change at any time, objects may be renamed, or
concepts may be re-thought. Make sure to sync your work to the current
release frequently by rebasing feature branches and upstreaming
more general applicable work in the form of pull requests.
```

```{include} ../../README.md
```

## Indices and tables

- {ref}`genindex`
- {ref}`modindex`
loki-ecmwf-0.3.6/docs/source/programming_models.rst0000664000175000017500000000156415167130205022616 0ustar  alastairalastair==================
Programming models
==================

Loki directives
---------------

Loki uses an internal set of directives as an intermediate annotation for data movement
and parallelisation concepts. Transformations, such as the :any:`SCCAnnotateTransformation`,
insert these directives, or they can be written into the original Fortran source code.
The :any:`PragmaModelTransformation` should be used, as one of the final steps in a processing
pipeline, to translate these directives to the corresponding instructions for the chosen
programming model.

Currently, Loki supports OpenACC and some OpenMP. The following table gives a summary of how
Loki directives are translated to the corresponding pragmas in either programming model:

.. csv-table:: Loki generic pragmas to pragma model mapping
   :file: /loki_pragma_model.csv
   :widths: 100, 100, 100
   :header-rows: 1
loki-ecmwf-0.3.6/docs/source/frontends.rst0000664000175000017500000001224415167130205020730 0ustar  alastairalastair===========================
Reading Fortran source code
===========================

.. important::
    Loki is still under active development and has not yet seen a stable
    release. Interfaces can change at any time, objects may be renamed, or
    concepts may be re-thought. Make sure to sync your work to the current
    release frequently by rebasing feature branches and upstreaming
    more general applicable work in the form of pull requests.

.. contents:: Contents
   :local:


The first step in a transformation pipeline is reading Fortran source code
and converting it to :doc:`internal_representation`.


Parsing a file or string
========================

Typically, one has a source file that contains modules, functions and/or
subroutines. In Loki, this will be represented by a :any:`Sourcefile` that
stores the individual program units.
Reading source code from file is done via :any:`Sourcefile.from_file` and
requires only specifying the path to that file:

.. code-block:: python

   source = Sourcefile.from_file('/path/to/source/file.f90')

Optionally, a number of parameters can or sometimes should be supplied:

* If symbols from other modules are imported in the source file
  and type or procedure information is required (e.g., to inline constant
  parameters), a list of :any:`Module` objects can be provided via
  :data:`definitions`.
* When there are C-preprocessor macros or includes used in the source file,
  a C-preprocessor (`pcpp `_) can be applied
  to the file before reading it. For that, include paths and macro definitions
  can be specified.
* Choosing a different frontend (see below).

See the description of :any:`Sourcefile.from_file` for a description of all
available options.

As an alternative to reading source files one can also parse a Python string
directly using :any:`Sourcefile.from_source`. With that, it is also possible to
directly create modules or subroutines using the common parent routine
:any:`ProgramUnit.from_source`.  This is particularly useful for writing tests,
avoiding the detour via an external file.

.. code-block:: python

   fcode = """
   subroutine axpy(a, x, y)
     real, intent(in) :: a, x(:)
     real, intent(inout) :: y(:)

     y(:) = a * x(:) + y(:)
   end subroutine axpy
   """.strip()
   routine = Subroutine.from_source(fcode)

In rare cases (e.g., when parsing a pragma annotation), parsing a standalone
expression may be required. Experimental support for that is provided via
the utility function :any:`parse_fparser_expression`.

Frontends
=========

Three different externally developed frontends are currently supported, each
of them with individual advantages and shortcomings:

* `Fparser 2 `_, developed by STFC as a
  rewrite of the original fparser that is included in
  `f2py `_, (now a part of numpy).
  It is written in pure Python, supports Fortran 2003 and some Fortran 2008,
  and is actively maintained. The default frontend in Loki.
* `Omni Compiler Frontend `_, developed in the
  Omni Compiler Project. It is written in Java, supports Fortran 2008 and
  is also used in the `CLAW compiler `_.
  Compared to the other frontends, OMNI performs a lot of transformations
  internally (unifies case, propagates constants, inlines statement
  functions, etc.), thus prevents string reproducibility. Biggest drawback
  is the very rigorous dependency chasing (with custom ``.xmod`` files), that
  disallows dangling symbol definitions via imports and therefore prevents
  partial source tree processing.
* `Open Fortran Parser `_
  with a customized
  `Python wrapper `_.
  It is written in Java, claims Fortran 2008 support, and is also part of the
  `ROSE Compiler framwork `_. It is lacking support
  for some Fortran features, notably slower than the other frontends and not
  actively developed at the moment.

.. important::
   By default, Loki uses Fparser 2.

.. autosummary::

   loki.frontend.util.Frontend

When invoked, every frontend produces an abstract syntax tree that is then
transformed to Loki's own internal representation.

.. autosummary::

   loki.frontend.fparser
   loki.frontend.omni
   loki.frontend.ofp


Preprocessing
=============

When reading a source file, a C99-preprocessor can be applied to the file
before passing it to the frontend. This can be enabled by specifying
:data:`preprocess` when calling `Sourcefile.from_file`. The corresponding
routine carrying out the preprocessing is :any:`preprocess_cpp`.

Source sanitization
===================

Internally, Loki performs also another kind of preprocessing to work around
known shortcomings in frontends. This is done via a regex-based replacement
of known incompatibilities that are later-on reinserted into the Loki IR.
This preprocessing step is applied automatically and does not require any
user intervention.

.. autosummary::

   loki.frontend.preprocessing.sanitize_input
   loki.frontend.preprocessing.sanitize_registry
   loki.frontend.preprocessing.PPRule
loki-ecmwf-0.3.6/docs/source/example0000777000175000017500000000000015167130205021546 2../../exampleustar  alastairalastairloki-ecmwf-0.3.6/docs/source/visitors.rst0000664000175000017500000002207615167130205020614 0ustar  alastairalastair===================
Working with the IR
===================

.. important::
    Loki is still under active development and has not yet seen a stable
    release. Interfaces can change at any time, objects may be renamed, or
    concepts may be re-thought. Make sure to sync your work to the current
    release frequently by rebasing feature branches and upstreaming
    more general applicable work in the form of pull requests.

.. contents:: Contents
   :local:

The most important tool for working with
:doc:`Loki's internal representation ` are utilities
that traverse the IR to find specific nodes or patterns, to modify or
replace subtrees, or to annotate the tree. In Loki there exist two types of
tree traversal tools, depending on which level of the IR they operate on:

* *Visitors* that traverse the tree of control flow nodes;
* *Mappers* (following Pymbolic's :py:mod:`pymbolic.mapper` naming
  convention) that traverse expression trees.

Visitors
========

Loki's visitors work by inspecting the type of each IR node they encounter and
then selecting the best matching handler method for that node type. This allows
implementing visitors that perform tasks either for very specific node types
or generally applicable for any node type, depending on the handler's name.

Loki includes a range of ready-to-use and configurable visitors for many common
use cases, such as discovering certain node types, modifying or replacing
nodes in the tree, or creating a string representation of the tree.
For some use cases it may be easier to implement new visitors tailored to the
task.


Searching the tree
------------------

The first category of visitors traverses the IR and collects a list of results
subject to certain criteria. In almost all cases :any:`FindNodes` is the tool
for that job with some bespoke variants for specific use cases.

.. autosummary::

   loki.ir.find.FindNodes
   loki.ir.find.FindScopes
   loki.ir.find.SequenceFinder
   loki.ir.find.PatternFinder

A common pattern for using :any:`FindNodes` is the following:

.. code-block:: python

   for loop in FindNodes((Loop, WhileLoop)).visit(routine.body):
       # ...do something with loop...

There are additional visitors that search all expression trees embedded in the
control flow IR, which are explained further down.


Transforming the tree
---------------------

A core feature of Loki is the ability to transform the IR, which is done using
the :any:`Transformer`. It is a visitor that rebuilds the tree and replaces
nodes according to a mapper.

.. autosummary::

   loki.ir.transformer.Transformer
   loki.ir.transformer.NestedTransformer
   loki.ir.transformer.MaskedTransformer
   loki.ir.transformer.NestedMaskedTransformer

:any:`Transformer` is commonly used in conjunction with :any:`FindNodes`, with
the latter being used to build the mapper for the first. The following example
removes all loops over the horizontal dimension and replaces them by
their body. This code snippet is a simplified version of a transformation used
in :any:`ExtractSCATransformation`:

.. code-block:: python

   routine = Subroutine(...)
   horizontal = Dimension(...)

   ...

   loop_map = {}
   for loop in FindNodes(Loop).visit(routine.body):
       if loop.variable == horizontal.variable:
           loop_map[loop] = loop.body
   routine.body = Transformer(loop_map).visit(routine.body)


Converting the tree to string
-----------------------------

The last step in a transformation pipeline is usually to write the transformed
IR to a file. This is a task for :doc:`Loki's backends ` which
themselves are subclasses of :class:`loki.visitors.pprint.Stringifier`, yet
another visitor. :class:`loki.visitors.pprint.Stringifier` doubles as a
pretty-printer for the IR that is useful for debugging.

.. autosummary::

   loki.ir.pprint.Stringifier
   loki.ir.pprint

Implementing new visitors
-------------------------

Any new visitor should subclass :any:`Visitor` (or any of its subclasses).

The common base class for all visitors is :any:`GenericVisitor`, declared in
:py:mod:`loki.visitors` that provides the basic functionality for matching
objects to their handler methods. Derived from that is :any:`Visitor` which
adds a default handler :data:`visit_Node` (for :any:`Node`) and functionality
to recurse for all items in a list or tuple and return the combined result.

To define handlers in new visitors, they should define :data:`visit_Foo`
methods for each class :data:`Foo` they want to handle.
If a specific method for a class :data:`Foo` is not found, the MRO
of the class is walked in order until a matching method is found (all the
way until, for example, :any:`Visitor.visit_Node` applies).
The method signature is:

.. code-block:: python

   def visit_Foo(self, o, [*args, **kwargs]):
       pass

The handler is responsible for visiting the children (if any) of
the node :data:`o`.  :data:`*args` and :data:`**kwargs` may be
used to pass information up and down the call stack.  You can also
pass named keyword arguments, e.g.:

.. code-block:: python

    def visit_Foo(self, o, parent=None, *args, **kwargs):
        pass

Mappers
=======

Mappers are visitors that traverse
:ref:`expression trees `.

They are built upon :py:mod:`pymbolic.mapper` classes and for that reason use
a slightly different way of determining the handler methods: each expression
tree node (:class:`pymbolic.primitives.Expression`) holds a class
attribute :attr:`mapper_method` with the name of the relevant method.

Loki provides, similarly to control flow tree visitors, ready-to-use mappers
for searching or transforming expression trees, all of which are implemented
in :mod:`loki.expression.mappers`. In addition,
:mod:`loki.expression.expr_visitors` provides visitors that apply the same mapper
to all expression trees in the IR.


Searching in expression trees
-----------------------------

The equivalent to :any:`FindNodes` for expression trees is
:any:`ExpressionRetriever`. Using a generic function handle, (almost) arbitrary
conditions can be used as a query that decides whether to include a given node
into the list of results.

.. autosummary::

   loki.expression.mappers.ExpressionRetriever

Note that mappers operate only on expression trees, i.e. using them directly
is only useful when working with a single property of a control flow node,
such as :attr:`loki.ir.Assignment.rhs`. If one wanted to search for expression
nodes in all expression trees in the IR, e.g. to find all variables, bespoke
visitors exist that apply :any:`ExpressionRetriever` to all expression trees.

.. autosummary::

   loki.expression.expr_visitors.ExpressionFinder
   loki.expression.expr_visitors.FindExpressions
   loki.expression.expr_visitors.FindTypedSymbols
   loki.expression.expr_visitors.FindVariables
   loki.expression.expr_visitors.FindInlineCalls
   loki.expression.expr_visitors.FindLiterals

For example, the following finds all function calls embedded in expressions
(:any:`InlineCall`, as opposed to subroutine calls in :any:`CallStatement`):

.. code-block:: python

   for call in FindInlineCalls().visit(routine.body):
       # ...do something with call...


Transforming expression trees
-----------------------------

Transformations of the expression tree are done very similar to
:any:`Transformer`, using the mapper :any:`SubstituteExpressionsMapper` that
is given a map to replace matching expression nodes.

.. autosummary::

   loki.expression.mappers.LokiIdentityMapper
   loki.expression.mappers.SubstituteExpressionsMapper

In the same way that searching can be done on all expression trees in the IR,
transformations can be applied to all expression trees at the same time using
:any:`SubstituteExpressions`:

.. autosummary::

   loki.expression.expr_visitors.SubstituteExpressions

The following example shows how expression node discovery and substitution can
be combined to replace all occurences of intrinsic function calls.
(The code snippet is taken from :any:`replace_intrinsics`, where two `dict`,
:data:`function_map` and :data:`symbol_map`, provide the mapping to rename
certain function calls that appear in :data:`routine`.)

.. code-block:: python

   from loki.expression import symbols as sym

   callmap = {}
   for c in FindInlineCalls(unique=False).visit(routine.body):
       cname = c.name.lower()

       if cname in symbol_map:
           callmap[c] = sym.Variable(name=symbol_map[cname], scope=routine.scope)

       if cname in function_map:
           fct_symbol = sym.ProcedureSymbol(function_map[cname], scope=routine.scope)
           callmap[c] = sym.InlineCall(fct_symbol, parameters=c.parameters,
                                       kw_parameters=c.kw_parameters)

   routine.body = SubstituteExpressions(callmap).visit(routine.body)


Converting expressions to string
--------------------------------

Every backend has their own mapper to convert expressions to a source
code string, according to the corresponding language specification.
All build on a common base class :any:`LokiStringifyMapper`, which is
also called automatically when converting any expression node to string.

.. autosummary::

   loki.expression.mappers.LokiStringifyMapper
loki-ecmwf-0.3.6/docs/source/loki_lint.rst0000664000175000017500000004362715167130205020723 0ustar  alastairalastair=========
loki-lint
=========

.. contents:: Contents
   :local:

Loki's ability to parse Fortran source files into an internal representation
that is easy to work with has been used to create a protype implementation of
a static source code analysis tool for Fortran. The intention is to verify
compliance with the
`IFS coding standard `_
but the tool itself is generic enough to be used for other use cases, too.


Installation
============

A generic Loki installation as described in the :doc:`installation instructions
` also installs the linting script. However, it requires linter rules
to do anything useful. A basic set of rules for IFS is provided via the
``lint_rules`` module that can be optionally included in the installation as
described in `INSTALL.md`.

Basic usage
===========

The basic command for loki-lint is

.. code-block:: bash

   loki-lint.py [--log ] check [--basedir ] [--include ] [--exclude ]

The most important option is ``--include`` that specifies the pattern of file
names that should be checked. For example: ``--include **/*.F90`` checks all
files with suffix ``.F90`` in all subdirectories of the current working
directory. This option can be specified multiple times to add multiple files
or more than one directory (sub-)tree.

Patterns can be absolute paths or relative to the current working directory.
Optionally, a different base directory relative to which patterns should be
interpreted can be specified with ``--basedir``.

An optional exclusion pattern can be given with ``--exclude`` that skips files
with names matching that pattern.

Progress and warnings are reported on the command line. Optionally, an
additional log file is written with the name given via ``--log``.

Examples
--------

.. dropdown:: Minimal example

   This checks only the ``cloudsc.F90`` file from the
   [dwarf-p-cloudsc](https://github.com/ecmwf-ifs/dwarf-p-cloudsc) mini-app:

   .. code-block:: bash

      $~> $ loki-lint.py check --include src/cloudsc_fortran/cloudsc.F90
      Base directory: 
      Include patterns:
        - src/cloudsc_fortran/cloudsc.F90
      Exclude patterns:

      1 files selected for checking (0 files excluded).

      Using 4 worker.
      10 rules available.
      Checking against 10 rules.

      [1.3] CodeBodyRule: src/cloudsc_fortran/cloudsc.F90 (ll. 1833-1837) - Nesting of conditionals exceeds limit of 3
      [1.9] DrHookRule: src/cloudsc_fortran/cloudsc.F90 (ll. 10-2867) in routine "CLOUDSC" - First executable statement must be call to DR_HOOK
      [1.9] DrHookRule: src/cloudsc_fortran/cloudsc.F90 (ll. 10-2867) in routine "CLOUDSC" - Last executable statement must be call to DR_HOOK
      [4.7] ExplicitKindRule: src/cloudsc_fortran/cloudsc.F90 (ll. 2046-2050) - 0.8 used without explicit KIND
      [4.7] ExplicitKindRule: src/cloudsc_fortran/cloudsc.F90 (l. 2380) - 1.0 used without explicit KIND
      [4.7] ExplicitKindRule: src/cloudsc_fortran/cloudsc.F90 (l. 2380) - 0.5 used without explicit KIND
      [4.7] ExplicitKindRule: src/cloudsc_fortran/cloudsc.F90 (l. 2381) - 273.0 used without explicit KIND
      [4.7] ExplicitKindRule: src/cloudsc_fortran/cloudsc.F90 (l. 2381) - 1.5 used without explicit KIND
      [4.7] ExplicitKindRule: src/cloudsc_fortran/cloudsc.F90 (l. 2381) - 393.0 used without explicit KIND
      [4.7] ExplicitKindRule: src/cloudsc_fortran/cloudsc.F90 (l. 2381) - 120.0 used without explicit KIND
      [4.7] ExplicitKindRule: src/cloudsc_fortran/cloudsc.F90 (ll. 2387-2388) - 0.65 used without explicit KIND
      [4.7] ExplicitKindRule: src/cloudsc_fortran/cloudsc.F90 (ll. 2387-2388) - 0.5 used without explicit KIND
      [4.7] ExplicitKindRule: src/cloudsc_fortran/cloudsc.F90 (ll. 2387-2388) - 0.5 used without explicit KIND
      [4.7] ExplicitKindRule: src/cloudsc_fortran/cloudsc.F90 (ll. 2387-2388) - 0.5 used without explicit KIND
      [2.2] LimitSubroutineStatementsRule: src/cloudsc_fortran/cloudsc.F90 (ll. 10-2867) in routine "CLOUDSC" - Subroutine has 604 executable statements (should not have more than 300)
      [3.6] MaxDummyArgsRule: src/cloudsc_fortran/cloudsc.F90 (ll. 10-2867) in routine "CLOUDSC" - Subroutine has 54 dummy arguments (should not have more than 50)

      1 files parsed successfully

.. dropdown:: Minimal example with a different ``--basedir``

   This checks only the ``cloudsc.F90`` file but specifies a different base
   directory. Note the difference in output:

   .. code-block:: bash

      $~> $ loki-lint.py check --basedir src/cloudsc_fortran --include cloudsc.F90
      Base directory: src/cloudsc_fortran
      Include patterns:
        - cloudsc.F90
      Exclude patterns:

      1 files selected for checking (0 files excluded).

      Using 4 worker.
      10 rules available.
      Checking against 10 rules.

      [1.3] CodeBodyRule: cloudsc.F90 (ll. 1833-1837) - Nesting of conditionals exceeds limit of 3
      [1.9] DrHookRule: cloudsc.F90 (ll. 10-2867) in routine "CLOUDSC" - First executable statement must be call to DR_HOOK
      [1.9] DrHookRule: cloudsc.F90 (ll. 10-2867) in routine "CLOUDSC" - Last executable statement must be call to DR_HOOK
      [4.7] ExplicitKindRule: cloudsc.F90 (ll. 2046-2050) - 0.8 used without explicit KIND
      [4.7] ExplicitKindRule: cloudsc.F90 (l. 2380) - 1.0 used without explicit KIND
      [4.7] ExplicitKindRule: cloudsc.F90 (l. 2380) - 0.5 used without explicit KIND
      [4.7] ExplicitKindRule: cloudsc.F90 (l. 2381) - 273.0 used without explicit KIND
      [4.7] ExplicitKindRule: cloudsc.F90 (l. 2381) - 1.5 used without explicit KIND
      [4.7] ExplicitKindRule: cloudsc.F90 (l. 2381) - 393.0 used without explicit KIND
      [4.7] ExplicitKindRule: cloudsc.F90 (l. 2381) - 120.0 used without explicit KIND
      [4.7] ExplicitKindRule: cloudsc.F90 (ll. 2387-2388) - 0.65 used without explicit KIND
      [4.7] ExplicitKindRule: cloudsc.F90 (ll. 2387-2388) - 0.5 used without explicit KIND
      [4.7] ExplicitKindRule: cloudsc.F90 (ll. 2387-2388) - 0.5 used without explicit KIND
      [4.7] ExplicitKindRule: cloudsc.F90 (ll. 2387-2388) - 0.5 used without explicit KIND
      [2.2] LimitSubroutineStatementsRule: cloudsc.F90 (ll. 10-2867) in routine "CLOUDSC" - Subroutine has 604 executable statements (should not have more than 300)
      [3.6] MaxDummyArgsRule: cloudsc.F90 (ll. 10-2867) in routine "CLOUDSC" - Subroutine has 54 dummy arguments (should not have more than 50)

      1 files parsed successfully


.. dropdown:: Example for a complete command line

   This specifies a custom path relative to which the patterns are to be
   interpreted and includes all F90-files in the ``phys_ec`` and ``module``
   directories. Note that single quotes may be necessary to ensure the shell
   does not expand the pattern. Output is written to a log file with current
   date and time in the file name.

   .. code-block:: bash

      loki-lint.py --log ifs_$(date +"%Y%m%d-%H%M").log check --basedir /path/to/ifs-source/branch/ifs --include 'phys_ec/*.F90' --include 'module/*.F90'


Help command
------------

loki-lint has a built-in help output detailing the use of the application. Run

.. code-block:: bash

   loki-lint.py --help

to display the generic help text, and

.. code-block:: bash

   loki-lint.py check --help

gives some advice about the usage of the source file checker and its options.
This includes some advanced options not mentioned here.

The list of available rules that source files are tested against can be
displayed by running (optionally with their ID and a short description for
each rule):

.. code-block:: bash

   loki-lint.py rules [--with-title]


Configuration
=============

The behaviour of Loki-lint and its rules can be configured using a YAML
configuration file. Currently, this allows to change settings for individual
rules as well as the list of rules to be checked.

For that, simply provide the config file in the command line like this:

.. code-block:: bash

   loki-lint.py check --config 

The default configuration can be displayed (and optionally written to file)
using:

.. code-block:: bash

   loki-lint.py default-config [--output-file ]

This default configuration can then be used as a template for creating an
individual configuration file. Any options not specified explicitly in the
configuration file are chosen to be default values.

Rules-module
------------

The rules against which Loki-lint performs checks can be configured as follows:

.. code-block:: bash

   loki-lint.py --rules-module  check [options/arguments]

If a rules-module is not specified, then the default :mod:`lint_rules.ifs_coding_standards_2011`
is used.

Implementing own rules
======================

All rules are implemented in :mod:`lint_rules`. Currently, this includes:

#. :mod:`lint_rules.ifs_coding_standards_2011` - A (small) subset of the rules defined in the IFS coding standards document.
#. :mod:`lint_rules.debug_rules` - A set of rules to identify common mistakes/anti-patterns:
    * :any:`ArgSizeMismatchRule` - Check for argument/dummy-argument size consistency
    * :any:`DynamicUboundCheckRule` - Check if run-time bounds checking is used rather than compile-time bounds checking.

To be able to write own rules a rudimentary understanding of
:doc:`internal_representation` is helpful.

Each rule is represented by a subclass of :any:`GenericRule` with the
following structure:

.. code-block:: python

   class MyOwnRule(GenericRule):

       type = RuleType.WARN

       docs = {
           'id': '13.37',
           'title': 'Scientists should write {what_now}.',
       }

       config = {
           'some_option': 'some value',
           'what_now': 'sensible code',
           'another_option': ['a', 'list', 'of', 'values']
       }

       fixable = True

       @classmethod
       def check_module(cls, module, rule_report, config):
           # Implement checks on module level here
           rule_report.add("Problem in this module", module)

       @classmethod
       def check_subroutine(cls, subroutine, rule_report, config):
           # Implement checks on subroutine level here
           rule_report.add("Problem in this subroutine", subroutine)

       @classmethod
       def check_file(cls, sourcefile, rule_report, config):
           # Implement checks on source file level here
           rule_report.add("Problem in this file", sourcefile)

       @classmethod
       def fix_subroutine(cls, subroutine, rule_report, config):
           # Implements logic that attempts to fix the problems that
           # were flagged in rule_report


Properties of a rule
--------------------

* :attr:`type` : The type, category or severity of that rule. Available types
  are defined in :any:`RuleType` and comprise currently :attr:`INFO`,
  :attr:`WARN`, :attr:`SERIOUS`, :attr:`ERROR` (with increasing severity).

* :attr:`docs` : A short description of that rule. At the moment, this includes
  by default

   * :attr:`id` : The rule number according to the IFS Coding standards

   * :attr:`title` : A short description of that rule. It may contain placeholder
     values (such as ``{what_now}``) that are replaced by the corresponding
     value from the config when displaying the rules (see example above).

* :attr:`config` : A dictionary that allows to parametrize the rule, with given
  default values. These options are exposed via the config file mentioned
  above, where defaults can be overwritten.

* :attr:`fixable` : `True`/`False` to indicate if the rule has a method
  :meth:`fix_*` that can be used to make an attempt of automatically fixing
  the problems the corresponding :meth:`check_*` method reported. Defaults to
  `False`.

.. note::
   Automatic fixing of rules is currently in prototype stage and the API may
   change in the future.

Further **properties for future use** are already implemented but not currently
used:

* :attr:`deprecated` : `True`/`False` to indicate when a rule has been
  superseded by other rules (e.g., due to a new revision of the Coding
  Standards). Defaults to `False`.
* :attr:`replaced_by` : A tuple that can be used to specify the rule(s) that
  replaced this rule when it became deprecated.


Methods of a rule
-----------------

The core of a rule are its :meth:`check*` methods, which implement its behaviour.
Depending on the nature of a rule, it may require checks to be carried out on
different levels in the hierarchy of a source file (the :any:`Sourcefile` itself
or :any:`Module` or :any:`Subroutine` that are contained in it). For that reason,
there are multiple entry points that a rule can implement, depending on the
specific needs. Any function that is not required can simply be left out. The
driver of loki-lint calls each of the following routines for every entity in a
source file:

* :meth:`check_file` once for the file (:any:`Sourcefile`),
* :meth:`check_module` for every module (:any:`Module`) in that file, and
* :meth:`check_subroutine` for every subroutine (:any:`Subroutine`) in that
  file and for every subroutine contained in a module in that file, and for
  every subroutine contained in a subroutine in that file, etc.

**Arguments** given to each of those routines are

* A :any:`Sourcefile`, :any:`Module` or :any:`Subroutine` object;
* The reporter (:any:`RuleReport`) for this rule, to which detected problems
  can be reported (see below);
* A `dict` holding the configuration values (defaults or from the config file).


Reporting of problems
---------------------

Problems detected by a rule are reported by calling
``rule_report.add(message, location)``. Here, :data:`message` is an arbitrary
string describing the problem, and :data:`location` can be an arbitrary node of
the internal representation in which the problem occured. This parameter will
later be used to provide information about the location of the problem (e.g.,
line number).


Example of a rule
-----------------

To illustrate the use of :doc:`internal_representation` and how a rule is
implemented with that, consider the following example:

.. code-block:: python

   class MplCdstringRule(GenericRule):  # Coding standards 3.12

       type = RuleType.SERIOUS

       docs = {
           'id': '3.12',
           'title': 'Calls to MPL subroutines should provide a "CDSTRING" identifying the caller.',
       }

       @classmethod
       def check_subroutine(cls, subroutine, rule_report, config):
           '''Check all calls to MPL subroutines for a CDSTRING.'''
           for call in FindNodes(ir.CallStatement).visit(subroutine.ir):
               if call.name.upper().startswith('MPL_'):
                   for kw, _ in call.kwarguments:
                       if kw.upper() == 'CDSTRING':
                           break
                   else:
                       fmt_string = 'No "CDSTRING" provided in call to {}'
                       msg = fmt_string.format(call.name)
                       rule_report.add(msg, call)

This rule checks all calls to ``MPL_`` subroutines for the presence of a
keyword-argument ``CDSTRING`` that should provide identification of the
caller. Note the following implementation details of the class:

* The rule is categorized as :data:`SERIOUS`.
* Documentation contains its ID (3.12) and title (here, providing the full
  wording from the coding standards document).
* There is no config that modifies the behaviour of the rule.
* There is a single entry point to that rule: Only the method
  :meth:`check_subroutine` is implemented that is called for all subroutines
  in a source file (irrespective whether it is a free function in the file,
  or contained in a module or subroutine).

The implementation of :meth:`check_subroutine` features the following details:

* It uses the :doc:`visitor ` :any:`FindNodes` to find all
  :any:`CallStatement` nodes; this visitor is applied to the subroutine's IR,
  which is available via the attribute :any:`Subroutine.ir`.
* For every ``call`` node, it takes the name of the called routine
  (available as property :attr:`name` and converted to uppercase as Fortran is
  case-insensitive) and checks if it starts with ``MPL_``.
  For each such call node, it looks at all keyword arguments (available as list
  of ``(keyword, value)``-tuples in the property :attr:`kwarguments`).

  * If keyword ``CDSTRING`` is found, the search loop is stopped (with
    ``break``) and the outer visitor loop continues with the next call node;
  * if the loop terminates normally (i.e., break was not invoked) then no such
    keyword argument was found and the loop's ``else`` block is executed (this
    is a Python-specific feature allowing to execute a block of code only if a
    loop was not terminated "abnormally"). There, a message text is formed by
    inserting the name of the called routine into the ``fmt_string``. Then,
    this is reported to ``rule_report`` together with the problematic IR node
    ``call``. Later, the output handler will use this node to determine the
    exact position in the source file (e.g., to report the line number).

Note that this rule does not report anything if no problematic calls are present.

An example output of this rule looks as follows:

.. code-block:: text

  [3.12] MplCdstringRule: cma2odb/distio_mix.F90 (l. 821) - No "CDSTRING" provided in call to MPL_BROADCAST


Known issues
============

In general, bugs and open questions are collected in Loki's issue tracker
and this is also the best place to report any problems.

One important limitation is that loki-lint currently does not invoke a
C-preprocessor. Although Loki has now a built-in
:ref:`preprocessor `, this is not currently used in
loki-lint. Therefore, preprocessor directives are not interpreted but
essentially treated as comments. Thus, a code that does not reduce to
(syntactically) valid Fortran when ignoring PP directives, parsing that
file will fail (e.g., because each branch of an ``#ifdef ... #else ... #endif``
construct provides a different ``IF`` statement for a common ``ENDIF``).

For other limitations of Frontends or the IR, Loki has a built-in sanitizer for
input files to maneuver around some of the deficiencies.
loki-ecmwf-0.3.6/docs/source/INSTALL.md0000664000175000017500000000004215167130205017615 0ustar  alastairalastair```{include} ../../INSTALL.md
```
loki-ecmwf-0.3.6/docs/source/notebooks.rst0000664000175000017500000000070215167130205020725 0ustar  alastairalastair=================
Example notebooks
=================

The following Jupyter notebooks can be found in the ``example`` directory.
They provide an introduction into Loki's core functionality and should help
to get familiar with the API and usage.

.. toctree::
   :maxdepth: 1

   example/01_reading_and_writing_files
   example/02_working_with_the_ir
   example/03_loop_fusion
   example/04_creating_new_visitors
   example/05_argument_intent_linter
loki-ecmwf-0.3.6/docs/source/internal_representation.rst0000664000175000017500000003404415167130205023666 0ustar  alastairalastair===================================
Loki's internal representation (IR)
===================================

.. important::
    Loki is still under active development and has not yet seen a stable
    release. Interfaces can change at any time, objects may be renamed, or
    concepts may be re-thought. Make sure to sync your work to the current
    release frequently by rebasing feature branches and upstreaming
    more general applicable work in the form of pull requests.

.. contents:: Contents
   :local:

Loki's internal representation aims to achieve a balance between usability
and general applicability. This means that in places there may be shortcuts
taken to ease its use in the context of a source-to-source translation
utility but may break with established practices in compiler theory.
The IR was developed with Fortran source code in mind and that shows. Where
there exist similar concepts in other languages, things are transferable.
In other places, Fortran-specific annotations are included for the sole purpose
of enabling string reproducibility.

The internal representation is vertically divided into different layers,
roughly aligned with high level concepts found in Fortran and other
programming languages:

.. contents::
   :local:
   :depth: 1


Container data structures
=========================

Outermost are container data structures that conceptually translate to
Fortran's `program-units`, such as modules and subprograms.

Fortran modules are represented by :any:`Module` objects which comprise
a specification part (:py:attr:`Module.spec`) and a list of :any:`Subroutine`
objects contained in the module.

Subroutines and functions are represented by :any:`Subroutine` objects that
in turn have their own docstring (:py:attr:`Subroutine.docstring`),
specification part (:py:attr:`Subroutine.spec`), execution part
(:py:attr:`Subroutine.body`), and contained subprograms
(:py:attr:`Subroutine.members`).

To map these programming language concepts to source files and ease input or
output operations, any number of these container data structures can be
classes.

Available container classes
---------------------------

.. autosummary::

   loki.sourcefile.Sourcefile
   loki.module.Module
   loki.subroutine.Subroutine


Control flow tree
=================

Specification and execution parts of (sub)programs and modules are the central
components of container data structures. Each of them is represented by a tree
of control flow nodes, with a :any:`Section` as root node. This tree resembles
to some extend a hierarchical control flow graph where each node can have
control flow and expression nodes as children. Consequently, this separation on
node level is reflected in the internal representation, splitting the tree into
two levels:

1. :ref:`Control flow `
   (e.g., loops, conditionals, assignments, etc.);
   the corresponding classes are declared in :py:mod:`loki.ir` and described
   in this section.
2. :ref:`Expressions `
   (e.g., scalar/array variables, literals, operators, etc.);
   this is based on `Pymbolic `__ with
   encapsulating classes declared in :py:mod:`loki.expression.symbols` and
   described below.

The split in IR levels is meant to control the complexity of
individual tree traversals, and separate the symbolic expression layer
from the more Fortran-specific elements of the IR tree. This loosely
follows the principles outlined in
`Luporini et al. `__.

All control flow nodes implement the common base class :any:`Node` and
can have an arbitrary number of children that are either control flow nodes
or expression nodes. Thus, any control flow node looks in principle like the
following:

.. code-block:: none

                      Node
                      / | \
              +------+  |  +---+
             /          |       \
            /           |        \
      Expression   Expression   Node   ...

As an example, consider a basic Fortran ``DO i=1,n`` loop: it defines a loop
variable (``i``), a loop range (``1:n``) and a loop body. The body can be
one/multiple statements or other control flow structures and therefore is a
subtree of control flow nodes. Loop variable and range, however, are
expression nodes.

All control flow nodes fall into one of two categories:

* :any:`InternalNode`: nodes that have a :py:attr:`body` and therefore
  have other control flow nodes as children.
* :any:`LeafNode`: nodes that (generally) do not have any other
  control flow nodes as children.

Note that :any:`InternalNode` can have other properties than
:py:attr:`body` in which control flow nodes are contained as children
(for example, :py:attr:`else_body` in :any:`Conditional`).
All :any:`Node` may, however, have one or multiple expression trees
as children.

.. note:: All actual control flow nodes are implementations of one of the two
          base classes. Two notable exceptions to the above are the following:

          * :any:`MultiConditional` (for example, Fortran's ``SELECT CASE``):
            It has multiple bodies and thus does not fit the above framework.
            Conceptually, these could be converted into nested
            :any:`Conditional` but it would break string reproducibility.
            For that reason they are retained as a :any:`LeafNode` for the
            time being.
          * :any:`TypeDef`: This defines a new scope for symbols, which
            does not include symbols from the enclosing scope. Thus, it behaves
            like a leaf node although it has technically control flow nodes as
            children. It is therefore also implemented as a :any:`LeafNode`.

With this separation into two types of nodes, the schematics of the control flow
layer of the internal representation are as follows:

.. code-block:: none

                        InternalNode
                             |
                            body
                           /|||\
          +---------------+ /|\ +-------------+
         /          +------+ | +-----+         \
        /          /         |        \         \
    LeafNode InternalNode LeafNode LeafNode InternalNode ...
                  |                              |
                 body                           body
                /    \                         /    \
               /      \                         ....
         LeafNode  InternalNode
                        |
                       ...


Available control flow nodes
----------------------------

Abstract base classes
^^^^^^^^^^^^^^^^^^^^^

.. autosummary::

   loki.ir.Node
   loki.ir.InternalNode
   loki.ir.LeafNode

Internal node classes
^^^^^^^^^^^^^^^^^^^^^

.. autosummary::

   loki.ir.Section
   loki.ir.Associate
   loki.ir.Loop
   loki.ir.WhileLoop
   loki.ir.Conditional
   loki.ir.PragmaRegion
   loki.ir.Interface

Leaf node classes
^^^^^^^^^^^^^^^^^

.. autosummary::

   loki.ir.Assignment
   loki.ir.ConditionalAssignment
   loki.ir.CallStatement
   loki.ir.Allocation
   loki.ir.Deallocation
   loki.ir.Nullify
   loki.ir.Comment
   loki.ir.CommentBlock
   loki.ir.Pragma
   loki.ir.PreprocessorDirective
   loki.ir.Import
   loki.ir.VariableDeclaration
   loki.ir.ProcedureDeclaration
   loki.ir.DataDeclaration
   loki.ir.StatementFunction
   loki.ir.TypeDef
   loki.ir.MultiConditional
   loki.ir.TypeConditional
   loki.ir.MaskedStatement
   loki.ir.Intrinsic
   loki.ir.Enumeration


Expression tree
===============

Many control flow nodes contain one or multiple expressions, such as the
right-hand side of an assignment (:py:attr:`loki.ir.Assignment.rhs`) or the
condition of an ``IF`` statement (:py:attr:`loki.ir.Conditional.condition`).
Such expressions are represented by expression trees, comprising a single
node (e.g., the left-hand side of an assignment may be just a scalar variable)
or a large expression tree consisting of multiple nested sub-expressions.

Loki's expression representation is based on
`Pymbolic `__ but encapsulates all
classes with bespoke own implementations. This allows to enrich expression
nodes by attaching custom metadata, implementing bespoke comparison operators,
or store type information.

The base class for all expression nodes is :any:`pymbolic.primitives.Expression`.

Available expression tree nodes
-------------------------------

Typed symbol nodes
^^^^^^^^^^^^^^^^^^

.. autosummary::

   loki.expression.symbols.TypedSymbol
   loki.expression.symbols.Variable
   loki.expression.symbols.DeferredTypeSymbol
   loki.expression.symbols.Scalar
   loki.expression.symbols.Array
   loki.expression.symbols.ProcedureSymbol

Literals
^^^^^^^^

.. autosummary::

   loki.expression.symbols.Literal
   loki.expression.symbols.FloatLiteral
   loki.expression.symbols.IntLiteral
   loki.expression.symbols.LogicLiteral
   loki.expression.symbols.StringLiteral
   loki.expression.symbols.IntrinsicLiteral
   loki.expression.symbols.LiteralList

Mix-ins
^^^^^^^

.. autosummary::

   loki.expression.symbols.StrCompareMixin

Expression modules
^^^^^^^^^^^^^^^^^^

.. autosummary::

   loki.expression.expr_visitors
   loki.expression.mappers
   loki.expression.operations
   loki.expression.symbolic
   loki.expression.symbols


Type information and scopes
===========================

Every symbol in an expressions tree (:any:`TypedSymbol`, such as :any:`Scalar`,
:any:`Array`, :any:`ProcedureSymbol`) has a type (represented by a
:any:`DataType`) and, possibly, other attributes associated with it.
Type and attributes are stored together in a :any:`SymbolAttributes`
object, which is essentially a `dict`.

.. note::
   *Example:* An array variable ``VAR`` may be declared in Fortran as a subroutine
   argument in the following way:

   .. code-block:: none

      INTEGER(4), INTENT(INOUT) :: VAR(10)

   This variable has type :any:`BasicType.INTEGER` and the following
   additional attributes:

   * ``KIND=4``
   * ``INTENT=INOUT``
   * ``SHAPE=(10,)``

   The corresponding :any:`SymbolAttributes` object can be created as

   .. code-block::

      SymbolAttributes(BasicType.INTEGER, kind=Literal(4), intent='inout', shape=(Literal(10),))

If the variable object is associated with a :any:`Scope`, then its
:any:`SymbolAttributes` object is stored in the relevant :any:`SymbolTable`.
From there, all expression nodes that represent use of the associated symbol
(i.e., the variable object and any others with the same name) query the type
information from there. This means, changing the declared attributes of a symbol
applies this change for all instances of this symbol.

If the variable is not associated with a :any:`Scope`, then its
:any:`SymbolAttributes` object is stored locally and not shared by any other
variable objects.

.. warning::
   Loki allows to apply changes very freely, which means changing symbol
   attributes can lead to invalid states.

   For example, removing the ``shape`` property from the :any:`SymbolAttributes`
   object in a symbol table converts the corresponding :any:`Array` to
   a :any:`Scalar` variable. But at this point all expression tree nodes will
   still be :any:`Array`, possibly also with subscript operations (represented
   by the ``dimensions`` property).

   For plain :any:`Array` nodes (without subscript), rebuilding the IR will
   automatically take care of instantiating these objects as :any:`Scalar` but
   removing ``dimensions`` properties must be done explicitly.

Every object that defines a new scope (e.g., :any:`Subroutine`,
:any:`Module`, implementing :any:`Scope`) has an associated symbol table
(:any:`SymbolTable`). The :any:`SymbolAttributes` of a symbol declared or
imported in a scope are stored in the symbol table of that scope.
These symbol tables/scopes are organized in a hierarchical fashion, i.e., they
are aware of their enclosing scope and allow to recursively look-up entries.

The overall schematics of the scope and type representation are depicted in the
following diagram:

.. code-block:: none

      Subroutine | Module | TypeDef | ...
              \      |      /
               \     |     /   
                \    |    /
                   Scope
                     |
                     | 
                     |
                SymbolTable  - - - - - - - - - - - - TypedSymbol
                     |
                     |  
                     |
              SymbolAttributes
           /     |       |      \
          /      |       |       \  
         /       |       |        \
   DataType | (kind) | (intent) | (...)


Available data types
--------------------

The :any:`DataType` of a symbol can be one of

* :any:`BasicType`: intrinsic types, such as ``INTEGER``, ``REAL``, etc.
* :any:`DerivedType`: derived types defined somewhere
* :any:`ProcedureType`: any subroutines or functions declared or imported

Note that this is different from the understanding of types in the Fortran
standard, where only intrinsic types and derived types are considered a
type. Treating also procedures as types allows us to treat them uniformly
when considering external subprograms, procedure pointers and type bound
procedures.

.. code-block:: none

   BasicType | DerivedType | ProcedureType
            \       |       /
             \      |      /    
              \     |     /
                 DataType


Derived types
-------------

Derived type definitions (via :any:`TypeDef`) create entries in the scope's
symbol table in which they are defined to make the type definition available
to declarations.

Imports and deferred type
-------------------------

For imported symbols (via :any:`Import`) the source module may not be
available and thus no information about the symbol. This is indicated by
:any:`BasicType.DEFERRED`. This is also applied to any variable that is
instantiated without providing a type and where no type information can
be found in the scope's symbol table (either because no information has
been provided previously or because no scope is attached).

.. autosummary::

   loki.scope.Scope
   loki.scope.SymbolTable
   loki.types.SymbolAttributes
   loki.types.DataType
   loki.types.BasicType
   loki.types.DerivedType
   loki.types.ProcedureType
loki-ecmwf-0.3.6/docs/source/loki_pragma_model.csv0000664000175000017500000000315115167130205022353 0ustar  alastairalastairLoki,OpenACC,OMP-GPU
``create device(...)``,``declare create(...)``,``declare target(...)``
``update device(...) host(...)``,``update device(...) self(...)``,
``unstructured-data in(...) create(...) attach(...)``,``enter data copyin(...) create(...) attach(...)``,``target enter data map(to: ...) map(alloc: ...)``
``exit unstructured-data out(...) delete(...) detach(...) [finalize]``,``exit data copyout(...) delete(...) detach(...) [finalize]``,``target exit data map(from: ...) map(delete: ...) map(release: ... ???)``
``structured-data inout(...) in(...) out(...) create(...) present(...)``,``data copy(...) copyin(...) copyout(...) create(...) present(...)``,``target data map(tofrom: ...) map(to: ...) map(from: ...) map(to: ...)``
``end structured-data inout(...) in(...) out(...)``,``end data``,``end target data``
``loop gang private(...) vlength(...)``,``parallel loop gang private(...) vector_length(...)``,``target teams distribute thread_limit(...) ???``
``end loop gang``,``end parallel loop``,``end target teams distribute``
``loop vector private(...) reduction(...)``,``loop vector private(...) reduction(...)``,``parallel do``
``end loop vector``,,``end parallel do``
``loop seq``,``loop seq``,
``end loop seq``,,
``routine vector``,``routine vector``,
``routine seq``,``routine seq``,``declare target``
``data device-present vars(...)``,``data present()``,
``device-present vars (...)``,``data present(...)``,
``end device-present vars(...)``,``end data``,
``device-ptr vars (...)``,``data deviceptr(...)``,
``end device-ptr vars(...)``,``end data``,
``omp-update-global-vars in(...)``,,``target enter data map(to: ...)``
loki-ecmwf-0.3.6/docs/source/getting_started.rst0000664000175000017500000001171515167130205022117 0ustar  alastairalastair===============
Getting started
===============

.. toctree::
   :hidden:

   INSTALL
   notebooks


Core concepts (the philosophical bit)
=====================================

On a fundamental level, converting between different programming
styles in a low-level compiled language like Fortran or C/C++ typically
requires assumptions to be made that are specific to the data and algorithm
and do not generalize to the entire language. This is why Loki provides a
programmable interface rather than a push-button solution, leaving it
up to developers to decide which assumptions about the original source
code can be used and how.

For example, converting large numbers of IFS physics code to a "single column"
format (see below) requires the explicit knowledge of which index variables
typically represent the parallel dependency-free horizontal dimension that
is to be lifted.

The aim of Loki is therefore to give developers all the tools to encode their
own code transformation in an elegant, pythonic fashion. The core concepts
provided for this are:

* :any:`Module` and :any:`Subroutine` classes (kernels) that each provide an
  :doc:`Intermediate Representation (IR) `
  of their source code, as well as
  utilities to inspect and transform the underlying IR nodes.
* Expressions contained in IR nodes, such as :any:`Statement`, :any:`Loop`,
  and :any:`Conditional`, are represented as independent sub-trees, based on the
  `Pymbolic `__ infrastructure.
* Three frontends are supported that are used to parse Fortran code
  either from source files or strings into the Loki IR trees. Multiple
  backends are provided to generate Fortran or (experimentally) C or (even more
  experimentally) Python code from the combined IR and expression trees.
* A :any:`Transformation` class is provided that allows users to encode
  individual code changes based on the abstract representation
  provided by Loki's IR and expression objects and can be applied
  to individual :any:`Subroutine` and :any:`Module` objects - much like simple
  compiler passes.
* A :any:`Scheduler` class that provides bulk processing
  and inter-procedural analysis (IPA) tools to apply individual changes
  over large numbers of files while honoring the call-tree that
  connects them.

Example transformations and current features
============================================

Loki is primarily an API and toolbox, allowing developers to create their
own head scripts and to create and invoke source-to-source translation toolchains.
In addition, a set of supported transformations is provided by the
package itself in :mod:`loki.transformations`. These range from utilities
that can be used with generic Fortran codes to highly bespoke transformations
for generating GPU code based on highly model-specific assumptions.

The ``loki_transform.py`` script is provided by the Loki install. The primary
transformation passes provided by these example transformations are:

* **Idempotence (Idem)** - A simple transformation that performs a
    neutral parse-unparse cycle on a kernel.
* **Single column abstraction (SCA)** - Transforms a set of kernels
  into Single column format by removing the specified horizontal
  iteration dimension. This transformation has a "driver" and a
  "kernel" mode, as it potentially changes the subroutine's call
  signature to remove derived types (structs do not expose
  dimensions).
* **Single column coalesced (SCC)** - Transforms a set of kernels
  from CPU-style (SIMD) vectorization format to a GPU-style (SIMT)
  loop layout. It removes the specified horizontal iteration
  dimension and re-inserts it outermost. Optionally, the horizontal
  loop can be stripped from kernels and re-inserted in the driver, to
  allow hoisting the allocation of temporary arrays to driver level (SCCH).
* **C transpilation** - A dedicated Fortran-to-C transpilation
  pipeline that converts Fortran source code into (column major,
  1-indexed) C kernel code. The transformation pipeline also creates
  the necessary header and `ISOC` wrappers to integrate this C kernel
  with a Fortran driver layer, as demonstrated with the
  `CLOUDSC ESCAPE dwarf `_.

First steps
===========

To start using Loki, follow the :doc:`installation instructions `.
We recommend to study the :doc:`Jupyter notebooks ` in the `example`
directory to get familiar with the basic API of Loki. The
:doc:`Using Loki ` section provides more details on the inner
workings and underpinning concepts.

Contributions
=============

Contributions to Loki are welcome. In order to do so, please open an
issue in the `Github repository `__
where a feature request or bug can be discussed.
Then create a pull request with your contribution. We require you to read and sign the
`contributors license agreement (CLA) `__
before your contribution can be reviewed and merged.
loki-ecmwf-0.3.6/docs/source/loki_transform.rst0000664000175000017500000000035215167130205021754 0ustar  alastairalastair==============
loki-transform
==============

.. contents:: Contents
   :local:

The ``loki-transform.py`` command-line utility is the main entry point
for batch provessing a set of source files with source-to-source transformations.
loki-ecmwf-0.3.6/docs/Makefile0000664000175000017500000000135715167130205016337 0ustar  alastairalastair# Minimal makefile for Sphinx documentation
#

# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS    ?=
SPHINXBUILD   ?= sphinx-build
SOURCEDIR     = source
BUILDDIR      = build

SPHINXAPIDOC ?= sphinx-apidoc

# Put it first so that "make" without argument is like "make help".
help:
	@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

.PHONY: help Makefile apidoc

apidoc:
	@$(SPHINXAPIDOC) -o "$(SOURCEDIR)" -e --no-toc -f ../loki

# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile apidoc
	@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
loki-ecmwf-0.3.6/docs/make.bat0000664000175000017500000000137415167130205016303 0ustar  alastairalastair@ECHO OFF

pushd %~dp0

REM Command file for Sphinx documentation

if "%SPHINXBUILD%" == "" (
	set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=source
set BUILDDIR=build

if "%1" == "" goto help

%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
	echo.
	echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
	echo.installed, then set the SPHINXBUILD environment variable to point
	echo.to the full path of the 'sphinx-build' executable. Alternatively you
	echo.may add the Sphinx directory to PATH.
	echo.
	echo.If you don't have Sphinx installed, grab it from
	echo.http://sphinx-doc.org/
	exit /b 1
)

%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end

:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%

:end
popd
loki-ecmwf-0.3.6/.gitignore0000664000175000017500000000074515167130205015737 0ustar  alastairalastair*~
*.pyc
*.swp

# NFS artifacts and core dumps
.nfs*
core.*

loki.egg-info/*
lint_rules/lint_rules.egg-info/*
.eggs/*
.cache/*
.pytest_cache/*
.dacecache/*
/build
/lint_rules/build
/artifacts

# Installation artifacts
loki_env/*
loki-activate

# Docs
docs/build
docs/source/loki.*rst
docs/source/scripts.*rst
docs/source/loki-*.rst
docs/source/lint_rules.*
docs/source/raps_deps.rst
docs/source/transformations.*

# notebooks
example/.ipynb_checkpoints
example/my_module.F90
.DS_Store
loki-ecmwf-0.3.6/LICENSE0000664000175000017500000002477715167130205014767 0ustar  alastairalastair                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   Copyright 2018- ECMWF

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.
loki-ecmwf-0.3.6/cmake/0000775000175000017500000000000015167130205015021 5ustar  alastairalastairloki-ecmwf-0.3.6/cmake/loki_transform.cmake0000664000175000017500000004035115167130205021057 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

include( loki_transform_helpers )

##############################################################################
# .rst:
#
# loki_transform
# ==============
#
# Invoke loki-transform.py using the given options.::
#
#   loki_transform(
#       COMMAND 
#       OUTPUT  [ ...]
#       DEPENDS  [ ...]
#       MODE 
#       CONFIG 
#       [CPP]
#       [FRONTEND ]
#       [BUILDDIR ]
#       [SOURCES  [ ...]]
#       [HEADERS  [ ...]]
#   )
#
# Call ``loki-transform.py  ...`` with the provided arguments.
# See ``loki-transform.py`` for a description of all options.
#
# Options
# -------
#
# :OUTPUT:     The output files generated by Loki. Providing them here allows
#              to declare dependencies on this command later.
# :DEPENDS:    The input files or targets this call depends on.
#
##############################################################################

function( loki_transform )

    set( options CPP )
    set( oneValueArgs COMMAND MODE FRONTEND CONFIG BUILDDIR )
    set( multiValueArgs OUTPUT DEPENDS SOURCES HEADERS INCLUDES DEFINITIONS OMNI_INCLUDE XMOD )

    cmake_parse_arguments( _PAR "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} )

    if( _PAR_UNPARSED_ARGUMENTS )
        ecbuild_critical( "Unknown keywords given to loki_transform(): \"${_PAR_UNPARSED_ARGUMENTS}\"")
    endif()

    # Select command for loki-transform.py
    if( NOT _PAR_COMMAND )
        ecbuild_critical( "No COMMAND specified for loki_transform()" )
    endif()
    set( _ARGS ${_PAR_COMMAND} )

    if( NOT _PAR_OUTPUT )
        ecbuild_critical( "No OUTPUT specified for loki_transform()" )
    endif()

    if( NOT _PAR_DEPENDS )
        ecbuild_critical( "No DEPENDS specified for loki_transform()" )
    endif()

    # Translate function args to arguments for loki-transform.py
    _loki_transform_parse_args()

    if( _PAR_CPP )
        list( APPEND _ARGS --cpp )
    endif()

    # Ensure transformation script and environment is available
    _loki_transform_env_setup()

    ecbuild_debug( "COMMAND ${_LOKI_TRANSFORM} ${_ARGS}" )

    add_custom_command(
        OUTPUT ${_PAR_OUTPUT}
        COMMAND ${_LOKI_TRANSFORM} ${_ARGS}
        DEPENDS ${_PAR_DEPENDS} ${_LOKI_TRANSFORM_DEPENDENCY}
        COMMENT "[Loki] Pre-processing: command=${_PAR_COMMAND} mode=${_PAR_MODE} frontend=${_PAR_FRONTEND}"
    )

endfunction()

##############################################################################
# .rst:
#
# loki_transform_plan
# ===================
#
# Run Loki bulk transformation in plan mode.::
#
#   loki_transform_plan(
#       MODE 
#       FRONTEND 
#       [CPP]
#       [CONFIG ]
#       [BUILDDIR ]
#       [NO_SOURCEDIR | SOURCEDIR ]
#       [CALLGRAPH ]
#       [PLAN ]
#       [SOURCES  [ ...]]
#       [HEADERS  [ ...]]
#   )
#
# Call ``loki-transform.py plan ...`` with the provided arguments.
# See ``loki-transform.py`` for a description of all options.
#
##############################################################################

function( loki_transform_plan )

    set( options NO_SOURCEDIR CPP )
    set( oneValueArgs MODE FRONTEND CONFIG BUILDDIR SOURCEDIR CALLGRAPH PLAN )
    set( multiValueArgs SOURCES HEADERS )

    cmake_parse_arguments( _PAR "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} )

    if( _PAR_UNPARSED_ARGUMENTS )
        ecbuild_critical( "Unknown keywords given to loki_transform_plan(): \"${_PAR_UNPARSED_ARGUMENTS}\"")
    endif()

    set( _ARGS )

    # Translate function args to arguments for loki-transform.py
    _loki_transform_parse_args()

    if( _PAR_CPP )
        list( APPEND _ARGS --cpp )
    endif()

    if( NOT _PAR_NO_SOURCEDIR )
        if( _PAR_SOURCEDIR )
            list( APPEND _ARGS --root ${_PAR_SOURCEDIR} )
        else()
            ecbuild_critical( "No SOURCEDIR specified for loki_transform_plan()" )
        endif()
    endif()

    if( _PAR_CALLGRAPH )
        list( APPEND _ARGS --callgraph ${_PAR_CALLGRAPH} )
    endif()

    if( _PAR_PLAN )
        list( APPEND _ARGS --plan-file ${_PAR_PLAN} )
    else()
        ecbuild_critical( "No PLAN file specified for loki_transform_plan()" )
    endif()

    _loki_transform_env_setup()

    # Create a source transformation plan to tell CMake which files will be affected
    ecbuild_info( "[Loki] Creating plan: mode=${_PAR_MODE} frontend=${_PAR_FRONTEND} config=${_PAR_CONFIG}" )
    ecbuild_debug( "COMMAND ${_LOKI_TRANSFORM_EXECUTABLE} plan ${_ARGS}" )

    execute_process(
        COMMAND ${_LOKI_TRANSFORM_EXECUTABLE} plan ${_ARGS}
        WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
        COMMAND_ERROR_IS_FATAL ANY
        ECHO_ERROR_VARIABLE
    )

endfunction()

##############################################################################
# .rst:
#
# loki_update_target_sources
# ==========================
#
# A reusable utility to update the sources of a CMake target .::
#
#   loki_update_target_sources(
#      TARGET            
#      REMOVE_SOURCES    
#      TRANSFORM_SOURCES 
#      APPEND_SOURCES    
#      [COPY_UNMODIFIED]
#   )
#
##############################################################################

function( loki_update_target_sources )

    set( options COPY_UNMODIFIED )
    set( single_value_args TARGET )
    set( multi_value_args REMOVE_SOURCES TRANSFORM_SOURCES APPEND_SOURCES )

    cmake_parse_arguments( _PAR_LUTS "${options}" "${single_value_args}" "${multi_value_args}" ${ARGN} )

    if( _PAR_LUTS_UNPARSED_ARGUMENTS )
        ecbuild_critical( "Unknown keywords given to loki_update_target_sources(): \"${_PAR_LUTS_UNPARSED_ARGUMENTS}\"")
    endif()

    # Exclude source files that Loki has re-generated.
    # Note, this is done explicitly here because the HEADER_FILE_ONLY
    # property is not always being honoured by CMake.
    get_target_property( _target_sources ${_PAR_LUTS_TARGET} SOURCES )
    foreach( source ${_PAR_LUTS_REMOVE_SOURCES} )
        # get_property( source_deps SOURCE ${source} PROPERTY OBJECT_DEPENDS )
        list( FILTER _target_sources EXCLUDE REGEX ${source} )
    endforeach()

    if( NOT _PAR_LUTS_COPY_UNMODIFIED )
        # Update the target source list
        set_property( TARGET ${_PAR_LUTS_TARGET} PROPERTY SOURCES ${_target_sources} )
    else()
        # Copy the unmodified source files to the build dir
        set( _target_sources_copy "" )
        foreach( source ${_target_sources} )
            get_filename_component( _source_name ${source} NAME )
            list( APPEND _target_sources_copy ${CMAKE_CURRENT_BINARY_DIR}/${_source_name} )
            ecbuild_debug( "[Loki] copy: ${source} -> ${CMAKE_CURRENT_BINARY_DIR}/${_source_name}" )
        endforeach()
        file( COPY ${_target_sources} DESTINATION ${CMAKE_CURRENT_BINARY_DIR} )

        # Mark the copied files as build-time generated
        set_source_files_properties( ${_target_sources_copy} PROPERTIES GENERATED TRUE )

        # Update the target source list
        set_property( TARGET ${_PAR_LUTS_TARGET} PROPERTY SOURCES ${_target_sources_copy} )
    endif()

    list( LENGTH _PAR_LUTS_TRANSFORM_SOURCES LOKI_APPEND_LENGTH )
    if ( LOKI_APPEND_LENGTH GREATER 0 )
        # Mark the generated stuff as build-time generated
        set_source_files_properties( ${_PAR_LUTS_APPEND_SOURCES} PROPERTIES GENERATED TRUE )

        # Add the Loki-generated sources to our target (CLAW is not called)
        target_sources( ${_PAR_LUTS_TARGET} PRIVATE ${_PAR_LUTS_APPEND_SOURCES} )
    endif()

    # Copy over compile flags for generated source. Note that this assumes
    # matching indexes between LOKI_SOURCES_TO_TRANSFORM and LOKI_SOURCES_TO_APPEND
    # to encode the source-to-source mapping. This matching is strictly enforced
    # in the `CMakePlannerTransformation`.
    loki_copy_compile_flags(
        ORIG_LIST ${_PAR_LUTS_TRANSFORM_SOURCES}
        NEW_LIST ${_PAR_LUTS_APPEND_SOURCES}
    )

    if( _PAR_LUTS_COPY_UNMODIFIED )
        loki_copy_compile_flags(
            ORIG_LIST ${_target_sources}
            NEW_LIST ${_target_sources_copy}
        )
    endif()

endfunction()

##############################################################################
# .rst:
#
# loki_transform_target
# ======================
#
# Apply Loki source transformations to sources in a CMake target.::
#
#   loki_transform_target(
#       TARGET 
#       [COMMAND ]
#       MODE 
#       CONFIG 
#       PLAN 
#       [CPP] [CPP_PLAN]
#       [FRONTEND ]
#       [SOURCES  [ ...]]
#       [HEADERS  [ ...]]
#       [NO_PLAN_SOURCEDIR COPY_UNMODIFIED]
#   )
#
# Applies a Loki bulk transformation to the source files belonging to a particular
# CMake target according to the specified entry points in the ``config-file``.
#
# This is done via a call to ``loki-transform.py plan ...`` during configure,
# from which the specific additions and deletions of source objects within the
# target are derived. See ``loki_transform_plan`` for more details.
#
# Subsequently, the actual bulk transformation of source files is scheduled
# via ``loki-transform.py ``, where ```` is provided via ``COMMAND``.
# If none is given, this defaults to ``convert``.
#
# Preprocessing of source files during plan or transformation stage can be
# enabled using ``CPP_PLAN`` and ``CPP`` options, respectively.
#
# ``NO_PLAN_SOURCEDIR`` can optionally be specified to call the plan stage without
# an explicit root directory. That means, Loki will generate absolute paths in the
# CMake plan file. This requires the ``SOURCES`` of the target to transform also
# to be given with absolute paths, otherwise the file list update mechanism will not
# work as expected.
#
# See ``loki-transform.py`` for a description of all options.
#
##############################################################################

function( loki_transform_target )

    set( options NO_PLAN_SOURCEDIR COPY_UNMODIFIED CPP CPP_PLAN )
    set( single_value_args COMMAND MODE FRONTEND CONFIG PLAN )
    set( multi_value_args TARGET SOURCES HEADERS DEFINITIONS INCLUDES )

    cmake_parse_arguments( _PAR_T "${options}" "${single_value_args}" "${multi_value_args}" ${ARGN} )

    if( _PAR_T_UNPARSED_ARGUMENTS )
        ecbuild_critical( "Unknown keywords given to loki_transform_target(): \"${_PAR_T_UNPARSED_ARGUMENTS}\"")
    endif()

    if( NOT _PAR_T_TARGET )
        ecbuild_critical( "The call to loki_transform_target() doesn't specify the TARGET." )
    endif()

    if( NOT _PAR_T_COMMAND )
        set( _PAR_T_COMMAND "convert" )
    endif()

    if( NOT _PAR_T_PLAN )
        ecbuild_critical( "No PLAN specified for loki_transform_target()" )
    endif()

    ecbuild_info( "[Loki] Loki scheduler:: target=${_PAR_T_TARGET} mode=${_PAR_T_MODE} frontend=${_PAR_T_FRONTEND}")

    # Ensure that changes to the config file trigger the planning stage
    foreach( target ${_PAR_T_TARGET} )
        configure_file( ${_PAR_T_CONFIG} ${CMAKE_CURRENT_BINARY_DIR}/loki_${target}.config )
    endforeach()

    # Create the bulk-transformation plan
    set( _PLAN_OPTIONS "" )
    if( _PAR_T_CPP_PLAN )
        list( APPEND _PLAN_OPTIONS CPP )
    endif()
    if( _PAR_T_NO_PLAN_SOURCEDIR )
        list( APPEND _PLAN_OPTIONS NO_SOURCEDIR )
    endif()

    string(REPLACE ";" "_" CALLGRAPH_NAME "${_PAR_T_TARGET}")

    loki_transform_plan(
        MODE      ${_PAR_T_MODE}
        CONFIG    ${_PAR_T_CONFIG}
        FRONTEND  ${_PAR_T_FRONTEND}
        SOURCES   ${_PAR_T_SOURCES}
        PLAN      ${_PAR_T_PLAN}
        CALLGRAPH ${CMAKE_CURRENT_BINARY_DIR}/callgraph_${CALLGRAPH_NAME}
        BUILDDIR  ${CMAKE_CURRENT_BINARY_DIR}
        SOURCEDIR ${CMAKE_CURRENT_SOURCE_DIR}
        ${_PLAN_OPTIONS}
    )

    include("${_PAR_T_PLAN}")
    ecbuild_info( "[Loki] Imported transformation plan: ${_PAR_T_PLAN}" )
    ecbuild_debug( "[Loki] Loki all transform: ${LOKI_SOURCES_TO_TRANSFORM}" )
    ecbuild_debug( "[Loki] Loki all append: ${LOKI_SOURCES_TO_APPEND}" )
    ecbuild_debug( "[Loki] Loki all remove: ${LOKI_SOURCES_TO_REMOVE}" )

    # Schedule the source-to-source transformation on the source files from the schedule
    list( LENGTH LOKI_SOURCES_TO_TRANSFORM LOKI_APPEND_LENGTH )
    if ( LOKI_APPEND_LENGTH GREATER 0 )

        # Apply the bulk-transformation according to the plan
        set( _TRANSFORM_OPTIONS "" )
        if( _PAR_T_CPP )
            list( APPEND _TRANSFORM_OPTIONS CPP )
        endif()

        loki_transform(
            COMMAND     ${_PAR_T_COMMAND}
            OUTPUT      ${LOKI_SOURCES_TO_APPEND}
            MODE        ${_PAR_T_MODE}
            CONFIG      ${_PAR_T_CONFIG}
            FRONTEND    ${_PAR_T_FRONTEND}
            BUILDDIR    ${CMAKE_CURRENT_BINARY_DIR}
            SOURCES     ${_PAR_T_SOURCES}
            HEADERS     ${_PAR_T_HEADERS}
            DEFINITIONS ${_PAR_T_DEFINITIONS}
            INCLUDES    ${_PAR_T_INCLUDES}
            DEPENDS     ${LOKI_SOURCES_TO_TRANSFORM} ${_PAR_T_HEADERS} ${_PAR_T_CONFIG}
            ${_TRANSFORM_OPTIONS}
        )
    endif()


    set(_TARGETS_POSTFIX "")
    list( LENGTH _PAR_T_TARGET TARGETS_LENGTH )
    if (TARGETS_LENGTH GREATER 1)
        foreach(_target ${_PAR_T_TARGET})
	    # sanitize target name, e.g., replace '.' with '-' (same things happen within Loki)
            string(REPLACE "." "_" _sanitized_target ${_target})
            list( APPEND _TARGETS_POSTFIX "_${_sanitized_target}")
        endforeach()
    endif()

    unset( _UPDATE_TARGET_SOURCES_OPTIONS )
    if( _PAR_T_COPY_UNMODIFIED )
       list( APPEND _UPDATE_TARGET_SOURCES_OPTIONS COPY_UNMODIFIED )
    endif()

    foreach(_target _postfix IN ZIP_LISTS _PAR_T_TARGET _TARGETS_POSTFIX)
        ecbuild_debug( "[Loki] Loki ${_target} transform: ${LOKI_SOURCES_TO_TRANSFORM${_postfix}}")
        ecbuild_debug( "[Loki] Loki ${_target} append: ${LOKI_SOURCES_TO_APPEND${_postfix}}" )
        ecbuild_debug( "[Loki] Loki ${_target} remove: ${LOKI_SOURCES_TO_REMOVE${_postfix}}" )

        # update target sources using the plan
        loki_update_target_sources(
           TARGET            ${_target}
           REMOVE_SOURCES    ${LOKI_SOURCES_TO_REMOVE${_postfix}}
           TRANSFORM_SOURCES ${LOKI_SOURCES_TO_TRANSFORM${_postfix}}
           APPEND_SOURCES    ${LOKI_SOURCES_TO_APPEND${_postfix}}
           ${_UPDATE_TARGET_SOURCES_OPTIONS}
        )

    endforeach()
endfunction()


##############################################################################
# .rst:
#
# generate_xmod
# =============
#
# Call OMNI's F_Front on a file to generate its xml-parse tree and, as a
# side effect, xmod-file.::
#
#   generate_xmod(
#       OUTPUT 
#       SOURCE 
#       [XMOD  [ ...]]
#       [DEPENDS  [ ...]]
#   )
#
# Note that the xmod-file will be located in the first path given to ``XMOD``.
#
##############################################################################
function( generate_xmod )

    set( options )
    set( oneValueArgs SOURCE OUTPUT )
    set( multiValueArgs XMOD DEPENDS )

    cmake_parse_arguments( _PAR "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} )

    if( NOT _PAR_OUTPUT )
        ecbuild_critical( "No OUTPUT given for generate_xmod()" )
    endif()

    if( NOT _PAR_SOURCE )
        ecbuild_critical( "No SOURCE given for generate_xmod()" )
    endif()

    set( _ARGS )
    list( APPEND _ARGS -fleave-comment )

    if( _PAR_XMOD )
        foreach( XMOD ${_PAR_XMOD} )
            list( APPEND _ARGS -M ${XMOD} )
        endforeach()
    endif()

    set( _F_FRONT_EXECUTABLE F_Front )

    add_custom_command(
        OUTPUT ${_PAR_OUTPUT}
        COMMAND ${_F_FRONT_EXECUTABLE} ${_ARGS} -o ${_PAR_OUTPUT} ${_PAR_SOURCE}
        DEPENDS ${_PAR_SOURCE} ${_PAR_DEPENDS}
        COMMENT "[OMNI] Pre-processing: ${_PAR_SOURCE}"
    )

endfunction()
loki-ecmwf-0.3.6/cmake/loki_find_executables.cmake0000664000175000017500000000771715167130205022361 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

##############################################################################
#.rst:
#
# loki_find_executables
# =====================
#
# Find Loki's executable frontend scripts and make them available as
# (imported) targets. ::
#
#   loki_find_executables()
#
# It adds all scripts in the list `LOKI_EXECUTABLES` using `add_executable`,
# either by setting explicitly the path to the installed scripts or by
# searching for them using `find_program` if Loki is not being installed by CMake.
#
# Additionally, `clawfc` is also being searched for and made available as
# an executable, if it has not been exported as a target already.
#
# Input variables
# ---------------
#
# :LOKI_EXECUTABLES:    The names of all Loki executables.
# :loki_HAVE_NO_INSTALL: If True, Loki is considered not to be installed by
#                       CMake and all executables are searched for using
#                       `find_program`.
# :Python3_VENV_BIN:    The `bin` directory path of Loki's virtual environment.
#                       Executable scripts are used from this folder if
#                       `loki_HAVE_NO_INSTALL` is false.
# :loki_HAVE_CLAW:      If True, then CLAW should be installed and usable and
#                       `clawfc` is added as an executable.
#
##############################################################################
macro( loki_find_executables )

    ecbuild_debug( "LOKI_EXECUTABLES=${LOKI_EXECUTABLES}" )

    # Make Loki executables (and clawfc) available as imported executable targets
    # (this is required for the macros in loki_transform to set up their environment)
    if( ${loki_HAVE_NO_INSTALL} )

        # Make CLI executables available in add_custom_command by searching
        # for them on the $PATH using find_program
        foreach( _exe_name IN LISTS LOKI_EXECUTABLES )
            if( NOT TARGET ${_exe_name} )
                find_program( _exe_program NAMES ${_exe_name} )
                add_executable( ${_exe_name} IMPORTED GLOBAL )
                set_property( TARGET ${_exe_name} PROPERTY IMPORTED_LOCATION ${_exe_program} )
                ecbuild_debug( "Adding executable ${_exe_name} from ${_exe_program}" )
                unset( _exe_program CACHE )
            endif()
        endforeach()

    else()

        # Find the path of the virtual environment relative to the binary directory
        # because that is also how we install it in the prefix location

        # Create a bin directory in the install location and add the Python binaries
        # as a quasi-symlink
        install( CODE "
            file( MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/bin\" )
        ")

        # Make CLI executables available in add_custom_command by setting
        # their location to the virtual environment's bin folder
        foreach( _exe_name IN LISTS LOKI_EXECUTABLES )
            if( NOT TARGET ${_exe_name} )
                add_executable( ${_exe_name} IMPORTED GLOBAL )
                set_property( TARGET ${_exe_name} PROPERTY IMPORTED_LOCATION ${Python3_VENV_BIN}/${_exe_name} )
                ecbuild_debug( "Adding executable ${_exe_name} from ${Python3_VENV_BIN}/${_exe_name}" )
            endif()

            # Create symlinks for frontend scripts when actually installing Loki (in the CMake sense)
            install( CODE "
                file( REAL_PATH \${CMAKE_INSTALL_PREFIX}/var/${Python3_VENV_NAME}/bin _venv_bin )
                file( CREATE_LINK
                    \${_venv_bin}/${_exe_name}
                    \${CMAKE_INSTALL_PREFIX}/bin/${_exe_name}
                    SYMBOLIC
                )
            ")
        endforeach()

    endif()

endmacro()
loki-ecmwf-0.3.6/cmake/loki_python_macros.cmake0000664000175000017500000005301415167130205021731 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

##############################################################################
#.rst:
#
# loki_find_python_venv
# =====================
#
# Call ``find_package( Python3 )``, making sure to discover a specific
# virtual environment at the given location ``VENV_PATH``::
#
#   loki_find_python_venv( VENV_PATH  [ PYTHON_VERSION  ] )
#
# Options
# -------
# :VENV_PATH: The path to the virtual environment
# :PYTHON_VERSION: Optional specification of permissible Python versions for find_package
#
# Output variables
# ----------------
# :Python3_FOUND:       Exported into parent scope from FindPython3
# :Python3_EXECUTABLE:  Exported into parent scope from FindPython3
# :Python3_VENV_BIN:    The path to the virtual environment's `bin` directory
# :ENV{VIRTUAL_ENV}:    Environment variable with the virtual environment directory,
#                       emulating the activate script
#
##############################################################################

function( loki_find_python_venv )

    set( options "" )
    set( oneValueArgs VENV_PATH PYTHON_VERSION )
    set( multiValueArgs "" )

    cmake_parse_arguments( _PAR "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} )

    if( _PAR_UNPARSED_ARGUMENTS )
        message( FATAL_ERROR "Unknown keywords given to loki_find_python_venv(): \"${_PAR_UNPARSED_ARGUMENTS}\"" )
    endif()

    if( NOT _PAR_VENV_PATH )
        message( FATAL_ERROR "No VENV_PATH provided to loki_find_python_venv()" )
    endif()

    # Update the environment with VIRTUAL_ENV variable (mimic the activate script)
    set( ENV{VIRTUAL_ENV} "${_PAR_VENV_PATH}" )

    # Change the context of the search to only find the venv
    set( Python3_FIND_VIRTUALENV ONLY )

    # Unset Python3_EXECUTABLE because it is also an input variable
    # see https://cmake.org/cmake/help/latest/module/FindPython.html#artifacts-specification
    unset( Python3_EXECUTABLE )
    # To allow cmake to discover the newly created venv if Python3_ROOT_DIR
    # was passed as an argument at build-time
    set( Python3_ROOT_DIR "${_PAR_VENV_PATH}" )

    # Launch a new search
    find_package( Python3 ${_PAR_PYTHON_VERSION} COMPONENTS Interpreter REQUIRED )

    # Find the binary directory of the virtual environment
    execute_process(
        COMMAND ${Python3_EXECUTABLE} -c "import sys; import os.path; print(os.path.dirname(sys.executable), end='')"
        OUTPUT_VARIABLE Python3_VENV_BIN
    )

    # Forward variables to parent scope
    foreach ( _VAR_NAME Python3_FOUND Python3_EXECUTABLE Python3_VENV_BIN )
        set( ${_VAR_NAME} ${${_VAR_NAME}} PARENT_SCOPE )
    endforeach()

endfunction()

##############################################################################
#.rst:
#
# loki_create_python_venv
# =======================
#
# Discover a Python 3 interpreter and create a virtual environment at the
# specified location ``VENV_PATH``. ::
#
#   loki_create_python_venv( VENV_PATH  [ PYTHON_VERSION  ] [ INSTALL_VENV ] )
#
# Installation procedure
# ----------------------
#
# Create a virtual environment at the given location (`VENV_PATH`)
#
# Options
# -------
#
# :VENV_PATH: The path to use for the virtual environment
# :PYTHON_VERSION: Optional specification of permissible Python versions for find_package
# :INSTALL_VENV: If provided, an equivalent virtual environment will also be created in
#                ``${CMAKE_INSTALL_PREFIX}/var/${VENV_NAME}`` upon installation
#
##############################################################################

function( loki_create_python_venv )

    set( options INSTALL_VENV )
    set( oneValueArgs VENV_NAME PYTHON_VERSION )
    set( multiValueArgs "" )

    cmake_parse_arguments( _PAR "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} )

    if( _PAR_UNPARSED_ARGUMENTS )
        message( FATAL_ERROR "Unknown keywords given to loki_create_python_venv(): \"${_PAR_UNPARSED_ARGUMENTS}\"" )
    endif()

    if( NOT _PAR_VENV_NAME )
        message( FATAL_ERROR "No VENV_NAME provided to loki_create_python_venv()" )
    endif()

    set( VENV_PATH "${CMAKE_CURRENT_BINARY_DIR}/${_PAR_VENV_NAME}" )

    # Discover only system install Python 3
    set( Python3_FIND_VIRTUALENV STANDARD )
    find_package( Python3 ${_PAR_PYTHON_VERSION} COMPONENTS Interpreter REQUIRED )

    # Ensure the activate script is writable in case the venv exists already
    if( EXISTS "${VENV_PATH}/bin/activate" )
        file( CHMOD "${VENV_PATH}/bin/activate" FILE_PERMISSIONS OWNER_READ OWNER_WRITE )
    endif()

    # Create a virtualenv
    message( STATUS "Create Python virtual environment ${VENV_PATH}" )
    execute_process( COMMAND ${Python3_EXECUTABLE} -m venv "${VENV_PATH}" )
    set( Python3_VENV_NAME "${_PAR_VENV_NAME}" PARENT_SCOPE )

    # Upon installation, we create an equivalent Python venv in the installation directory
    if( _PAR_INSTALL_VENV )
        install(
            CODE
                "execute_process( COMMAND ${Python3_EXECUTABLE} -m venv \${CMAKE_INSTALL_PREFIX}/var/${_PAR_VENV_NAME} RESULT_VARIABLE _RET )"
        )
        set( Python3_INSTALL_VENV TRUE PARENT_SCOPE )
    endif()

endfunction()

##############################################################################
#.rst:
#
# loki_setup_python_venv
# =======================
#
# Find Python 3, create a virtual environment and make it available. ::
#
#   loki_setup_python_venv( VENV_PATH  [ PYTHON_VERSION  ] [ INSTALL_VENV ] )
#
# It combines calls to ``loki_create_python_venv`` and ``loki_find_python_venv``
#
# Options
# -------
#
# :VENV_PATH: The path to use for the virtual environment
# :PYTHON_VERSION: Optional specification of permissible Python versions for find_package
# :INSTALL_VENV: If provided, an equivalent virtual environment will also be created in
#                ``${CMAKE_INSTALL_PREFIX}/var/${VENV_NAME}`` upon installation
#
# Output variables
# ----------------
# :Python3_FOUND:        Exported into parent scope from FindPython3
# :Python3_EXECUTABLE:   Exported into parent scope from FindPython3
# :Python3_VENV_BIN:     The path to the virtual environment's `bin` directory
# :Python3_VENV_NAME:    The name of the virtual environment
# :Python3_INSTALL_VENV: Will be set with the value TRUE if INSTALL_VENV has been provided.
# :ENV{VIRTUAL_ENV}:     Environment variable with the virtual environment directory,
#                        emulating the activate script
#
##############################################################################

function( loki_setup_python_venv )

    set( options INSTALL_VENV )
    set( oneValueArgs VENV_NAME PYTHON_VERSION )
    set( multiValueArgs "" )

    cmake_parse_arguments( _PAR "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} )

    if( _PAR_UNPARSED_ARGUMENTS )
        message( FATAL_ERROR "Unknown keywords given to loki_setup_python_venv(): \"${_PAR_UNPARSED_ARGUMENTS}\"" )
    endif()

    if( NOT _PAR_VENV_NAME )
        message( FATAL_ERROR "No VENV_NAME provided to loki_setup_python_venv()" )
    endif()

    # Create the virtual environment
    set( _ARGS VENV_NAME "${_PAR_VENV_NAME}" )
    if( _PAR_PYTHON_VERSION )
        list( APPEND _ARGS PYTHON_VERSION "${_PAR_PYTHON_VERSION}" )
    endif()
    if( _PAR_INSTALL_VENV )
        list( APPEND _ARGS INSTALL_VENV )
    endif()

    loki_create_python_venv( ${_ARGS} )

    set( Python3_VENV_NAME "${Python3_VENV_NAME}" PARENT_SCOPE )
    if( DEFINED Python3_INSTALL_VENV )
        set( Python3_INSTALL_VENV "${Python3_INSTALL_VENV}" PARENT_SCOPE )
    endif()

    # Discover Python in the virtual environment and set-up variables
    set( _ARGS VENV_PATH "${CMAKE_CURRENT_BINARY_DIR}/${_PAR_VENV_NAME}" )
    if( _PAR_PYTHON_VERSION )
        list( APPEND _ARGS PYTHON_VERSION "${_PAR_PYTHON_VERSION}" )
    endif()
    loki_find_python_venv( ${_ARGS} )

    foreach ( _VAR_NAME Python3_FOUND Python3_EXECUTABLE Python3_VENV_BIN )
        set( ${_VAR_NAME} ${${_VAR_NAME}} PARENT_SCOPE )
    endforeach()

endfunction()

##############################################################################
#.rst:
#
# loki_download_python_wheels
# ===========================
#
# Download all dependencies for the given ``REQUIREMENT_SPEC`` and cache them in a
# wheelhouse at ``WHEELS_DIR``
#
#   loki_download_python_wheels(
#       REQUIREMENT_SPEC 
#       [ WHEELS_DIR  ]
#       [ WHEEL_ARCH  ]
#       [ WHEEL_PYTHON_VERSION  ]
#       [ PYTHON_VERSION  ] )
#
# Implementation note
# -------------------
#
# This function does intentionally not expose all PIP options directly because the PIP command line
# interface allows to specify option values via environment variables. These can therefore be used
# to further control the PIP behaviour, see https://pip.pypa.io/en/stable/cli/pip_download/
#
# Because PIP does not provide a mechanism for downloading PEP 518 build dependencies,
# this function builds the wheel also for the provided REQUIREMENT_SPEC instead of only downloading
# the required dependencies. See https://github.com/pypa/pip/issues/7863 for details.
# To provide a sane minimum, setuptools and wheel packages are always downloaded.
#
# The provided PYTHON_VERSION is used to discover a Python interpreter matching the version
# specification when calling pip. To download wheels for specific platforms or Python versions,
# use the PIP_PLATFORM, PIP_PYTHON_VERSION, PIP_IMPLEMENTATION, or PIP_ABI environment variables.
#
# It is safe to call this function during an offline build, as long as all wheels are already
# available in the wheelhouse. A dry-run call to ``pip install`` is used to determine the need
# for any wheel downloads before executing the ``pip download`` command.
#
# Options
# -------
#
# :REQUIREMENT_SPEC: The requirement spec as given to ``pip download`` and ``pip wheel``
# :WHEELS_DIR: The path of the wheelhouse directory to cache the wheels. Defaults to
#              ``${CMAKE_CURRENT_BINARY_DIR}/wheelhouse``
# :WHEEL_ARCH: Optional specification of architecture for which to download non-pure Python wheels
# :PYTHON_VERSION: Optional specification of permissible Python versions for find_package
#
##############################################################################

function( loki_download_python_wheels )

    set( options "" )
    set( oneValueArgs REQUIREMENT_SPEC WHEELS_DIR WHEEL_ARCH WHEEL_PYTHON_VERSION PYTHON_VERSION )
    set( multiValueArgs "" )

    cmake_parse_arguments( _PAR "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} )

    if( _PAR_UNPARSED_ARGUMENTS )
        message( FATAL_ERROR "Unknown keywords given to loki_download_python_wheels(): \"${_PAR_UNPARSED_ARGUMENTS}\"" )
    endif()

    if( NOT _PAR_REQUIREMENT_SPEC )
        message( FATAL_ERROR "No REQUIREMENT_SPEC provided to loki_download_python_wheels()" )
    endif()

    message( STATUS "Checking for cached wheels in ${WHEELS_DIR}" )

    # Check for a suitable python interpreter
    find_package( Python3 ${_PAR_PYTHON_VERSION} COMPONENTS Interpreter REQUIRED QUIET )

    # If no wheelhouse dir is given, create one in the current binary directory
    if( _PAR_WHEELS_DIR )
        set( WHEELS_DIR "${_PAR_WHEELS_DIR}" )
    else()
        set( WHEELS_DIR "${CMAKE_CURRENT_BINARY_DIR}/wheelhouse" )
    endif()
    file( MAKE_DIRECTORY "${WHEELS_DIR}" )

    unset( PIP_OPTIONS )
    if( DEFINED _PAR_WHEEL_ARCH AND NOT _PAR_WHEEL_ARCH MATCHES None|NONE )
        # PIP does not recognize the Python version anymore if it is enclosed
        # by quotes, thus we need to strip any spurious quotes from the version
        string( REPLACE "\"" "" _ARCH ${_PAR_WHEEL_ARCH} )
       list( APPEND PIP_OPTIONS --platform=${_ARCH} )
    endif()
    if( DEFINED _PAR_WHEEL_PYTHON_VERSION AND NOT _PAR_WHEEL_PYTHON_VERSION MATCHES None|NONE )
        # PIP does not recognize the Python version anymore if it is enclosed
        # by quotes, thus we need to strip any spurious quotes from the version
        string( REPLACE "\"" "" _PYTHON_VERSION ${_PAR_WHEEL_PYTHON_VERSION} )
        list( APPEND PIP_OPTIONS --python-version=${_PYTHON_VERSION} )
    endif()
    if( PIP_OPTIONS )
        list( APPEND PIP_OPTIONS --no-deps )
    endif()

    # We use a dry-run installation to check if all dependencies have already been downloaded
    set( _CMD
        ${Python3_EXECUTABLE} -m pip install
            --dry-run --break-system-packages
            --no-index --find-links "${WHEELS_DIR}" --only-binary :all:
            ${PIP_OPTIONS} ${_PAR_REQUIREMENT_SPEC}
    )
    execute_process(
        COMMAND ${_CMD}
        OUTPUT_QUIET ERROR_QUIET
        RESULT_VARIABLE _RET_VAL
    )

    if( "${_RET_VAL}" EQUAL "0" )

        message( STATUS "All dependency wheels for ${_PAR_REQUIREMENT_SPEC} found in cache" )

    else()

        message( STATUS "Downloading dependency wheels for ${_PAR_REQUIREMENT_SPEC} to ${WHEELS_DIR}" )

        # Download typical build dependencies for wheels: setuptools and wheel
        set( _CMD
            ${Python3_EXECUTABLE} -m pip download
            --disable-pip-version-check --dest "${WHEELS_DIR}"
            ${PIP_OPTIONS} setuptools>=75.0.0 wheel
        )
        execute_process(
            COMMAND ${_CMD}
            OUTPUT_QUIET
        )

        # Download dependencies for the specified REQUIREMENT_SPEC
        set( _CMD
            ${Python3_EXECUTABLE} -m pip download
            --disable-pip-version-check --dest "${WHEELS_DIR}"
            ${PIP_OPTIONS} ${_PAR_REQUIREMENT_SPEC}
        )
        execute_process(
            COMMAND ${_CMD}
            OUTPUT_QUIET
        )

        # Here we _build_ the package instead of just downloading its build dependencies. Sadly, this is necessary because
        # PIP does not yet provide a mechanism to download the build dependencies for PEP 518 packages.
        # See https://github.com/pypa/pip/issues/7863 for details.
        # When this is resolved, we should instead download only build dependencies here, which will defer the actual wheel
        # building to the `build_python_wheel` function
        execute_process(
            COMMAND
                ${Python3_EXECUTABLE} -m pip wheel
                    --disable-pip-version-check --wheel-dir "${WHEELS_DIR}"
                    ${_PAR_REQUIREMENT_SPEC}
            OUTPUT_QUIET
        )

    endif()

endfunction()

##############################################################################
#.rst:
#
# loki_build_python_wheels
# ========================
#
# Build a Python wheel for the given ``REQUIREMENT_SPEC`` and store it in the
# specified ``BUILD_DIR``. This uses no online sources to download packages,
# any required dependencies must be available in ``WHEELS_DIR``.
# Use ``download_python_wheels`` to make them available if necessary.
#
#   loki_build_python_wheels( REQUIREMENT_SPEC  [ BUILD_DIR  ] [ WHEELS_DIR  ] [ PYTHON_VERSION  ] )
#
# Implementation note
# -------------------
#
# This function does intentionally not expose all PIP options directly because the PIP command line
# interface allows to specify option values via environment variables. These can therefore be used
# to further control the PIP behaviour, see https://pip.pypa.io/en/stable/cli/pip_download/
#
# The provided PYTHON_VERSION is used to discover a Python interpreter matching the version
# specification when calling pip. To build wheels for specific platforms or Python versions,
# use the PIP_PLATFORM, PIP_PYTHON_VERSION, PIP_IMPLEMENTATION, or PIP_ABI environment variables.
#
# Options
# -------
#
# :REQUIREMENT_SPEC: The requirement spec as given to ``pip download`` and ``pip wheel``
# :BUILD_DIR: The path to store the built wheel. Defaults to ``${CMAKE_CURRENT_BINARY_DIR}/wheelhouse``
# :WHEELS_DIR: The path of the wheelhouse directory where to look for cached wheels. Defaults to ``BUILD_DIR``
# :PYTHON_VERSION: Optional specification of permissible Python versions for find_package
#
##############################################################################

function( loki_build_python_wheels )

    set( options "" )
    set( oneValueArgs REQUIREMENT_SPEC WHEELS_DIR BUILD_DIR )
    set( multiValueArgs "" )

    cmake_parse_arguments( _PAR "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} )

    if( _PAR_UNPARSED_ARGUMENTS )
        message( FATAL_ERROR "Unknown keywords given to loki_build_python_wheels(): \"${_PAR_UNPARSED_ARGUMENTS}\"" )
    endif()

    if( NOT _PAR_REQUIREMENT_SPEC )
        message( FATAL_ERROR "No REQUIREMENT_SPEC provided to loki_build_python_wheels()" )
    endif()

    message( STATUS "Building wheel for ${REQUIREMENT_SPEC}" )

    # Check for a suitable python interpreter
    find_package( Python3 ${_PAR_PYTHON_VERSION} COMPONENTS Interpreter REQUIRED QUIET )

    # If no build dir is given, create one in the current binary directory
    if( _PAR_BUILD_DIR )
        set( BUILD_DIR "${_PAR_BUILD_DIR}" )
    else()
        set( BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/wheelhouse" )
    endif()
    file( MAKE_DIRECTORY "${BUILD_DIR}" )

    # If no wheelhouse is given, use the build directory
    if( _PAR_WHEELS_DIR )
        set( WHEELS_DIR "${_PAR_WHEELS_DIR}" )
    else()
        set( WHEELS_DIR "${BUILD_DIR}" )
    endif()
    file( MAKE_DIRECTORY "${WHEELS_DIR}" )

    execute_process(
        COMMAND
            ${Python3_EXECUTABLE} -m pip wheel
                --no-index --find-links "${WHEELS_DIR}" --wheel-dir "${BUILD_DIR}"
                ${_PAR_REQUIREMENT_SPEC}
    )

endfunction()

##############################################################################
#.rst:
#
# loki_install_python_package
# ===========================
#
# Install a Python package using the provided ``REQUIREMENT_SPEC``.
#
#   loki_install_python_package( REQUIREMENT_SPEC  [ WHEELS_DIR  ] [ EDITABLE ] )
#
# This assumes that the ``Python3_EXECUTABLE`` has been made available to use, e.g.,
# via a ``find_package( Python3 )``, ``loki_find_python_venv()`` or ``loki_setup_python_venv()``.
#
# By default this will search for the package and its dependencies in the
# standard package index.
# Providing a wheelhouse ``WHEELS_DIR`` ensures that this installation is
# an offline operation, taking wheels only from the provided path.
# If required, these can be fetched explicitly via ``download_python_wheels``.
#
# If Python3_INSTALL_VENV variable is set, the package will also be installed
# into the virtual environment at installation time.
#
# Implementation note
# -------------------
#
# This function does intentionally not expose all PIP options directly because the PIP command line
# interface allows to specify option values via environment variables. These can therefore be used
# to further control the PIP behaviour, see https://pip.pypa.io/en/stable/cli/pip_download/
#
# The provided PYTHON_VERSION is used to discover a Python interpreter matching the version
# specification when calling pip. To build wheels for specific platforms or Python versions,
# use the PIP_PLATFORM, PIP_PYTHON_VERSION, PIP_IMPLEMENTATION, or PIP_ABI environment variables.
#
# Options
# -------
#
# :REQUIREMENT_SPEC: The requirement spec as given to ``pip download`` and ``pip wheel``
# :WHEELS_DIR: The path of the wheelhouse directory where to look for cached wheels.
#
##############################################################################
function( loki_install_python_package )

    set( options EDITABLE )
    set( oneValueArgs REQUIREMENT_SPEC WHEELS_DIR )
    set( multiValueArgs "" )

    cmake_parse_arguments( _PAR "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} )

    if( _PAR_UNPARSED_ARGUMENTS )
        message( FATAL_ERROR "Unknown keywords given to loki_install_python_package(): \"${_PAR_UNPARSED_ARGUMENTS}\"" )
    endif()

    if( NOT _PAR_REQUIREMENT_SPEC )
        message( FATAL_ERROR "No REQUIREMENT_SPEC provided to loki_install_python_package()" )
    endif()

    # Check for a suitable python interpreter
    find_package( Python3 ${_PAR_PYTHON_VERSION} COMPONENTS Interpreter REQUIRED QUIET )

    if( _PAR_WHEELS_DIR )
        # Force installation from provided wheelhouse
        set( INSTALL_OPTS --no-index "--find-links=${_PAR_WHEELS_DIR}" )
    else()
        # Default pip install
        set( INSTALL_OPTS --disable-pip-version-check )
    endif()

    if( _PAR_EDITABLE )
        set( INSTALL_OPTS ${INSTALL_OPTS} -e )
    endif()

    message( STATUS "Installing Python package ${_PAR_REQUIREMENT_SPEC}" )

    set( OUTPUT_OPTIONS OUTPUT_VARIABLE _OUTPUT ERROR_VARIABLE _OUTPUT )
    if( ${CMAKE_VERBOSE_MAKEFILE} )
        list(
            APPEND
                OUTPUT_OPTIONS
            ECHO_OUTPUT_VARIABLE
            ECHO_ERROR_VARIABLE
            COMMAND_ECHO STDOUT
        )
    endif()

    # Install package
    execute_process(
        COMMAND ${Python3_EXECUTABLE} -m pip install ${INSTALL_OPTS} ${_PAR_REQUIREMENT_SPEC}
        COMMAND_ERROR_IS_FATAL ANY
        WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
        ${OUTPUT_OPTIONS}
    )

    # Upon installation, repeat the installation
    if( ${Python3_INSTALL_VENV} )
        if( DEFINED ENV{SETUPTOOLS_SCM_PRETEND_VERSION} )
            install(CODE "set( ENV{SETUPTOOLS_SCM_PRETEND_VERSION} $ENV{SETUPTOOLS_SCM_PRETEND_VERSION})")
        endif()
        install(
            CODE
                "execute_process( COMMAND \${CMAKE_INSTALL_PREFIX}/var/${Python3_VENV_NAME}/bin/python -m pip install ${INSTALL_OPTS} ${_PAR_REQUIREMENT_SPEC} )"
        )
    endif()

endfunction()
loki-ecmwf-0.3.6/cmake/omni_compiler.cmake0000664000175000017500000000516615167130205020667 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

##############################################################################
# .rst:
#
# install_omni_compiler
# =====================
#
# Download and install OMNI Compiler. ::
#
#   install_omni_compiler(VERSION)
#
# Installation procedure
# ----------------------
#
# OMNI will be installed during the build step into folder
# `omni-compiler` in the current binary directory ``${CMAKE_CURRENT_BINARY_DIR}``.
#
# Options
# -------
#
# :VERSION:     The git branch or tag to download
#
# Output variables
# ----------------
# :OMNI_DIR:    The directory into which OMNI has been installed
#
##############################################################################

include( FetchContent )
include( ExternalProject )

function(install_omni_compiler VERSION)

    set( OMNI_DIR "" )
    message( STATUS "Downloading OMNI Compiler")

    # Bootstrap OpenJDK and Ant, if necessary
    add_subdirectory( cmake/cmake-jdk-ant )

    # Build OMNI Compiler
    FetchContent_Declare(
        omni_compiler
        GIT_REPOSITORY  https://github.com/omni-compiler/xcodeml-tools.git
        GIT_TAG         ${VERSION}
        GIT_SHALLOW     ON
    )

    # Need to fetch manually to be able to do an "in-build installation"
    FetchContent_GetProperties( omni_compiler )
    if( NOT omni_compiler_POPULATED )
        FetchContent_Populate( omni_compiler )

        set( OMNI_DIR ${CMAKE_CURRENT_BINARY_DIR}/omni-compiler )

    endif()

    find_program(MAKE_EXECUTABLE NAMES gmake make mingw32-make REQUIRED)

    ExternalProject_Add(
        omni
        SOURCE_DIR ${omni_compiler_SOURCE_DIR}
        BINARY_DIR ${omni_compiler_BINARY_DIR}
        INSTALL_DIR ${OMNI_DIR}

        # Can skip this as FetchContent will take care of it at configure time
        DOWNLOAD_COMMAND ""
        UPDATE_COMMAND ""
        PATCH_COMMAND ""

        # Specify in-build installation target
        CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${OMNI_DIR} -DJAVA_HOME=${JAVA_HOME}
    )

    add_executable( F_Front IMPORTED GLOBAL )
    set_property( TARGET F_Front PROPERTY IMPORTED_LOCATION ${OMNI_DIR}/bin/F_Front )
    add_dependencies( F_Front omni )

    # Forward variables to parent scope
    foreach ( _VAR_NAME OMNI_DIR )
        set( ${_VAR_NAME} ${${_VAR_NAME}} PARENT_SCOPE )
    endforeach()

endfunction()
loki-ecmwf-0.3.6/cmake/loki_get_python_wheels.cmake0000664000175000017500000000172515167130205022575 0ustar  alastairalastair# (C) Copyright 2024- ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
cmake_minimum_required( VERSION 3.19 FATAL_ERROR )

list( APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}" )
include( loki_python_macros )

set( WHEELS_DIR "${CMAKE_CURRENT_BINARY_DIR}/wheels" CACHE PATH "" )
set( REQUIREMENT_SPEC "${CMAKE_CURRENT_LIST_DIR}/.." CACHE STRING "" )
set( LOKI_WHEEL_ARCH NONE CACHE STRING "" )
set( LOKI_WHEEL_PYTHON_VERSION CACHE STRING "" )

loki_download_python_wheels(
    REQUIREMENT_SPEC        ${REQUIREMENT_SPEC}
    WHEELS_DIR              ${WHEELS_DIR}
    WHEEL_ARCH              ${LOKI_WHEEL_ARCH}
    WHEEL_PYTHON_VERSION    ${LOKI_WHEEL_PYTHON_VERSION}
)
loki-ecmwf-0.3.6/cmake/loki_transform_helpers.cmake0000664000175000017500000001556115167130205022606 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

#
# Utility macro to translate single value and multi value arguments
# in loki_transform to command line arguments for loki-transform.py
#
macro( _loki_transform_parse_args )

    if( _PAR_MODE )
        list( APPEND _ARGS --mode ${_PAR_MODE} )
    endif()

    if( _PAR_CONFIG )
        list( APPEND _ARGS --config ${_PAR_CONFIG} )
    endif()

    if( _PAR_BUILDDIR )
        file( MAKE_DIRECTORY ${_PAR_BUILDDIR} )
        list( APPEND _ARGS --build ${_PAR_BUILDDIR} )
    endif()

    if( _PAR_DIRECTIVE )
        list( APPEND _ARGS --directive ${_PAR_DIRECTIVE} )
    endif()

    if( _PAR_FRONTEND )
        list( APPEND _ARGS --frontend ${_PAR_FRONTEND} )
    endif()

    if( _PAR_SOURCES )
        foreach( _SOURCE ${_PAR_SOURCES} )
            list( APPEND _ARGS --source ${_SOURCE} )
        endforeach()
    endif()

    if( _PAR_HEADERS )
        foreach( _HEADER ${_PAR_HEADERS} )
            list( APPEND _ARGS --header ${_HEADER} )
        endforeach()
    endif()

    if( _PAR_INCLUDES )
        foreach( _INCLUDE ${_PAR_INCLUDES} )
            list( APPEND _ARGS --include ${_INCLUDE} )
        endforeach()
    endif()

    if( _PAR_DEFINITIONS )
        foreach( _DEFINE ${_PAR_DEFINITIONS} )
            list( APPEND _ARGS --define ${_DEFINE} )
        endforeach()
    endif()

    if( _PAR_OMNI_INCLUDE )
        foreach( _OMNI_INCLUDE ${_PAR_OMNI_INCLUDE} )
            list( APPEND _ARGS --omni-include ${_OMNI_INCLUDE} )
        endforeach()
    endif()

    if( _PAR_XMOD )
        foreach( _XMOD ${_PAR_XMOD} )
            file( MAKE_DIRECTORY ${XMOD_DIR} )
            list( APPEND _ARGS --xmod ${_XMOD} )
        endforeach()
    endif()

endmacro()


##############################################################################

macro( _loki_transform_env_setup )

    # The full path of the loki-transform.py executable
    get_target_property( _LOKI_TRANSFORM_EXECUTABLE loki-transform.py IMPORTED_LOCATION )

    set( _LOKI_TRANSFORM_ENV )
    set( _LOKI_TRANSFORM_PATH )

    if( TARGET clawfc AND "${_PAR_FRONTEND}" STREQUAL "omni" )
        # Ugly hack but I don't have a better solution: We need to add F_FRONT
        # (which is installed in the same directory as clawfc) to the PATH, if
        # OMNI is used as a frontend. Hence we have to update the environment in the below
        # add_custom_command calls to loki-transform.py.
        get_target_property( _CLAWFC_EXECUTABLE clawfc IMPORTED_LOCATION )
        get_filename_component( _CLAWFC_LOCATION ${_CLAWFC_EXECUTABLE} DIRECTORY )
        list( APPEND _LOKI_TRANSFORM_PATH ${_CLAWFC_LOCATION} )
    endif()

    if( _PAR_OUTPATH AND ("${_PAR_FRONTEND}" STREQUAL "omni" OR "${_PAR_FRONTEND}" STREQUAL "ofp") )
        # With pre-processing, we may end up having a race condition on the preprocessed
        # source files in parallel builds. Ensuring we use the outpath of the call to Loki
        # should ensure in most cases that parallel builds write to different directories
        # Note: this does not affect Fparser as we don't have to write preprocessed files
        # to disk there
        list( APPEND _LOKI_TRANSFORM_ENV LOKI_TMP_DIR=${_PAR_OUTPATH} )
    endif()

    if( _LOKI_TRANSFORM_ENV OR _LOKI_TRANSFORM_PATH )
        if( TARGET loki-transform.py )
            # Unfortunately, an environment update breaks the CMake feature of recognizing
            # the executable in add_custom_command as a previously declared target, which would
            # enable choosing the correct path automatically. Therefore, we have to insert also
            # loki-transform.py into the PATH variable.
            get_filename_component( _LOKI_TRANSFORM_LOCATION ${_LOKI_TRANSFORM_EXECUTABLE} DIRECTORY )
            list( APPEND _LOKI_TRANSFORM_PATH ${_LOKI_TRANSFORM_LOCATION} )
        endif()

        # Join all declared paths
        string( REPLACE ";" ":" _LOKI_TRANSFORM_PATH "${_LOKI_TRANSFORM_PATH}" )
        list( APPEND _LOKI_TRANSFORM_ENV PATH=${_LOKI_TRANSFORM_PATH}:$ENV{PATH} )

        # Run loki-transform.py via the CMake ENV wrapper
        set( _LOKI_TRANSFORM ${CMAKE_COMMAND} -E env ${_LOKI_TRANSFORM_ENV} loki-transform.py )

        # Also, now it breaks the dependency chain and we have to declare manual dependencies on
        # loki-transform.py...
        set( _LOKI_TRANSFORM_DEPENDENCY loki-transform.py )
    else()
        # This is how it is meant to be: We can rely on CMake's ability to set the correct
        # path of loki-transform.py if it was declared as an executable before (otherwise it
        # will assume it has been already on the path when CMake was called
        set( _LOKI_TRANSFORM loki-transform.py )
        set( _LOKI_TRANSFORM_DEPENDENCY "" )
    endif()

endmacro()

##############################################################################
# .rst:
#
# loki_copy_compile_flags
# =======================
#
# Copy compile flags from a list of source files to a list of source files.::
#
#
#   loki_copy_compile_flags( ORIG_LIST NEW_LIST )
#
# ``ORIG_LIST`` and ``NEW_LIST`` must have the same length. Compile flags are
# copied per-entry, this means matching indices between ``ORIG_LIST`` and
# ``NEW_LIST`` is assumed.
#
##############################################################################
function( loki_copy_compile_flags )

    set( options "" )
    set( single_value_args "" )
    set( multi_value_args ORIG_LIST NEW_LIST )

    cmake_parse_arguments( _PAR "${options}" "${single_value_args}" "${multi_value_args}" ${ARGN} )

    # Copy over compile flags for generated source. Note that this assumes
    # matching indexes between ORIG_LIST and NEW_LIST to encode the source-to-source mapping.
    list( LENGTH _PAR_ORIG_LIST nsources )
    math( EXPR maxidx "${nsources} - 1" )
    if ( nsources GREATER 0 )
        foreach( idx RANGE 0 ${maxidx} )
            list( GET _PAR_ORIG_LIST ${idx} orig )
            list( GET _PAR_NEW_LIST ${idx} newsrc )

            ecbuild_debug( "[Loki] loki_copy_compile_flags: ${orig} -> ${newsrc}" )

            foreach( _prop COMPILE_FLAGS
                     COMPILE_FLAGS_${CMAKE_BUILD_TYPE_CAPS}
                     OVERRIDE_COMPILE_FLAGS
                     OVERRIDE_COMPILE_FLAGS_${CMAKE_BUILD_TYPE_CAPS} )

                get_source_file_property( ${orig}_${_prop} ${orig} ${_prop} )
                if( ${orig}_${_prop} )
                    set_source_files_properties( ${newsrc} PROPERTIES ${_prop} ${${orig}_${_prop}} )
                endif()
            endforeach()
        endforeach()
    endif()

endfunction()

##############################################################################
loki-ecmwf-0.3.6/cmake/cmake-jdk-ant/0000775000175000017500000000000015167130205017427 5ustar  alastairalastairloki-ecmwf-0.3.6/cmake/cmake-jdk-ant/README.md0000664000175000017500000000265315167130205020714 0ustar  alastairalastair# cmake-jdk-ant

Contact: Balthasar Reuter (balthasar.reuter@ecmwf.int)

A CMake project to bootstrap OpenJDK and Ant during the configuration phase.

Variables influencing the behaviour of this CMake configuration:

* `MINIMUM_JAVA_VERSION`: The minimum JDK and JRE version that should be available. If no Java is found or a version too old, OpenJDK will be bootstrapped. Default: `11`
* `MINIMUM_ANT_VERSION`: The minimum Ant version that should be available. If no Ant is found or a version too old, Ant will be bootstrapped. Default: `1.10`
* `FORCE_OPEN_JDK_INSTALL`: Force bootstrapping of OpenJDK, regardless of any available version. Default: `OFF`
* `FORCE_ANT_INSTALL`: Force bootstrapping of Ant, regardless of any available version. Default: `OFF`
* `OPEN_JDK_INSTALL_VERSION`: The OpenJDK version to install. Default: `11.0.2`
* `ANT_INSTALL_VERSION`: The Ant version to install. Default: `1.10.12`
* `OPEN_JDK_MIRROR`: Allows to set an alternative mirror for OpenJDK download.
* `ANT_MIRROR`: Allows to set an alternative mirror for Ant download.

The purpose of this is to provide a way of on-the-fly installation of Java/Ant toolchain dependencies on systems where no usable setup is available.

## Example:

In a project that requires Java and Ant, add `cmake-jdk-ant` as a subdirectory:

```cmake
...
add_subdirectory( cmake-jdk-ant )
...
```

Subsequently, any calls to `find_package( Java )` will yield the bootstrapped toolchain.
loki-ecmwf-0.3.6/cmake/cmake-jdk-ant/cmake/0000775000175000017500000000000015167130205020507 5ustar  alastairalastairloki-ecmwf-0.3.6/cmake/cmake-jdk-ant/cmake/module/0000775000175000017500000000000015167130205021774 5ustar  alastairalastairloki-ecmwf-0.3.6/cmake/cmake-jdk-ant/cmake/module/FindAnt.cmake0000664000175000017500000000545115167130205024326 0ustar  alastairalastair#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#

# This file has been adapted from the FindAnt.cmake module
# of CLAW compiler:
# https://github.com/claw-project/claw-compiler/blob/master/cmake/module/FindAnt.cmake


#  ANT_FOUND - system has Ant
#  Ant_EXECUTABLE - the Ant executable
#  Ant_VERSION - the Ant version
#
# It will search the environment variable ANT_HOME if it is set

include(FindPackageHandleStandardArgs)

set ( _ANT_HOME "" )
if ( ANT_HOME AND IS_DIRECTORY "${ANT_HOME}" )
    set ( _ANT_HOME "${ANT_HOME}" )
else()
    set ( _ENV_ANT_HOME "" )
    if ( DEFINED ENV{ANT_HOME} )
        file ( TO_CMAKE_PATH "$ENV{ANT_HOME}" _ENV_ANT_HOME )
    endif ()
    if ( _ENV_ANT_HOME AND IS_DIRECTORY "${_ENV_ANT_HOME}" )
        set ( _ANT_HOME "${_ENV_ANT_HOME}" )
    endif ()
    unset ( _ENV_ANT_HOME )
endif()

find_program(Ant_EXECUTABLE NAMES ant HINTS ${_ANT_HOME}/bin)

unset ( _ANT_HOME )

if(Ant_EXECUTABLE)

    # Try to determine Ant version
    execute_process(COMMAND ${Ant_EXECUTABLE} -version
        RESULT_VARIABLE res
        OUTPUT_VARIABLE var
        ERROR_VARIABLE var
        OUTPUT_STRIP_TRAILING_WHITESPACE
        ERROR_STRIP_TRAILING_WHITESPACE
    )

    if( res )
        message( STATUS "Warning, could not run ant -version")
        unset(Ant_EXECUTABLE CACHE)
        unset(Ant_VERSION)
    else()
        # extract major/minor version and patch level from "ant -version" output
        if(var MATCHES "Apache Ant(.*)version ([0-9]+\\.[0-9]+\\.[0-9_.]+)(.*)")
            set(Ant_VERSION_STRING "${CMAKE_MATCH_2}")
        endif()
        string( REGEX REPLACE "([0-9]+).*" "\\1" Ant_VERSION_MAJOR "${Ant_VERSION_STRING}" )
        string( REGEX REPLACE "[0-9]+\\.([0-9]+).*" "\\1" Ant_VERSION_MINOR "${Ant_VERSION_STRING}" )
        string( REGEX REPLACE "[0-9]+\\.[0-9]+\\.([0-9]+).*" "\\1" Ant_VERSION_PATCH "${Ant_VERSION_STRING}" )
        set(Ant_VERSION ${Ant_VERSION_MAJOR}.${Ant_VERSION_MINOR}.${Ant_VERSION_PATCH})
    endif()

endif()

find_package_handle_standard_args(Ant REQUIRED_VARS Ant_EXECUTABLE VERSION_VAR Ant_VERSION)
mark_as_advanced(Ant_EXECUTABLE)
loki-ecmwf-0.3.6/cmake/cmake-jdk-ant/CMakeLists.txt0000664000175000017500000001234015167130205022167 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

# CMake 3.10 required to parse Java version
# CMake 3.14 required to use FetchContent_MakeAvailable
cmake_minimum_required( VERSION 3.14 )
project( cmake-jdk-ant VERSION 0.1 LANGUAGES NONE )

set( MINIMUM_JAVA_VERSION 11 CACHE STRING "Minimum Java version required" )
set( MINIMUM_ANT_VERSION 1.10 CACHE STRING "Minimum ant version required" )

set( FORCE_OPEN_JDK_INSTALL OFF CACHE BOOL "Force installation of OpenJDK" )
set( FORCE_ANT_INSTALL OFF CACHE BOOL "Force installation of OpenJDK" )

set( OPEN_JDK_INSTALL_VERSION 11.0.2 CACHE STRING "OpenJDK version to install if Java >= ${MINIMUM_JAVA_VERSION} not found" )
set( OPEN_JDK_MIRROR https://download.java.net/java/GA/jdk11/9/GPL/ CACHE STRING "OpenJDK download mirror" )

set( ANT_INSTALL_VERSION 1.10.15 CACHE STRING "ant version to install if Ant >= ${MINIMUM_ANT_VERSION} not found" )
set( ANT_MIRROR https://archive.apache.org/dist/ant/binaries/ CACHE STRING "ant download mirror" )

list( APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/module" )

include( FetchContent )

find_package( Java ${MINIMUM_JAVA_VERSION} COMPONENTS Development )

if( NOT Java_FOUND OR FORCE_OPEN_JDK_INSTALL )

    # Remove variables from cache to make sure find_package(Java) works correctly later
    unset( JAVA_HOME CACHE )
    unset( Java_JAVA_EXECUTABLE CACHE )

    # Use FetchContent to make available at configure time instead of build time (as would be the case with ExternalProject_Add)
    FetchContent_Declare(
        OpenJDK
        URL ${OPEN_JDK_MIRROR}/openjdk-${OPEN_JDK_INSTALL_VERSION}_linux-x64_bin.tar.gz
    )

    list( APPEND FETCH_CONTENT_LIST "OpenJDK" )
    message( STATUS "OpenJDK: Download and install version ${OPEN_JDK_INSTALL_VERSION}" )

elseif( DEFINED ENV{JAVA_HOME} )

    set( JAVA_HOME $ENV{JAVA_HOME} CACHE STRING "" )

endif()

find_package( Ant ${MINIMUM_ANT_VERSION} )

if( NOT Ant_FOUND OR FORCE_ANT_INSTALL )

    # Remove variables from cache to make sure find_package(Ant) works correctly later
    unset( ANT_HOME CACHE )
    unset( Ant_EXECUTABLE CACHE )

    # Use FetchContent to make available at configure time instead of build time (as would be the case with ExternalProject_Add)
    FetchContent_Declare(
        Ant
        URL ${ANT_MIRROR}/apache-ant-${ANT_INSTALL_VERSION}-bin.tar.gz
    )

    list( APPEND FETCH_CONTENT_LIST "Ant" )
    message( STATUS "Ant: Download and install version ${ANT_INSTALL_VERSION}" )

endif()

if( FETCH_CONTENT_LIST )

    # Trigger the actual downloads
    FetchContent_MakeAvailable ( ${FETCH_CONTENT_LIST} )

    # Re-discover Java
    if( "OpenJDK" IN_LIST FETCH_CONTENT_LIST )
        FetchContent_GetProperties( OpenJDK SOURCE_DIR OPEN_JDK_SOURCE_DIR BINARY_DIR OPEN_JDK_BINARY_DIR )

        # Create wrapper scripts that set JAVA_HOME for Java binaries
        file( MAKE_DIRECTORY "${OPEN_JDK_BINARY_DIR}/bin" )
        foreach( _JAVA_BINARY java javac javah jar javadoc )
            file(
                WRITE "${OPEN_JDK_SOURCE_DIR}/${_JAVA_BINARY}"
                "#!/bin/bash
                JAVA_HOME=${OPEN_JDK_SOURCE_DIR} ${OPEN_JDK_SOURCE_DIR}/bin/${_JAVA_BINARY} \"$@\""
            )
            file(
                COPY "${OPEN_JDK_SOURCE_DIR}/${_JAVA_BINARY}"
                DESTINATION "${OPEN_JDK_BINARY_DIR}/bin"
                FILE_PERMISSIONS OWNER_EXECUTE OWNER_WRITE OWNER_READ
            )
        endforeach()

        # Re-discover Java
        set( JAVA_HOME ${OPEN_JDK_BINARY_DIR} CACHE STRING "" )
        find_package( Java ${OPEN_JDK_INSTALL_VERSION} EXACT REQUIRED COMPONENTS Development )
    endif()

    # Re-discover Ant and fetch dependencies
    if( "Ant" IN_LIST FETCH_CONTENT_LIST )
        FetchContent_GetProperties( Ant SOURCE_DIR ANT_SOURCE_DIR BINARY_DIR ANT_BINARY_DIR )

        if( JAVA_HOME )
            set( _JAVA_HOME "JAVA_HOME=${JAVA_HOME}" )
        elseif( DEFINED ENV{JAVA_HOME} )
            set( _JAVA_HOME "JAVA_HOME=$ENV{JAVA_HOME}" )
        else()
            set( _JAVA_HOME "" )
        endif()

        # Create a wrapper script that sets ANT_HOME
        file(
            WRITE "${ANT_SOURCE_DIR}/ant"
            "#!/bin/bash
            ANT_HOME=${ANT_SOURCE_DIR} ${_JAVA_HOME} ${ANT_SOURCE_DIR}/bin/ant \"$@\""
        )
        file( MAKE_DIRECTORY "${ANT_BINARY_DIR}/bin" )
        file(
            COPY "${ANT_SOURCE_DIR}/ant"
            DESTINATION "${ANT_BINARY_DIR}/bin"
            FILE_PERMISSIONS OWNER_EXECUTE OWNER_WRITE OWNER_READ
        )

        # Download dependencies
        message( STATUS "Ant: Fetch dependencies" )
        execute_process( COMMAND "${ANT_BINARY_DIR}/bin/ant" -f "${ANT_SOURCE_DIR}/fetch.xml" -Ddest=optional OUTPUT_QUIET )

        # Re-discover ant
        set( ANT_HOME ${ANT_BINARY_DIR} CACHE STRING "" )
        find_package( Ant ${ANT_INSTALL_VERSION} EXACT REQUIRED )
    endif()

endif()

message ( VERBOSE "JAVA_HOME=\"${JAVA_HOME}\"" )
message ( VERBOSE "ANT_HOME=\"${ANT_HOME}\"" )
loki-ecmwf-0.3.6/cmake/CMakeLists.txt0000664000175000017500000000110115167130205017552 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

file( GLOB_RECURSE loki_support_files RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*" )

ecbuild_add_resources(
    TARGET ${PROJECT_NAME}_loki_support_files
    SOURCES_PACK
        ${loki_support_files}
)
loki-ecmwf-0.3.6/install0000775000175000017500000002337415167130205015346 0ustar  alastairalastair#!/usr/bin/env bash

# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

set -euo pipefail

hpc2020_java_version=11.0.6
hpc2020_python_version=3.11.10-01
hpc2020_cmake_version=3.28.3
hpc2020_meson_version=1.2.1
hpc2020_ninja_version=1.11.1

# Determine base path for loki
# Either take the root of the current git tree or, if not inside a git repository, then
# use the path of this install script
if [ $(git rev-parse --git-dir > /dev/null 2>&1) ]; then
  loki_path=$(git rev-parse --show-toplevel)
else
  loki_path=$(realpath $(dirname "$0"))
fi

# Configuration default values
verbose=false
is_hpc2020=false
venv_path=
with_jdk=false
with_omni=false
with_docs=false
with_dace=false
with_tests=true
with_examples=true

# Helper functions
print_usage() {
  echo "Usage: $0 [-v] [--hpc2020] [--use-venv[=]] [--with-*] [...]" >&2
}

print_usage_with_options() {
  echo "Loki install script. This installs Loki and selected dependencies."
  echo
  print_usage
  echo
  echo "Available options:"
  echo "  -h / --help                  Display this help message"
  echo "  -v                           Enable verbose output"
  echo "  --hpc2020                    Load HPC2020 (Atos) specific modules and settings"
  echo "  --use-venv[=]          Use existing virtual environment at "
  echo "  --with[out]-jdk              Install JDK instead of using system version (default: use system version)"
  echo "  --with[out]-omni             Install OMNI Compiler (default: disabled)"
  echo "  --with[out]-dace             Install DaCe (default: enabled)"
  echo "  --with[out]-tests            Install dependencies to run tests (default: enabled)"
  echo "  --with[out]-docs             Install dependencies to generate documentation (default: disabled)"
  echo "  --with[out]-examples         Install dependencies to run the example notebooks (default: enabled)"
}

print_step() {
  echo "------------------------------------------------------"
  echo "  $1"
  echo "------------------------------------------------------"
}

# Parse arguments
# (see https://stackoverflow.com/a/7680682)
optspec=":hv-:"
while getopts "$optspec" optchar; do
  case "${optchar}" in
    -)
      case "${OPTARG}" in
        hpc2020)     # Load ECMWF HPC2020 (Atos) specific modules and settings
          is_hpc2020=true
          ;;
        use-venv)    # Specify existing virtual environment
          venv_path=$(realpath "${!OPTIND}")
          OPTIND=$(( OPTIND + 1 ))
          ;;
        use-venv=*)  # Specify existing virtual environment
          venv_path=$(realpath "${OPTARG#*=}")
          ;;
        with-jdk)    # Enable installation of Java
          with_jdk=true
          ;;
        without-jdk) # Disable installation of Java
          with_jdk=false
          ;;
        with-omni)   # Enable installation of OMNI
          with_omni=true
          ;;
        without-omni) # Disable installation of OMNI
          with_omni=false
          ;;
        with-dace)    # Enable installation of DaCe
          with_dace=true
          ;;
        without-dace) # Disable installation of DaCe
          with_dace=false
          ;;
        with-tests)    # Enable installation of dependencies for running tests
          with_tests=true
          ;;
        without-tests) # Disable installation of dependencies for running tests
          with_tests=false
          ;;
        with-docs)    # Enable installation of dependencies for docs generation
          with_docs=true
          ;;
        without-docs) # Disable installation of dependencies for docs generation
          with_docs=false
          ;;
        with-examples)    # Enable installation of dependencies for notebook examples
          with_examples=true
          ;;
        without-examples) # Disable installation of dependencies for notebook examples
          with_examples=false
          ;;
        help)
          print_usage_with_options
          exit 2
          ;;
        *)
          echo "Unknown option '--${OPTARG}'." >&2
          print_usage
          echo "Try '$0 -h' for more options."
          exit 1
          ;;
      esac
      ;;
    h)
      print_usage_with_options
      exit 2
      ;;
    v)
      verbose=true
      ;;
    *)
      echo "Unknown option '-${OPTARG}'." >&2
      print_usage
      echo "Try '$0 -h' for more options."
      exit 1
      ;;
  esac
done


# Print configuration
if [ "$verbose" == true ]; then
  print_step "Installation configuration:"

  [[ "$is_hpc2020" == true ]]  && echo "    --hpc2020"
  [[ "$venv_path" != "" ]]   && echo "    --use-venv='$venv_path'"
  [[ "$with_jdk" == true ]]  && echo "    --with-jdk"
  [[ "$with_omni" == true ]] && echo "    --with-omni"
  [[ "$with_dace" == false ]] && echo "    --without-dace"
  [[ "$with_tests" == false ]] && echo "    --without-tests"
  [[ "$with_docs" == false ]] && echo "    --without-docs"
  [[ "$with_examples" == false ]] && echo "    --without-examples"
fi

# Load modules
if [ "$is_hpc2020" == true ]; then
  print_step "Loading HPC2020 modules and settings"

  module unload cmake
  module load cmake/${hpc2020_cmake_version}

  module unload meson
  module load meson/${hpc2020_meson_version}

  module unload ninja
  module load ninja/${hpc2020_ninja_version}

  if [ "$with_jdk" == false ]; then
    module unload java
    module load java/${hpc2020_java_version}
  fi

  if [ "$venv_path" == "" ]; then
    module unload python3
    module load python3/${hpc2020_python_version}
  fi

fi

#
# Create Python virtualenv
#

if [ "$venv_path" == "" ]; then
  print_step "Creating virtualenv"
  venv_path=${loki_path}/loki_env
  for activate_file in activate activate.csh activate.fish Activate.ps1; do
    if [ -f "${loki_path}/loki_env/bin/${activate_file}" ]; then
      chmod u+w "${loki_path}/loki_env/bin/${activate_file}"
    fi
  done
  python3 -m venv "${venv_path}"
fi

#
# Activate Python virtualenv
#

print_step "Activating virtualenv"
source "${venv_path}/bin/activate"

#
# Install Loki with Python dependencies
#

print_step "Installing Loki and Python dependencies"

cd "$loki_path"

pip_opts=()
[[ "$with_tests" == true ]] && pip_opts+=(tests)
[[ "$with_dace" == true ]] && pip_opts+=(dace)
[[ "$with_docs" == true ]] && pip_opts+=(docs)
[[ "$with_examples" == true ]] && pip_opts+=(examples)
pip_opts=$(printf ",%s" "${pip_opts[@]}")

if [ "$pip_opts" == "," ]; then
  pip_opts=
else
  pip_opts=[${pip_opts:1}]
fi

# Supply pretend version if not a git worktree
if [ ! -e .git ]; then
  export "SETUPTOOLS_SCM_PRETEND_VERSION=$(cat VERSION)"
fi


pip install --upgrade pip
pip install -e .$pip_opts  # Installs Loki dev copy in editable mode
pip install -e ./lint_rules

#
# Install Java
#

if [ "$with_jdk" == true ]; then
  print_step "Downloading and installing JDK"

  JDK_ARCHIVE=openjdk-11.0.2_linux-x64_bin.tar.gz
  JDK_URL=https://download.java.net/java/GA/jdk11/9/GPL/${JDK_ARCHIVE}
  JAVA_INSTALL_DIR=${venv_path}/opt/java
  export JAVA_HOME=${JAVA_INSTALL_DIR}/jdk-11.0.2

  mkdir -p "${JAVA_INSTALL_DIR}"
  rm -rf "${JAVA_HOME}" "${JAVA_INSTALL_DIR}/${JDK_ARCHIVE}"
  cd "${JAVA_INSTALL_DIR}"
  wget -O "${JDK_ARCHIVE}" "${JDK_URL}"
  tar -xzf "${JDK_ARCHIVE}"
fi

#
# Install OMNI
#

if [ "$with_omni" == true ]; then
  print_step "Downloading and installing OMNI Compiler"

  OMNI_INSTALL_DIR=${venv_path}/opt/omni
  mkdir -p "${OMNI_INSTALL_DIR}"
  cd "${OMNI_INSTALL_DIR}"
  rm -rf xcodeml-tools
  git clone --recursive --single-branch https://github.com/omni-compiler/xcodeml-tools.git xcodeml-tools

  cd xcodeml-tools

  omni_opts=()
  [[ ! -z "${JAVA_HOME}" ]] && omni_opts+=("JAVA_HOME=${JAVA_HOME}")

  # A CMake install would be cleaner but they inject without good reason a -Werror for
  # GNU and Clang without writing good enough code that actually avoids any warnings...
  # cmake -S . -B build -DCMAKE_INSTALL_PREFIX="${OMNI_INSTALL_DIR}" "${omni_opts[@]}"
  # cmake --build build
  # cmake --install build
  ./configure --prefix="${OMNI_INSTALL_DIR}" "${omni_opts[@]}"
  make && make install
fi

#
# Writing loki-activate script
#

print_step "Writing loki-activate script"

path_var=\${PATH}

echo "
# This script activates Loki's virtual environment, loads additional modules and sets dependend paths.
#
# Run as 'source loki-activate'

# Load virtualenv
. ${venv_path}/bin/activate
" > "${loki_path}/loki-activate"

# Load ECMWF modules, if required
if [ "${is_hpc2020}" == true ]; then
  if [ "$with_jdk" == false ]; then
    echo "
module unload java
module load java/${hpc2020_java_version}
" >> "${loki_path}/loki-activate"
  fi

  echo "
module unload cmake
module load cmake/${hpc2020_cmake_version}
" >> "${loki_path}/loki-activate"

  echo "
module unload meson
module load meson/${hpc2020_meson_version}
" >> "${loki_path}/loki-activate"

  echo "
module unload ninja
module load ninja/${hpc2020_ninja_version}
" >> "${loki_path}/loki-activate"
fi

# Inject self-installed JDK into env
if [ "$with_jdk" == true ]; then
  echo "
export JAVA_HOME=\"\${JAVA_HOME}\"
" >> "${loki_path}/loki-activate"
  path_var=${JAVA_HOME}/bin:$path_var
fi

# Inject OMNI into env
if [ "$with_omni" == true ]; then
  path_var=${OMNI_INSTALL_DIR}/bin:$path_var
fi

# Update path variable
echo "
export PATH=\"$path_var\"

echo \"Activated loki environment. Unload with 'deactivate'.\"
" >> "${loki_path}/loki-activate"

#
# Finish
#

print_step "Installation finished"
echo
echo "Activate the Loki environment with"
echo
echo "    source loki-activate"
echo

if [ "$with_tests" == true ]; then
  echo "You can test the Loki installation by running"
  echo
  echo "    pytest --pyargs loki lint_rules"
  echo
fi
loki-ecmwf-0.3.6/CMakeLists.txt0000664000175000017500000001463215167130205016507 0ustar  alastairalastair# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

##############################################################################
#.rst:
#
# loki
# ====
#
# Install Loki with dependencies. ::
#
# Features
# --------
#
# :NO_INSTALL:  Do not install Loki itself but make the CMake configuration
#               available (Default: ``OFF``)
# :EDITABLE:    Install Loki as an editable package (Default: ``OFF``)
#
# Installation procedure
# ----------------------
#
# A virtual environment is created for Loki into which it is installed along
# with any dependencies. The CLI scripts ``loki-transform.py`` and ``loki-lint.py``
# are made available as executable targets, thus can be used from any subsequent
# ``add_custom_command`` statements.
#
##############################################################################

# Version 3.12 required to use FindPython
# Version 3.15 officially required to use Python3_FIND_VIRTUALENV (not working on 3.15.3,
# though, and use 3.17 for conda support anyway)
# Version 3.19 for support of find_package version range and file(CHMOD)
cmake_minimum_required( VERSION 3.19 FATAL_ERROR )
find_package( ecbuild 3.7 REQUIRED )

# Specify project and configuration options
project( loki LANGUAGES NONE )

# Allow negating ENABLE_NO_INSTALL with a leading '~'
macro( apply_negation VAR_NAME )
    if( DEFINED ${VAR_NAME} )
        if( ${${VAR_NAME}} MATCHES ^~ )
            string( REPLACE ~ "" ${VAR_NAME} ${${VAR_NAME}} )
            if( ${${VAR_NAME}} )
                set( ${VAR_NAME} OFF )
            else()
                set( ${VAR_NAME} ON )
            endif()
        endif()
    endif()
endmacro()

apply_negation( ENABLE_NO_INSTALL )
apply_negation( LOKI_ENABLE_NO_INSTALL )

# Declare options
ecbuild_add_option(
    FEATURE NO_INSTALL
    DEFAULT OFF
    DESCRIPTION "Disable Loki (and dependency) installation"
)
ecbuild_add_option(
    FEATURE EDITABLE
    DEFAULT OFF
    DESCRIPTION "Install Loki as an editable Python package"
)
ecbuild_add_option(
    FEATURE OMNI
    DEFAULT OFF
    DESCRIPTION "Build OMNI compiler as Loki frontend"
)

include( loki_transform )

# Make CMake script files available in build and install directory
add_subdirectory( cmake )
install( DIRECTORY cmake DESTINATION ${INSTALL_DATA_DIR} PATTERN "CMakeLists.txt" EXCLUDE )

# The list of Loki frontend scripts
file( GLOB _LOKI_SCRIPTS "${CMAKE_CURRENT_SOURCE_DIR}/loki/cli/loki_*.py" )
list( TRANSFORM _LOKI_SCRIPTS REPLACE "loki/cli/loki_" "loki/cli/loki-" )
set( LOKI_EXECUTABLES "" )
foreach( _exe IN LISTS _LOKI_SCRIPTS )
    get_filename_component( _exe_name ${_exe} NAME )
    list( APPEND LOKI_EXECUTABLES ${_exe_name} )
endforeach()

# Install Loki and dependencies
if( NOT HAVE_NO_INSTALL )

    if( HAVE_OMNI )
        include( omni_compiler )
        install_omni_compiler( master )
    endif()

    # Setup Python virtual environment
    include( loki_python_macros )
    set( PYTHON_VERSION 3.9 )
    loki_setup_python_venv(
        VENV_NAME loki_env
        PYTHON_VERSION ${PYTHON_VERSION}
        INSTALL_VENV
    )

    # Enable Pytest tests as ecbuild/ctest targets
    if( HAVE_TESTS )

        if( HAVE_OMNI )
            set( _TEST_SELECTOR "not ofp" )
            set( _TEST_PATH "${OMNI_DIR}/bin:$ENV{PATH}" )
        else()
            set( _TEST_SELECTOR "not ofp and not omni" )
            set( _TEST_PATH "$ENV{PATH}" )
        endif()

        # Nesting the CMake tests into CTest does not correctly resolve
        # search paths, therefore these are getting disabled here
        set( _TEST_SELECTOR "${_TEST_SELECTOR} and not cmake")

        # ecbuild_add_test relies on the variables set by the _very_ outdated
        # FindPythonInterp, so we set the bare minimum here using the values
        # from our FindPython3 variables
        set( PYTHONINTERP_FOUND True )
        set( PYTHON_EXECUTABLE ${Python3_EXECUTABLE} )

        ecbuild_add_test(
            TYPE PYTHON
            TARGET loki_tests
            ARGS -m pytest -k ${_TEST_SELECTOR} -v
            WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
            ENVIRONMENT PATH=${Python3_VENV_BIN}:${_TEST_PATH}
        )

        ecbuild_add_test(
            TYPE PYTHON
            TARGET loki_lint_rules
            ARGS -m pytest -k ${_TEST_SELECTOR} -v
            WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/lint_rules
            ENVIRONMENT PATH=${Python3_VENV_BIN}:${_TEST_PATH}
        )

        list( APPEND LOKI_INSTALL_OPTIONS "tests" )

    endif()

    # Determine whether this is a Git worktree or if we have to provide
    # the version number to setuptools_scm
    if( NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/.git )
        set( ENV{SETUPTOOLS_SCM_PRETEND_VERSION} ${loki_VERSION} )
    endif()

    # Install Loki python package with dependencies
    set( _INSTALL_OPTIONS "" )
    if( LOKI_INSTALL_OPTIONS )
        list( JOIN LOKI_INSTALL_OPTIONS "," _INSTALL_OPT_STR )
        set( _INSTALL_OPTIONS "[${_INSTALL_OPT_STR}]" )
    endif()

    # Optionally use the ARTIFACTS_DIR as wheelhouse, if provided
    if( DEFINED ARTIFACTS_DIR )
        set( WHEELS_DIR_OPTION WHEELS_DIR "${ARTIFACTS_DIR}" )
    else()
        set( WHEELS_DIR_OPTION "" )
    endif()

    if( HAVE_EDITABLE )
       set( EDITABLE_OPTION "EDITABLE" )
    else()
       set( EDITABLE_OPTION "" )
    endif()

    # We install Loki at configure time (for now), since bulk-transformation planning
    # requires configure time execution to allow injection with CMake targets.

    ecbuild_info( "Install Loki in virtual environment" )
    loki_install_python_package(
        REQUIREMENT_SPEC ${CMAKE_CURRENT_SOURCE_DIR}${_INSTALL_OPTIONS}
        ${EDITABLE_OPTION}
        ${WHEELS_DIR_OPTION}
    )
    loki_install_python_package(
        REQUIREMENT_SPEC ${CMAKE_CURRENT_SOURCE_DIR}/lint_rules
        ${EDITABLE_OPTION}
        ${WHEELS_DIR_OPTION}
    )
    ecbuild_info( "Install Loki in virtual environment - done" )

endif()

# Discover Loki executables and make available as CMake targets
include( loki_find_executables )
loki_find_executables()

# Install the project so it can be used within the bundle
ecbuild_install_project( NAME loki )

# print summary
ecbuild_print_summary()
loki-ecmwf-0.3.6/.github/0000775000175000017500000000000015167130205015301 5ustar  alastairalastairloki-ecmwf-0.3.6/.github/workflows/0000775000175000017500000000000015167130205017336 5ustar  alastairalastairloki-ecmwf-0.3.6/.github/workflows/tests.yml0000664000175000017500000000606615167130205021233 0ustar  alastairalastairname: tests

# Controls when the workflow will run
on:
  # Triggers the workflow on push events
  push:
    branches: [ 'main' ]
    tags-ignore: [ '**' ]

  # Triggers the workflow on pull request events
  pull_request:

  # Allows you to run this workflow manually from the Actions tab
  workflow_dispatch:

jobs:
  pytest:
    name: pytest

    strategy:
      fail-fast: false  # false: try to complete all jobs
      matrix:
        name:
          - linux gnu-14

        python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']

        include:
          - name: linux gnu-14
            os: ubuntu-24.04
              # Enable --with-dace as soon as DaCe supports 3.13 and Numpy>2.0
            install-options: --with-omni --with-examples --with-tests --without-dace
            toolchain: {compiler: gcc, version: 14}
            pkg-dependencies: graphviz gfortran byacc flex cmake meson ninja-build

          - name: macos
            os: macos-14
            python-version: '3.13'
            install-options: --with-examples --with-tests --without-dace
            toolchain: {compiler: gcc, version: 14}
            pkg-dependencies: graphviz ninja meson

    runs-on: ${{ matrix.os }}

    steps:
      - uses: actions/checkout@v4
        with:
          fetch-depth: 0

      - name: Set up Fortran compiler ${{ matrix.toolchain.compiler }} ${{ matrix.toolchain.version }}
        uses: fortran-lang/setup-fortran@v1
        with:
          compiler: ${{ matrix.toolchain.compiler }}
          version: ${{ matrix.toolchain.version }}

      - name: Set up Python ${{ matrix.python-version }}
        uses: actions/setup-python@v5
        with:
          python-version: ${{ matrix.python-version }}
          cache: 'pip'

      - name: Set up JDK
        uses: actions/setup-java@v4
        with:
          distribution: temurin
          java-version: 11

      - name: Install dependencies
        run: |
          if [[ "${{ matrix.os }}" =~ macos ]]; then
            brew install ${{ matrix.pkg-dependencies }}
          else
            sudo apt-get -o Acquire::Retries=3 install -y ${{ matrix.pkg-dependencies }}
          fi

      - name: Install Loki
        run: |
          ./install ${{ matrix.install-options }}

      - name: Run Loki tests
        run: |
          source loki-activate
          pytest -v -n 4 --cov=./loki --cov-report=xml --pyargs loki

      - name: Upload loki coverage report to Codecov
        uses: codecov/codecov-action@v4
        if: ${{ ! startsWith(matrix.os, 'macos') }}
        with:
          flags: loki
          files: ./coverage.xml
        env:
          CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}

      - name: Run lint_rules tests
        run: |
          source loki-activate
          pytest -v --cov=./lint_rules/lint_rules --cov-report=xml lint_rules

      - name: Upload lint_rules coverage report to Codecov
        uses: codecov/codecov-action@v4
        if: ${{ ! startsWith(matrix.os, 'macos') }}
        with:
          flags: lint_rules
          files: ./coverage.xml
        env:
          CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
loki-ecmwf-0.3.6/.github/workflows/code_checks.yml0000664000175000017500000000246115167130205022316 0ustar  alastairalastairname: code-checks

# Controls when the workflow will run
on:
  # Triggers the workflow on push events
  push:
    branches: [ 'main' ]
    tags-ignore: [ '**' ]

  # Triggers the workflow on pull request events
  pull_request:

  # Allows you to run this workflow manually from the Actions tab
  workflow_dispatch:

jobs:
  code_checks:
    name: code checks

    runs-on: ubuntu-latest
    strategy:
      fail-fast: false  # false: try to complete all jobs
      matrix:
        python-version: ["3.11"]

    steps:
    - uses: actions/checkout@v4
      with:
        fetch-depth: 0
    - name: Set up Python ${{ matrix.python-version }}
      uses: actions/setup-python@v5
      with:
        python-version: ${{ matrix.python-version }}
        cache: 'pip'
    - name: Install dependencies
      run: |
        python -m pip install --upgrade pip
        pip install .[tests,examples] ./lint_rules
        pip list
    - name: Add pylint annotator
      uses: pr-annotators/pylint-pr-annotator@v0.0.1
    - name: Analysing the code with pylint
      run: |
        pylint --rcfile=.pylintrc loki
        pushd lint_rules && pylint --rcfile=../.pylintrc lint_rules tests; popd
        jupyter nbconvert --to=script --output-dir=example_converted example/*.ipynb
        pylint --rcfile=.pylintrc_ipynb example_converted/*.py
loki-ecmwf-0.3.6/.github/workflows/regression_tests.yml0000664000175000017500000000541515167130205023470 0ustar  alastairalastairname: regression-tests

# Controls when the workflow will run
on:
  # Triggers the workflow on push events
  push:
    branches: [ 'main' ]
    tags-ignore: [ '**' ]

  # Triggers the workflow on pull request events
  pull_request:

  # Allows you to run this workflow manually from the Actions tab
  workflow_dispatch:

jobs:
  regression_tests:
    name: Python ${{ matrix.python-version }}

    strategy:
      fail-fast: false  # false: try to complete all jobs
      matrix:
        name:
          - linux gnu-13

        python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']

        include:
          - name: linux gnu-13
            os: ubuntu-24.04
            install-options: --with-omni --with-tests --without-dace
            toolchain: {compiler: gcc, version: 13}
            pkg-dependencies: graphviz byacc flex cmake meson ninja-build libhdf5-dev libopenmpi-dev
            pip-dependencies: pyyaml fypp

    runs-on: ${{ matrix.os }}

    steps:
      - uses: actions/checkout@v4
        with:
          fetch-depth: 0

      - name: Clone CLOUDSC
        uses: actions/checkout@v4
        with:
          repository: ecmwf-ifs/dwarf-p-cloudsc
          path: cloudsc
          ref: develop

      - name: Clone CLOUDSC2 TL AD
        uses: actions/checkout@v4
        with:
          repository: ecmwf-ifs/dwarf-p-cloudsc2-tl-ad
          path: cloudsc2_tl_ad
          ref: develop

      - name: Clone ECWAM
        uses: actions/checkout@v4
        with:
          repository: ecmwf-ifs/ecwam
          path: ecwam
          ref: develop

      - name: Set up Fortran compiler ${{ matrix.toolchain.compiler }} ${{ matrix.toolchain.version }}
        uses: fortran-lang/setup-fortran@v1
        with:
          compiler: ${{ matrix.toolchain.compiler }}
          version: ${{ matrix.toolchain.version }}

      - name: Set up Python ${{ matrix.python-version }}
        uses: actions/setup-python@v5
        with:
          python-version: ${{ matrix.python-version }}
          cache: 'pip'

      - name: Set up JDK
        uses: actions/setup-java@v4
        with:
          distribution: temurin
          java-version: 11

      - name: Install dependencies
        run: |
            sudo apt-get -o Acquire::Retries=3 install -y ${{ matrix.pkg-dependencies }}
            pip install ${{ matrix.pip-dependencies }}

      - name: Install Loki
        run: |
          ./install ${{ matrix.install-options }}

      - name: Run CLOUDSC and ECWAM regression tests
        env:
          CLOUDSC_DIR: ${{ github.workspace }}/cloudsc
          CLOUDSC2_DIR: ${{ github.workspace }}/cloudsc2_tl_ad
          ECWAM_DIR: ${{ github.workspace }}/ecwam
          OMP_STACKSIZE: 4G
        run: |
          source loki-activate
          pytest -v -n 2 --pyargs loki.transformations -k 'cloudsc or ecwam'
loki-ecmwf-0.3.6/.github/workflows/documentation.yml0000664000175000017500000000622315167130205022735 0ustar  alastairalastairname: documentation

# Controls when the workflow will run
on:
  # Triggers the workflow on push events
  push:
    branches: [ 'main' ]
    tags-ignore: [ '**' ]

  # Triggers the workflow on pull request events
  pull_request:

  # Allows you to run this workflow manually from the Actions tab
  workflow_dispatch:

jobs:
  build:
    name: Build and upload documentation

    runs-on: ubuntu-latest
    strategy:
      fail-fast: false  # false: try to complete all jobs
      matrix:
        python-version: ["3.10"]

    steps:
    - uses: actions/checkout@v4
      with:
        fetch-depth: 0

    - name: Set up Python ${{ matrix.python-version }}
      uses: actions/setup-python@v5
      with:
        python-version: ${{ matrix.python-version }}

    - name: Install pandoc
      run: |
        sudo apt-get update || true
        sudo apt-get install -y pandoc

    - name: Install dependencies
      run: |
        python -m pip install --upgrade pip
        pip install sites-toolkit -i https://get.ecmwf.int/repository/pypi-all/simple
        pip install .[docs]
        pip install ./lint_rules/

    - name: Build documentation
      working-directory: ./docs
      run: |
        make html

    - name: Upload pull request documentation to sites
      if: github.event.pull_request.head.repo.full_name == 'ecmwf-ifs/loki'
      env:
        SITES_TOKEN: ${{ secrets.SITES_TOKEN }}
      working-directory: ./docs
      run: |
        ./sites-manager.py --space=docs --name=loki --token "$SITES_TOKEN" upload build/html ${{ github.event.pull_request.number }} || true

    - name: Update documentation on sites
      if: github.event_name != 'pull_request'
      env:
        SITES_TOKEN: ${{ secrets.SITES_TOKEN }}
      working-directory: ./docs
      run: |
        ./sites-manager.py --space=docs --name=loki --token "$SITES_TOKEN" upload --clean build/html ${{ github.ref_name }} || true

    - name: Find Comment
      if: github.ref_name != 'main'
      uses: peter-evans/find-comment@v2
      id: fc
      with:
        issue-number: ${{ github.event.pull_request.number }}
        comment-author: 'github-actions[bot]'
        body-includes: Documentation for this branch can be viewed at

    - name: Create or update comment
      if: github.ref_name != 'main' && github.event.pull_request.head.repo.full_name == 'ecmwf-ifs/loki'
      uses: peter-evans/create-or-update-comment@v3
      with:
        comment-id: ${{ steps.fc.outputs.comment-id }}
        issue-number: ${{ github.event.pull_request.number }}
        body: |
          Documentation for this branch can be viewed at https://sites.ecmwf.int/docs/loki/${{ github.event.pull_request.number }}/index.html
        edit-mode: replace

    # - uses: actions/github-script@v6
    #   if: github.ref_name != 'main' && steps.fc.outputs.comment-id == ''
    #   with:
    #     script: |
    #       github.rest.issues.createComment({
    #         issue_number: context.issue.number,
    #         owner: context.repo.owner,
    #         repo: context.repo.repo,
    #         body: 'Documentation for this branch can be viewed at https://sites.ecmwf.int/docs/loki/${{ github.event.pull_request.number }}/index.html'
    #       })
loki-ecmwf-0.3.6/.github/workflows/documentation_clean-up.yml0000664000175000017500000000207015167130205024515 0ustar  alastairalastairname: documentation clean-up

# Controls when the workflow will run
on:
  # Triggers the workflow when pull requests are closed
  pull_request:
    types: [closed]

jobs:
  clean:
    name: Clean-up branch documentation

    runs-on: ubuntu-latest
    strategy:
      fail-fast: false  # false: try to complete all jobs
      matrix:
        python-version: ["3.10"]

    steps:
    - uses: actions/checkout@v4
      with:
        ref: main

    - name: Set up Python ${{ matrix.python-version }}
      uses: actions/setup-python@v5
      with:
        python-version: ${{ matrix.python-version }}

    - name: Install dependencies
      run: |
        python -m pip install --upgrade pip
        pip install click
        pip install sites-toolkit -i https://get.ecmwf.int/repository/pypi-all/simple

    - name: Clean-up documentation on sites
      env:
        SITES_TOKEN: ${{ secrets.SITES_TOKEN }}
      working-directory: ./docs
      run: |
        ./sites-manager.py --space=docs --name=loki --token "$SITES_TOKEN" delete ${{ github.event.pull_request.number }} || true