pax_global_header 0000666 0000000 0000000 00000000064 14203260312 0014503 g ustar 00root root 0000000 0000000 52 comment=645f46258515677620add3a60fc21a6bf6b27363
mir_eval-0.7/ 0000775 0000000 0000000 00000000000 14203260312 0013147 5 ustar 00root root 0000000 0000000 mir_eval-0.7/.coveragerc 0000664 0000000 0000000 00000000035 14203260312 0015266 0 ustar 00root root 0000000 0000000 [report]
show_missing = True
mir_eval-0.7/.gitignore 0000664 0000000 0000000 00000000650 14203260312 0015140 0 ustar 00root root 0000000 0000000 *.py[co]
# Packages
*.egg
*.egg-info
dist
build
eggs
parts
bin
var
sdist
develop-eggs
.installed.cfg
# Installer logs
pip-log.txt
# Unit test / coverage reports
.coverage
.tox
#Translations
*.mo
#Mr Developer
.mr.developer.cfg
# OS generated files #
.DS_Store
.DS_Store?
._*
.Spotlight-V100
.Trashes
ehthumbs.db
Thumbs.db
# Vim
*.swp
# pycharm
.idea/*
# docs
docs/_build/*
# matplotlib tsets
tests/result_images/*
mir_eval-0.7/.travis.yml 0000664 0000000 0000000 00000002154 14203260312 0015262 0 ustar 00root root 0000000 0000000 language: python
notifications:
email: false
python:
- "3.5"
before_install:
- if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then
wget http://repo.continuum.io/miniconda/Miniconda-3.8.3-Linux-x86_64.sh -O miniconda.sh;
else
wget http://repo.continuum.io/miniconda/Miniconda3-3.8.3-Linux-x86_64.sh -O miniconda.sh;
fi
- bash miniconda.sh -b -p $HOME/miniconda
- export PATH="$HOME/miniconda/bin:$PATH"
- hash -r
- conda config --set always_yes yes --set changeps1 no
- conda update -q conda
- conda info -a
- deps='pip atlas numpy scipy sphinx nose six future pep8 matplotlib>=2.1.0,<3 decorator'
- conda create -q -n test-environment "python=$TRAVIS_PYTHON_VERSION" $deps
- source activate test-environment
- pip install python-coveralls
- pip install numpydoc
install:
- pip install -e .[display,testing]
script:
- nosetests -v --with-coverage --cover-package=mir_eval -w tests
- pep8 mir_eval tests
- python setup.py build_sphinx
- python setup.py egg_info -b.dev sdist --formats gztar
after_success:
- coveralls
mir_eval-0.7/LICENSE.txt 0000664 0000000 0000000 00000002067 14203260312 0014777 0 ustar 00root root 0000000 0000000 The MIT License (MIT)
Copyright (c) 2014 Colin Raffel
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
mir_eval-0.7/README.rst 0000664 0000000 0000000 00000002270 14203260312 0014637 0 ustar 00root root 0000000 0000000 .. image:: https://travis-ci.org/craffel/mir_eval.svg?branch=master
:target: https://travis-ci.org/craffel/mir_eval
.. image:: https://coveralls.io/repos/craffel/mir_eval/badge.svg?branch=master&service=github
:target: https://coveralls.io/github/craffel/mir_eval?branch=master
mir_eval
========
Python library for computing common heuristic accuracy scores for various music/audio information retrieval/signal processing tasks.
Documentation, including installation and usage information: http://craffel.github.io/mir_eval/
If you're looking for the mir_eval web service, which you can use to run mir_eval without installing anything or writing any code, it can be found here: http://labrosa.ee.columbia.edu/mir_eval/
Dependencies:
* `Scipy/Numpy `_
* future
* six
If you use mir_eval in a research project, please cite the following paper:
Colin Raffel, Brian McFee, Eric J. Humphrey, Justin Salamon, Oriol Nieto, Dawen Liang, and Daniel P. W. Ellis, "`mir_eval: A Transparent Implementation of Common MIR Metrics `_", Proceedings of the 15th International Conference on Music Information Retrieval, 2014.
mir_eval-0.7/docs/ 0000775 0000000 0000000 00000000000 14203260312 0014077 5 ustar 00root root 0000000 0000000 mir_eval-0.7/docs/Makefile 0000664 0000000 0000000 00000015162 14203260312 0015544 0 ustar 00root root 0000000 0000000 # Makefile for Sphinx documentation
#
# You can set these variables from the command line.
SPHINXOPTS =
SPHINXBUILD = sphinx-build
PAPER =
BUILDDIR = _build
# User-friendly check for sphinx-build
ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1)
$(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/)
endif
# Internal variables.
PAPEROPT_a4 = -D latex_paper_size=a4
PAPEROPT_letter = -D latex_paper_size=letter
ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
# the i18n builder cannot share the environment and doctrees with the others
I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
.PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext
help:
@echo "Please use \`make ' where is one of"
@echo " html to make standalone HTML files"
@echo " dirhtml to make HTML files named index.html in directories"
@echo " singlehtml to make a single large HTML file"
@echo " pickle to make pickle files"
@echo " json to make JSON files"
@echo " htmlhelp to make HTML files and a HTML help project"
@echo " qthelp to make HTML files and a qthelp project"
@echo " devhelp to make HTML files and a Devhelp project"
@echo " epub to make an epub"
@echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter"
@echo " latexpdf to make LaTeX files and run them through pdflatex"
@echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx"
@echo " text to make text files"
@echo " man to make manual pages"
@echo " texinfo to make Texinfo files"
@echo " info to make Texinfo files and run them through makeinfo"
@echo " gettext to make PO message catalogs"
@echo " changes to make an overview of all changed/added/deprecated items"
@echo " xml to make Docutils-native XML files"
@echo " pseudoxml to make pseudoxml-XML files for display purposes"
@echo " linkcheck to check all external links for integrity"
@echo " doctest to run all doctests embedded in the documentation (if enabled)"
clean:
rm -rf $(BUILDDIR)/*
html:
$(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html
@echo
@echo "Build finished. The HTML pages are in $(BUILDDIR)/html."
dirhtml:
$(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml
@echo
@echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml."
singlehtml:
$(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml
@echo
@echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml."
pickle:
$(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle
@echo
@echo "Build finished; now you can process the pickle files."
json:
$(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json
@echo
@echo "Build finished; now you can process the JSON files."
htmlhelp:
$(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp
@echo
@echo "Build finished; now you can run HTML Help Workshop with the" \
".hhp project file in $(BUILDDIR)/htmlhelp."
qthelp:
$(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp
@echo
@echo "Build finished; now you can run "qcollectiongenerator" with the" \
".qhcp project file in $(BUILDDIR)/qthelp, like this:"
@echo "# qcollectiongenerator $(BUILDDIR)/qthelp/mir_eval.qhcp"
@echo "To view the help file:"
@echo "# assistant -collectionFile $(BUILDDIR)/qthelp/mir_eval.qhc"
devhelp:
$(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp
@echo
@echo "Build finished."
@echo "To view the help file:"
@echo "# mkdir -p $$HOME/.local/share/devhelp/mir_eval"
@echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/mir_eval"
@echo "# devhelp"
epub:
$(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub
@echo
@echo "Build finished. The epub file is in $(BUILDDIR)/epub."
latex:
$(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
@echo
@echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex."
@echo "Run \`make' in that directory to run these through (pdf)latex" \
"(use \`make latexpdf' here to do that automatically)."
latexpdf:
$(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
@echo "Running LaTeX files through pdflatex..."
$(MAKE) -C $(BUILDDIR)/latex all-pdf
@echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex."
latexpdfja:
$(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
@echo "Running LaTeX files through platex and dvipdfmx..."
$(MAKE) -C $(BUILDDIR)/latex all-pdf-ja
@echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex."
text:
$(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text
@echo
@echo "Build finished. The text files are in $(BUILDDIR)/text."
man:
$(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man
@echo
@echo "Build finished. The manual pages are in $(BUILDDIR)/man."
texinfo:
$(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
@echo
@echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo."
@echo "Run \`make' in that directory to run these through makeinfo" \
"(use \`make info' here to do that automatically)."
info:
$(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
@echo "Running Texinfo files through makeinfo..."
make -C $(BUILDDIR)/texinfo info
@echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo."
gettext:
$(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale
@echo
@echo "Build finished. The message catalogs are in $(BUILDDIR)/locale."
changes:
$(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes
@echo
@echo "The overview file is in $(BUILDDIR)/changes."
linkcheck:
$(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck
@echo
@echo "Link check complete; look for any errors in the above output " \
"or in $(BUILDDIR)/linkcheck/output.txt."
doctest:
$(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest
@echo "Testing of doctests in the sources finished, look at the " \
"results in $(BUILDDIR)/doctest/output.txt."
xml:
$(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml
@echo
@echo "Build finished. The XML files are in $(BUILDDIR)/xml."
pseudoxml:
$(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml
@echo
@echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml."
mir_eval-0.7/docs/changes.rst 0000664 0000000 0000000 00000015134 14203260312 0016245 0 ustar 00root root 0000000 0000000 Changes
=======
v0.7
----
- `#334`_: Support notation for unknown/ambiguous key or mode
- `#343`_: Add suite of alignment metrics
.. _#334: https://github.com/craffel/mir_eval/pull/334
.. _#343: https://github.com/craffel/mir_eval/pull/343
v0.6
----
- `#297`_: Return 0 when no overlap in transcription_velocity
- `#299`_: Allow one reference tempo and both estimate tempi to be zero
- `#301`_: Allow zero tolerance in tempo, but issue a warning
- `#302`_: Loosen separation test tolerance
- `#305`_: Use toarray instead of todense for sparse matrices
- `#307`_: Use tuple index in chord.rotate_bitmap_to_root
- `#309`_: Require matplotlib <3 for testing
- `#312`_: Fix raw chroma accuracy for unvoiced estimates
- `#320`_: Add comment support to io methods
- `#323`_: Fix interpolation in sonify.time_frequency
- `#324`_: Add generalized melody metrics
- `#327`_: Stop testing 2.7
- `#328`_: Cast n_voiced to int in display.multipitch
.. _#297: https://github.com/craffel/mir_eval/pull/297
.. _#299: https://github.com/craffel/mir_eval/pull/299
.. _#301: https://github.com/craffel/mir_eval/pull/301
.. _#302: https://github.com/craffel/mir_eval/pull/302
.. _#305: https://github.com/craffel/mir_eval/pull/305
.. _#307: https://github.com/craffel/mir_eval/pull/307
.. _#309: https://github.com/craffel/mir_eval/pull/309
.. _#312: https://github.com/craffel/mir_eval/pull/312
.. _#320: https://github.com/craffel/mir_eval/pull/320
.. _#323: https://github.com/craffel/mir_eval/pull/323
.. _#324: https://github.com/craffel/mir_eval/pull/324
.. _#327: https://github.com/craffel/mir_eval/pull/327
.. _#328: https://github.com/craffel/mir_eval/pull/328
v0.5
----
- `#222`_: added int cast for inferred length in sonify.clicks
- `#225`_: improved t-measures and l-measures
- `#227`_: added marginal flag to segment.nce
- `#234`_: update display to use matplotlib 2
- `#236`_: force integer division in beat.pscore
- `#240`_: fix unit tests for source separation
- `#242`_: use regexp in chord label validation
- `#245`_: add labeled interval formatter to display
- `#247`_: do not sonify negative amplitudes in time_frequency
- `#249`_: support gaps in util.interpolate_intervals
- `#252`_: add modulo and length arguments to chord.scale_degree_to_bitmap
- `#254`_: fix bss_eval_images single-frame fallback documentation
- `#255`_: fix crackle in sonify.time_frequency
- `#258`_: make util.match_events faster
- `#259`_: run pep8 check after nosetests
- `#263`_: add measures for chord over- and under-segmentation
- `#266`_: add amplitude parameter to sonify.pitch_contour
- `#268`_: update display tests to support mpl2.1
- `#277`_: update requirements and fix deprecations
- `#279`_: isolate matplotlib side effects
- `#282`_: remove evaluator scripts
- `#283`_: add transcription eval with velocity
.. _#222: https://github.com/craffel/mir_eval/pull/222
.. _#225: https://github.com/craffel/mir_eval/pull/225
.. _#227: https://github.com/craffel/mir_eval/pull/227
.. _#234: https://github.com/craffel/mir_eval/pull/234
.. _#236: https://github.com/craffel/mir_eval/pull/236
.. _#240: https://github.com/craffel/mir_eval/pull/240
.. _#242: https://github.com/craffel/mir_eval/pull/242
.. _#245: https://github.com/craffel/mir_eval/pull/245
.. _#247: https://github.com/craffel/mir_eval/pull/247
.. _#249: https://github.com/craffel/mir_eval/pull/249
.. _#252: https://github.com/craffel/mir_eval/pull/252
.. _#254: https://github.com/craffel/mir_eval/pull/254
.. _#255: https://github.com/craffel/mir_eval/pull/255
.. _#258: https://github.com/craffel/mir_eval/pull/258
.. _#259: https://github.com/craffel/mir_eval/pull/259
.. _#263: https://github.com/craffel/mir_eval/pull/263
.. _#266: https://github.com/craffel/mir_eval/pull/266
.. _#268: https://github.com/craffel/mir_eval/pull/268
.. _#277: https://github.com/craffel/mir_eval/pull/277
.. _#279: https://github.com/craffel/mir_eval/pull/279
.. _#282: https://github.com/craffel/mir_eval/pull/282
.. _#283: https://github.com/craffel/mir_eval/pull/283
v0.4
----
- `#189`_: expanded transcription metrics
- `#195`_: added pitch contour sonification
- `#196`_: added the `display` submodule
- `#203`_: support unsorted segment intervals
- `#205`_: correction in documentation for `sonify.time_frequency`
- `#208`_: refactored file/buffer loading
- `#210`_: added `io.load_tempo`
- `#212`_: added frame-wise blind-source separation evaluation
- `#218`_: speed up `melody.resample_melody_series` when times are equivalent
.. _#189: https://github.com/craffel/mir_eval/issues/189
.. _#195: https://github.com/craffel/mir_eval/issues/195
.. _#196: https://github.com/craffel/mir_eval/issues/196
.. _#203: https://github.com/craffel/mir_eval/issues/203
.. _#205: https://github.com/craffel/mir_eval/issues/205
.. _#208: https://github.com/craffel/mir_eval/issues/208
.. _#210: https://github.com/craffel/mir_eval/issues/210
.. _#212: https://github.com/craffel/mir_eval/issues/212
.. _#218: https://github.com/craffel/mir_eval/pull/218
v0.3
----
- `#170`_: implemented transcription metrics
- `#173`_: fixed a bug in chord sonification
- `#175`_: filter_kwargs passes through `**kwargs`
- `#181`_: added key detection metrics
.. _#170: https://github.com/craffel/mir_eval/issues/170
.. _#173: https://github.com/craffel/mir_eval/issues/173
.. _#175: https://github.com/craffel/mir_eval/issues/175
.. _#181: https://github.com/craffel/mir_eval/issues/181
v0.2
----
- `#103`_: incomplete files passed to `melody.evaluate` should warn
- `#109`_: `STRICT_BASS_INTERVALS` is now an argument to `chord.encode`
- `#122`_: improved handling of corner cases in beat tracking
- `#136`_: improved test coverage
- `#138`_: PEP8 compliance
- `#139`_: converted documentation to numpydoc style
- `#147`_: fixed a rounding error in segment intervals
- `#150`_: `sonify.chroma` and `sonify.chords` pass `kwargs` to `time_frequecy`
- `#151`_: removed `labels` support from `util.boundaries_to_intervals`
- `#159`_: fixed documentation error in `chord.tetrads`
- `#160`_: fixed documentation error in `util.intervals_to_samples`
.. _#103: https://github.com/craffel/mir_eval/issues/103
.. _#109: https://github.com/craffel/mir_eval/issues/109
.. _#122: https://github.com/craffel/mir_eval/issues/122
.. _#136: https://github.com/craffel/mir_eval/issues/136
.. _#138: https://github.com/craffel/mir_eval/issues/138
.. _#139: https://github.com/craffel/mir_eval/issues/139
.. _#147: https://github.com/craffel/mir_eval/issues/147
.. _#150: https://github.com/craffel/mir_eval/issues/150
.. _#151: https://github.com/craffel/mir_eval/issues/151
.. _#159: https://github.com/craffel/mir_eval/issues/159
.. _#160: https://github.com/craffel/mir_eval/issues/160
v0.1
----
- Initial public release.
mir_eval-0.7/docs/conf.py 0000664 0000000 0000000 00000020206 14203260312 0015376 0 ustar 00root root 0000000 0000000 # -*- coding: utf-8 -*-
#
# mir_eval documentation build configuration file, created by
# sphinx-quickstart on Thu May 8 15:55:45 2014.
#
# This file is execfile()d with the current directory set to its
# containing dir.
#
# Note that not all possible configuration values are present in this
# autogenerated file.
#
# All configuration values have a default; values that are commented out
# serve to show the default.
import sys
import os
sys.path.insert(0, os.path.abspath('..'))
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#sys.path.insert(0, os.path.abspath('.'))
# -- General configuration ------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.imgmath',
'numpydoc',
]
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# The suffix of source filenames.
source_suffix = '.rst'
# The encoding of source files.
#source_encoding = 'utf-8-sig'
# The master toctree document.
master_doc = 'index'
# General information about the project.
project = u'mir_eval'
copyright = u'2014, Colin Raffel et al.'
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
# built documents.
#
# The short X.Y version.
version = '0.7'
# The full version, including alpha/beta/rc tags.
release = '0.7'
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#language = None
# There are two options for replacing |today|: either, you set today to some
# non-false value, then it is used:
#today = ''
# Else, today_fmt is used as the format for a strftime call.
#today_fmt = '%B %d, %Y'
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
exclude_patterns = ['_build']
# The reST default role (used for this markup: `text`) to use for all
# documents.
#default_role = None
# If true, '()' will be appended to :func: etc. cross-reference text.
#add_function_parentheses = True
# If true, the current module name will be prepended to all description
# unit titles (such as .. function::).
#add_module_names = True
# If true, sectionauthor and moduleauthor directives will be shown in the
# output. They are ignored by default.
#show_authors = False
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'
# A list of ignored prefixes for module index sorting.
#modindex_common_prefix = []
# If true, keep warnings as "system message" paragraphs in the built documents.
#keep_warnings = False
# -- Options for HTML output ----------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
html_theme = 'default'
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#html_theme_options = {}
# Add any paths that contain custom themes here, relative to this directory.
#html_theme_path = []
# The name for this set of Sphinx documents. If None, it defaults to
# " v documentation".
#html_title = None
# A shorter title for the navigation bar. Default is the same as html_title.
#html_short_title = None
# The name of an image file (relative to this directory) to place at the top
# of the sidebar.
#html_logo = None
# The name of an image file (within the static path) to use as favicon of the
# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
# pixels large.
#html_favicon = None
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
# Add any extra paths that contain custom files (such as robots.txt or
# .htaccess) here, relative to this directory. These files are copied
# directly to the root of the documentation.
#html_extra_path = []
# If not '', a 'Last updated on:' timestamp is inserted at every page bottom,
# using the given strftime format.
#html_last_updated_fmt = '%b %d, %Y'
# If true, SmartyPants will be used to convert quotes and dashes to
# typographically correct entities.
#html_use_smartypants = True
# Custom sidebar templates, maps document names to template names.
#html_sidebars = {}
# Additional templates that should be rendered to pages, maps page names to
# template names.
#html_additional_pages = {}
# If false, no module index is generated.
#html_domain_indices = True
# If false, no index is generated.
#html_use_index = True
# If true, the index is split into individual pages for each letter.
#html_split_index = False
# If true, links to the reST sources are added to the pages.
#html_show_sourcelink = True
# If true, "Created using Sphinx" is shown in the HTML footer. Default is True.
#html_show_sphinx = True
# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True.
#html_show_copyright = True
# If true, an OpenSearch description file will be output, and all pages will
# contain a tag referring to it. The value of this option must be the
# base URL from which the finished HTML is served.
#html_use_opensearch = ''
# This is the file name suffix for HTML files (e.g. ".xhtml").
#html_file_suffix = None
# Output file base name for HTML help builder.
htmlhelp_basename = 'mir_evaldoc'
# -- Options for LaTeX output ---------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#'preamble': '',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
('index', 'mir_eval.tex', u'mir\\_eval Documentation',
u'Colin Raffel et al.', 'manual'),
]
# The name of an image file (relative to this directory) to place at the top of
# the title page.
#latex_logo = None
# For "manual" documents, if this is true, then toplevel headings are parts,
# not chapters.
#latex_use_parts = False
# If true, show page references after internal links.
#latex_show_pagerefs = False
# If true, show URL addresses after external links.
#latex_show_urls = False
# Documents to append as an appendix to all manuals.
#latex_appendices = []
# If false, no module index is generated.
#latex_domain_indices = True
# -- Options for manual page output ---------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
('index', 'mir_eval', u'mir_eval Documentation',
[u'Colin Raffel et al.'], 1)
]
# If true, show URL addresses after external links.
#man_show_urls = False
# -- Options for Texinfo output -------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
('index', 'mir_eval', u'mir_eval Documentation',
u'Colin Raffel et al.', 'mir_eval', 'One line description of project.',
'Miscellaneous'),
]
# Documents to append as an appendix to all manuals.
#texinfo_appendices = []
# If false, no module index is generated.
#texinfo_domain_indices = True
# How to display URL addresses: 'footnote', 'no', or 'inline'.
#texinfo_show_urls = 'footnote'
# If true, do not generate a @detailmenu in the "Top" node's menu.
#texinfo_no_detailmenu = False
autodoc_member_order = 'bysource'
mir_eval-0.7/docs/index.rst 0000664 0000000 0000000 00000016042 14203260312 0015743 0 ustar 00root root 0000000 0000000 **************************
``mir_eval`` Documentation
**************************
``mir_eval`` is a Python library which provides a transparent, standaridized, and straightforward way to evaluate Music Information Retrieval systems.
If you use ``mir_eval`` in a research project, please cite the following paper:
C. Raffel, B. McFee, E. J. Humphrey, J. Salamon, O. Nieto, D. Liang, and D. P. W. Ellis, `"mir_eval: A Transparent Implementation of Common MIR Metrics" `_, Proceedings of the 15th International Conference on Music Information Retrieval, 2014.
.. _installation:
Installing ``mir_eval``
=======================
The simplest way to install ``mir_eval`` is by using ``pip``, which will also install the required dependencies if needed.
To install ``mir_eval`` using ``pip``, simply run
``pip install mir_eval``
Alternatively, you can install ``mir_eval`` from source by first installing the dependencies and then running
``python setup.py install``
from the source directory.
If you don't use Python and want to get started as quickly as possible, you might consider using `Anaconda `_ which makes it easy to install a Python environment which can run ``mir_eval``.
Using ``mir_eval``
=============================================
Once you've installed ``mir_eval`` (see :ref:`installation`), you can import it in your Python code as follows:
``import mir_eval``
From here, you will typically either load in data and call the ``evaluate()`` function from the appropriate submodule like so::
reference_beats = mir_eval.io.load_events('reference_beats.txt')
estimated_beats = mir_eval.io.load_events('estimated_beats.txt')
# Scores will be a dict containing scores for all of the metrics
# implemented in mir_eval.beat. The keys are metric names
# and values are the scores achieved
scores = mir_eval.beat.evaluate(reference_beats, estimated_beats)
or you'll load in the data, do some preprocessing, and call specific metric functions from the appropriate submodule like so::
reference_beats = mir_eval.io.load_events('reference_beats.txt')
estimated_beats = mir_eval.io.load_events('estimated_beats.txt')
# Crop out beats before 5s, a common preprocessing step
reference_beats = mir_eval.beat.trim_beats(reference_beats)
estimated_beats = mir_eval.beat.trim_beats(estimated_beats)
# Compute the F-measure metric and store it in f_measure
f_measure = mir_eval.beat.f_measure(reference_beats, estimated_beats)
The documentation for each metric function, found in the :ref:`mir_eval` section below, contains further usage information.
Alternatively, you can use the evaluator scripts which allow you to run evaluation from the command line, without writing any code.
These scripts are are available here:
https://github.com/craffel/mir_evaluators
.. _mir_eval:
``mir_eval``
============
The structure of the ``mir_eval`` Python module is as follows:
Each MIR task for which evaluation metrics are included in ``mir_eval`` is given its own submodule, and each metric is defined as a separate function in each submodule.
Every metric function includes detailed documentation, example usage, input validation, and references to the original paper which defined the metric (see the subsections below).
The task submodules also all contain a function ``evaluate()``, which takes as input reference and estimated annotations and returns a dictionary of scores for all of the metrics implemented (for casual users, this is the place to start).
Finally, each task submodule also includes functions for common data pre-processing steps.
``mir_eval`` also includes the following additional submodules:
* :mod:`mir_eval.io` which contains convenience functions for loading in task-specific data from common file formats
* :mod:`mir_eval.util` which includes miscellaneous functionality shared across the submodules
* :mod:`mir_eval.sonify` which implements some simple methods for synthesizing annotations of various formats for "evaluation by ear".
* :mod:`mir_eval.display` which provides functions for plotting annotations for various tasks.
The following subsections document each submodule.
:mod:`mir_eval.beat`
--------------------
.. automodule:: mir_eval.beat
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
:mod:`mir_eval.chord`
---------------------
.. automodule:: mir_eval.chord
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
:mod:`mir_eval.melody`
----------------------
.. automodule:: mir_eval.melody
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
:mod:`mir_eval.multipitch`
--------------------------
.. automodule:: mir_eval.multipitch
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
:mod:`mir_eval.onset`
---------------------
.. automodule:: mir_eval.onset
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
:mod:`mir_eval.pattern`
-----------------------
.. automodule:: mir_eval.pattern
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
:mod:`mir_eval.segment`
-----------------------
.. automodule:: mir_eval.segment
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
:mod:`mir_eval.hierarchy`
-------------------------
.. automodule:: mir_eval.hierarchy
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
:mod:`mir_eval.separation`
--------------------------
.. automodule:: mir_eval.separation
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
:mod:`mir_eval.tempo`
--------------------------
.. automodule:: mir_eval.tempo
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
:mod:`mir_eval.transcription`
-----------------------------
.. automodule:: mir_eval.transcription
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
:mod:`mir_eval.transcription_velocity`
--------------------------------------
.. automodule:: mir_eval.transcription_velocity
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
:mod:`mir_eval.key`
-----------------------------
.. automodule:: mir_eval.key
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
:mod:`mir_eval.util`
--------------------
.. automodule:: mir_eval.util
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
:mod:`mir_eval.io`
------------------
.. automodule:: mir_eval.io
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
:mod:`mir_eval.sonify`
----------------------
.. automodule:: mir_eval.sonify
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
:mod:`mir_eval.display`
-----------------------
.. automodule:: mir_eval.display
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
Changes
=======
.. toctree::
:maxdepth: 1
changes
Indices and tables
==================
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`
mir_eval-0.7/mir_eval/ 0000775 0000000 0000000 00000000000 14203260312 0014745 5 ustar 00root root 0000000 0000000 mir_eval-0.7/mir_eval/__init__.py 0000664 0000000 0000000 00000000766 14203260312 0017067 0 ustar 00root root 0000000 0000000 #!/usr/bin/env python
"""Top-level module for mir_eval"""
# Import all submodules (for each task)
from . import alignment
from . import beat
from . import chord
from . import io
from . import onset
from . import segment
from . import separation
from . import util
from . import sonify
from . import melody
from . import multipitch
from . import pattern
from . import tempo
from . import hierarchy
from . import transcription
from . import transcription_velocity
from . import key
__version__ = '0.7'
mir_eval-0.7/mir_eval/alignment.py 0000664 0000000 0000000 00000034156 14203260312 0017306 0 ustar 00root root 0000000 0000000 """
Alignment models are given a sequence of events along with a piece of audio, and then return a
sequence of timestamps, with one timestamp for each event, indicating the position of this event
in the audio. The events are listed in order of occurrence in the audio, so that output
timestamps have to be monotonically increasing.
Evaluation usually involves taking the series of predicted and ground truth timestamps and
comparing their distance, usually on a pair-wise basis, e.g. taking the median absolute error in
seconds.
Conventions
-----------
Timestamps should be provided in the form of a 1-dimensional array of onset
times in seconds in increasing order.
Metrics
-------
* :func:`mir_eval.alignment.absolute_error`: Median absolute error and average absolute error
* :func:`mir_eval.alignment.percentage_correct`: Percentage of correct timestamps,
where a timestamp is counted
as correct if it lies within a certain tolerance window around the ground truth timestamp
* :func:`mir_eval.alignment.pcs`: Percentage of correct segments: Percentage of overlap between
predicted segments and ground truth segments, where segments are defined by (start time,
end time) pairs
* :func:`mir_eval.alignment.perceptual_metric`: metric based on human synchronicity perception as
measured in the paper "User-centered evaluation of lyrics to audio alignment",
N. Lizé-Masclef, A. Vaglio, M. Moussallam, ISMIR 2021
References
----------
.. [#lizemasclef2021] N. Lizé-Masclef, A. Vaglio, M. Moussallam.
"User-centered evaluation of lyrics to audio alignment",
International Society for Music Information Retrieval (ISMIR) conference,
2021.
.. [#mauch2010] M. Mauch, F: Hiromasa, M. Goto.
"Lyrics-to-audio alignment and phrase-level segmentation using
incomplete internet-style chord annotations",
Frontiers in Proceedings of the Sound Music Computing Conference (SMC), 2010.
.. [#dzhambazov2017] G. Dzhambazov.
"Knowledge-Based Probabilistic Modeling For Tracking Lyrics In Music Audio Signals",
PhD Thesis, 2017.
.. [#fujihara2011] H. Fujihara, M. Goto, J. Ogata, H. Okuno.
"LyricSynchronizer: Automatic synchronization system between musical audio signals and lyrics",
IEEE Journal of Selected Topics in Signal Processing, VOL. 5, NO. 6, 2011
"""
import collections
from typing import Optional
import numpy as np
from scipy.stats import skewnorm
from mir_eval.util import filter_kwargs
def validate(
reference_timestamps: np.ndarray, estimated_timestamps: np.ndarray
):
"""Checks that the input annotations to a metric look like valid onset time
arrays, and throws helpful errors if not.
Parameters
----------
reference_timestamps : np.ndarray
reference timestamp locations, in seconds
estimated_timestamps : np.ndarray
estimated timestamp locations, in seconds
"""
# We need to have 1D numpy arrays
if not isinstance(reference_timestamps, np.ndarray):
raise ValueError(
"Reference timestamps need to be a numpy array, but got"
f" {type(reference_timestamps)}"
)
if not isinstance(estimated_timestamps, np.ndarray):
raise ValueError(
"Estimated timestamps need to be a numpy array, but got"
f" {type(estimated_timestamps)}"
)
if reference_timestamps.ndim != 1:
raise ValueError(
"Reference timestamps need to be a one-dimensional vector, but got"
f" {reference_timestamps.ndim} dimensions"
)
if estimated_timestamps.ndim != 1:
raise ValueError(
"Estimated timestamps need to be a one-dimensional vector, but got"
f" {estimated_timestamps.ndim} dimensions"
)
# If reference or estimated timestamps are empty, cannot compute metric
if reference_timestamps.size == 0:
raise ValueError("Reference timestamps are empty.")
if estimated_timestamps.size != reference_timestamps.size:
raise ValueError(
"Number of timestamps must be the same in prediction and ground"
f" truth, but found {estimated_timestamps.size} in prediction and"
f" {reference_timestamps.size} in ground truth"
)
# Check monotonicity
if not np.all(reference_timestamps[1:] - reference_timestamps[:-1] >= 0):
raise ValueError(
"Reference timestamps are not monotonically increasing!"
)
if not np.all(estimated_timestamps[1:] - estimated_timestamps[:-1] >= 0):
raise ValueError(
"Estimated timestamps are not monotonically increasing!"
)
# Check positivity (need for correct PCS metric calculation)
if not np.all(reference_timestamps >= 0):
raise ValueError("Reference timestamps can not be below 0!")
if not np.all(estimated_timestamps >= 0):
raise ValueError("Estimated timestamps can not be below 0!")
def absolute_error(reference_timestamps, estimated_timestamps):
"""Compute the absolute deviations between estimated and reference timestamps,
and then returns the median and average over all events
Examples
--------
>>> reference_timestamps = mir_eval.io.load_events('reference.txt')
>>> estimated_timestamps = mir_eval.io.load_events('estimated.txt')
>>> mae, aae = mir_eval.align.absolute_error(reference_onsets, estimated_timestamps)
Parameters
----------
reference_timestamps : np.ndarray
reference timestamps, in seconds
estimated_timestamps : np.ndarray
estimated timestamps, in seconds
Returns
-------
mae : float
Median absolute error
aae: float
Average absolute error
"""
validate(reference_timestamps, estimated_timestamps)
deviations = np.abs(reference_timestamps - estimated_timestamps)
return np.median(deviations), np.mean(deviations)
def percentage_correct(reference_timestamps, estimated_timestamps, window=0.3):
"""Compute the percentage of correctly predicted timestamps. A timestamp is predicted
correctly if its position doesn't deviate more than the window parameter from the ground
truth timestamp.
Examples
--------
>>> reference_timestamps = mir_eval.io.load_events('reference.txt')
>>> estimated_timestamps = mir_eval.io.load_events('estimated.txt')
>>> pc = mir_eval.align.percentage_correct(reference_onsets, estimated_timestamps, window=0.2)
Parameters
----------
reference_timestamps : np.ndarray
reference timestamps, in seconds
estimated_timestamps : np.ndarray
estimated timestamps, in seconds
window : float
Window size, in seconds
(Default value = .3)
Returns
-------
pc : float
Percentage of correct timestamps
"""
validate(reference_timestamps, estimated_timestamps)
deviations = np.abs(reference_timestamps - estimated_timestamps)
return np.mean(deviations <= window)
def percentage_correct_segments(
reference_timestamps, estimated_timestamps, duration: Optional[float] = None
):
"""Calculates the percentage of correct segments (PCS) metric.
It constructs segments out of predicted and estimated timestamps separately
out of each given timestamp vector and calculates the percentage of overlap between correct
segments compared to the total duration.
WARNING: This metrics behaves differently depending on whether "duration" is given!
If duration is not given (default case), the computation follows the MIREX lyrics alignment
challenge 2020. For a timestamp vector with entries (t1,t2, ... tN), segments with
the following (start, end) boundaries are created: (t1, t2), ... (tN-1, tN).
After the segments are created, the overlap between the reference and estimated segments is
determined and divided by the total duration, which is the distance between the
first and last timestamp in the reference.
If duration is given, the segment boundaries are instead (0, t1), (t1, t2), ... (tN, duration).
The overlap is computed in the same way, but then divided by the duration parameter given to
this function.
This method follows the original paper [#fujihara2011] more closely, where the metric was
proposed.
As a result, this variant of the metrics punishes cases where the first estimated timestamp
is too early or the last estimated timestamp is too late, whereas the MIREX variant does not.
On the other hand, the MIREX metric is invariant to how long the eventless beginning and end
parts of the audio are, which might be a desirable property.
Examples
--------
>>> reference_timestamps = mir_eval.io.load_events('reference.txt')
>>> estimated_timestamps = mir_eval.io.load_events('estimated.txt')
>>> pcs = mir_eval.align.percentage_correct_segments(reference_timestamps, estimated_timestamps)
Parameters
----------
reference_timestamps : np.ndarray
reference timestamps, in seconds
estimated_timestamps : np.ndarray
estimated timestamps, in seconds
duration : float
Optional. Total duration of audio (seconds). WARNING: Metric is computed differently
depending on whether this is provided or not - see documentation above!
Returns
-------
pcs : float
Percentage of time where ground truth and predicted segments overlap
"""
validate(reference_timestamps, estimated_timestamps)
if duration is not None:
duration = float(duration)
if duration <= 0:
raise ValueError(
f"Positive duration needs to be provided, but got {duration}"
)
if np.max(reference_timestamps) > duration:
raise ValueError(
"Expected largest reference timestamp"
f"{np.max(reference_timestamps)} to not be "
f"larger than duration {duration}"
)
if np.max(estimated_timestamps) > duration:
raise ValueError(
"Expected largest estimated timestamp "
f"{np.max(estimated_timestamps)} to not be "
f"larger than duration {duration}"
)
ref_starts = np.concatenate([[0], reference_timestamps])
ref_ends = np.concatenate([reference_timestamps, [duration]])
est_starts = np.concatenate([[0], estimated_timestamps])
est_ends = np.concatenate([estimated_timestamps, [duration]])
else:
# MIREX lyrics alignment 2020 style:
# Ignore regions before start and after end reference timestamp
duration = reference_timestamps[-1] - reference_timestamps[0]
if duration <= 0:
raise ValueError(
f"Reference timestamps are all identical, can not compute PCS"
f" metric!"
)
ref_starts = reference_timestamps[:-1]
ref_ends = reference_timestamps[1:]
est_starts = estimated_timestamps[:-1]
est_ends = estimated_timestamps[1:]
overlap_starts = np.maximum(ref_starts, est_starts)
overlap_ends = np.minimum(ref_ends, est_ends)
overlap_duration = np.sum(np.maximum(overlap_ends - overlap_starts, 0))
return overlap_duration / duration
def karaoke_perceptual_metric(reference_timestamps, estimated_timestamps):
"""Metric based on human synchronicity perception as measured in the paper
"User-centered evaluation of lyrics to audio alignment" [#lizemasclef2021]
The parameters of this function were tuned on data collected through a user Karaoke-like
experiment
It reflects human judgment of how "synchronous" lyrics and audio stimuli are perceived
in that setup.
Beware that this metric is non-symmetrical and by construction it is also not equal to 1 at 0.
Examples
--------
>>> reference_timestamps = mir_eval.io.load_events('reference.txt')
>>> estimated_timestamps = mir_eval.io.load_events('estimated.txt')
>>> score = mir_eval.align.karaoke_perceptual_metric(reference_onsets, estimated_timestamps)
Parameters
----------
reference_timestamps : np.ndarray
reference timestamps, in seconds
estimated_timestamps : np.ndarray
estimated timestamps, in seconds
Returns
-------
perceptual_score : float
Perceptual score, averaged over all timestamps
"""
validate(reference_timestamps, estimated_timestamps)
offsets = estimated_timestamps - reference_timestamps
# Score offsets using a certain skewed normal distribution
skewness = 1.12244251
localisation = -0.22270315
scale = 0.29779424
normalisation_factor = 1.6857
perceptual_scores = (1.0 / normalisation_factor) * skewnorm.pdf(
offsets, skewness, loc=localisation, scale=scale
)
return np.mean(perceptual_scores)
def evaluate(reference_timestamps, estimated_timestamps, **kwargs):
"""Compute all metrics for the given reference and estimated annotations.
Examples
--------
>>> reference_timestamps = mir_eval.io.load_events('reference.txt')
>>> estimated_timestamps = mir_eval.io.load_events('estimated.txt')
>>> duration = max(np.max(reference_timestamps), np.max(estimated_timestamps)) + 10
>>> scores = mir_eval.align.evaluate(reference_onsets, estimated_timestamps, duration)
Parameters
----------
reference_timestamps : np.ndarray
reference timestamp locations, in seconds
estimated_timestamps : np.ndarray
estimated timestamp locations, in seconds
kwargs
Additional keyword arguments which will be passed to the
appropriate metric or preprocessing functions.
Returns
-------
scores : dict
Dictionary of scores, where the key is the metric name (str) and
the value is the (float) score achieved.
"""
# Compute all metrics
scores = collections.OrderedDict()
scores["pc"] = filter_kwargs(
percentage_correct, reference_timestamps, estimated_timestamps, **kwargs
)
scores["mae"], scores["aae"] = absolute_error(
reference_timestamps, estimated_timestamps
)
scores["pcs"] = filter_kwargs(
percentage_correct_segments,
reference_timestamps,
estimated_timestamps,
**kwargs,
)
scores["perceptual"] = karaoke_perceptual_metric(
reference_timestamps, estimated_timestamps
)
return scores
mir_eval-0.7/mir_eval/beat.py 0000664 0000000 0000000 00000075722 14203260312 0016247 0 ustar 00root root 0000000 0000000 '''
The aim of a beat detection algorithm is to report the times at which a typical
human listener might tap their foot to a piece of music. As a result, most
metrics for evaluating the performance of beat tracking systems involve
computing the error between the estimated beat times and some reference list of
beat locations. Many metrics additionally compare the beat sequences at
different metric levels in order to deal with the ambiguity of tempo.
Based on the methods described in:
Matthew E. P. Davies, Norberto Degara, and Mark D. Plumbley.
"Evaluation Methods for Musical Audio Beat Tracking Algorithms",
Queen Mary University of London Technical Report C4DM-TR-09-06
London, United Kingdom, 8 October 2009.
See also the Beat Evaluation Toolbox:
https://code.soundsoftware.ac.uk/projects/beat-evaluation/
Conventions
-----------
Beat times should be provided in the form of a 1-dimensional array of beat
times in seconds in increasing order. Typically, any beats which occur before
5s are ignored; this can be accomplished using
:func:`mir_eval.beat.trim_beats()`.
Metrics
-------
* :func:`mir_eval.beat.f_measure`: The F-measure of the beat sequence, where an
estimated beat is considered correct if it is sufficiently close to a
reference beat
* :func:`mir_eval.beat.cemgil`: Cemgil's score, which computes the sum of
Gaussian errors for each beat
* :func:`mir_eval.beat.goto`: Goto's score, a binary score which is 1 when at
least 25\% of the estimated beat sequence closely matches the reference beat
sequence
* :func:`mir_eval.beat.p_score`: McKinney's P-score, which computes the
cross-correlation of the estimated and reference beat sequences represented
as impulse trains
* :func:`mir_eval.beat.continuity`: Continuity-based scores which compute the
proportion of the beat sequence which is continuously correct
* :func:`mir_eval.beat.information_gain`: The Information Gain of a normalized
beat error histogram over a uniform distribution
'''
import numpy as np
import collections
from . import util
import warnings
# The maximum allowable beat time
MAX_TIME = 30000.
def trim_beats(beats, min_beat_time=5.):
"""Removes beats before min_beat_time. A common preprocessing step.
Parameters
----------
beats : np.ndarray
Array of beat times in seconds.
min_beat_time : float
Minimum beat time to allow
(Default value = 5.)
Returns
-------
beats_trimmed : np.ndarray
Trimmed beat array.
"""
# Remove beats before min_beat_time
return beats[beats >= min_beat_time]
def validate(reference_beats, estimated_beats):
"""Checks that the input annotations to a metric look like valid beat time
arrays, and throws helpful errors if not.
Parameters
----------
reference_beats : np.ndarray
reference beat times, in seconds
estimated_beats : np.ndarray
estimated beat times, in seconds
"""
# If reference or estimated beats are empty,
# warn because metric will be 0
if reference_beats.size == 0:
warnings.warn("Reference beats are empty.")
if estimated_beats.size == 0:
warnings.warn("Estimated beats are empty.")
for beats in [reference_beats, estimated_beats]:
util.validate_events(beats, MAX_TIME)
def _get_reference_beat_variations(reference_beats):
"""Return metric variations of the reference beats
Parameters
----------
reference_beats : np.ndarray
beat locations in seconds
Returns
-------
reference_beats : np.ndarray
Original beat locations
off_beat : np.ndarray
180 degrees out of phase from the original beat locations
double : np.ndarray
Beats at 2x the original tempo
half_odd : np.ndarray
Half tempo, odd beats
half_even : np.ndarray
Half tempo, even beats
"""
# Create annotations at twice the metric level
interpolated_indices = np.arange(0, reference_beats.shape[0]-.5, .5)
original_indices = np.arange(0, reference_beats.shape[0])
double_reference_beats = np.interp(interpolated_indices,
original_indices,
reference_beats)
# Return metric variations:
# True, off-beat, double tempo, half tempo odd, and half tempo even
return (reference_beats,
double_reference_beats[1::2],
double_reference_beats,
reference_beats[::2],
reference_beats[1::2])
def f_measure(reference_beats,
estimated_beats,
f_measure_threshold=0.07):
"""Compute the F-measure of correct vs incorrectly predicted beats.
"Correctness" is determined over a small window.
Examples
--------
>>> reference_beats = mir_eval.io.load_events('reference.txt')
>>> reference_beats = mir_eval.beat.trim_beats(reference_beats)
>>> estimated_beats = mir_eval.io.load_events('estimated.txt')
>>> estimated_beats = mir_eval.beat.trim_beats(estimated_beats)
>>> f_measure = mir_eval.beat.f_measure(reference_beats,
estimated_beats)
Parameters
----------
reference_beats : np.ndarray
reference beat times, in seconds
estimated_beats : np.ndarray
estimated beat times, in seconds
f_measure_threshold : float
Window size, in seconds
(Default value = 0.07)
Returns
-------
f_score : float
The computed F-measure score
"""
validate(reference_beats, estimated_beats)
# When estimated beats are empty, no beats are correct; metric is 0
if estimated_beats.size == 0 or reference_beats.size == 0:
return 0.
# Compute the best-case matching between reference and estimated locations
matching = util.match_events(reference_beats,
estimated_beats,
f_measure_threshold)
precision = float(len(matching))/len(estimated_beats)
recall = float(len(matching))/len(reference_beats)
return util.f_measure(precision, recall)
def cemgil(reference_beats,
estimated_beats,
cemgil_sigma=0.04):
"""Cemgil's score, computes a gaussian error of each estimated beat.
Compares against the original beat times and all metrical variations.
Examples
--------
>>> reference_beats = mir_eval.io.load_events('reference.txt')
>>> reference_beats = mir_eval.beat.trim_beats(reference_beats)
>>> estimated_beats = mir_eval.io.load_events('estimated.txt')
>>> estimated_beats = mir_eval.beat.trim_beats(estimated_beats)
>>> cemgil_score, cemgil_max = mir_eval.beat.cemgil(reference_beats,
estimated_beats)
Parameters
----------
reference_beats : np.ndarray
reference beat times, in seconds
estimated_beats : np.ndarray
query beat times, in seconds
cemgil_sigma : float
Sigma parameter of gaussian error windows
(Default value = 0.04)
Returns
-------
cemgil_score : float
Cemgil's score for the original reference beats
cemgil_max : float
The best Cemgil score for all metrical variations
"""
validate(reference_beats, estimated_beats)
# When estimated beats are empty, no beats are correct; metric is 0
if estimated_beats.size == 0 or reference_beats.size == 0:
return 0., 0.
# We'll compute Cemgil's accuracy for each variation
accuracies = []
for reference_beats in _get_reference_beat_variations(reference_beats):
accuracy = 0
# Cycle through beats
for beat in reference_beats:
# Find the error for the closest beat to the reference beat
beat_diff = np.min(np.abs(beat - estimated_beats))
# Add gaussian error into the accuracy
accuracy += np.exp(-(beat_diff**2)/(2.0*cemgil_sigma**2))
# Normalize the accuracy
accuracy /= .5*(estimated_beats.shape[0] + reference_beats.shape[0])
# Add it to our list of accuracy scores
accuracies.append(accuracy)
# Return raw accuracy with non-varied annotations
# and maximal accuracy across all variations
return accuracies[0], np.max(accuracies)
def goto(reference_beats,
estimated_beats,
goto_threshold=0.35,
goto_mu=0.2,
goto_sigma=0.2):
"""Calculate Goto's score, a binary 1 or 0 depending on some specific
heuristic criteria
Examples
--------
>>> reference_beats = mir_eval.io.load_events('reference.txt')
>>> reference_beats = mir_eval.beat.trim_beats(reference_beats)
>>> estimated_beats = mir_eval.io.load_events('estimated.txt')
>>> estimated_beats = mir_eval.beat.trim_beats(estimated_beats)
>>> goto_score = mir_eval.beat.goto(reference_beats, estimated_beats)
Parameters
----------
reference_beats : np.ndarray
reference beat times, in seconds
estimated_beats : np.ndarray
query beat times, in seconds
goto_threshold : float
Threshold of beat error for a beat to be "correct"
(Default value = 0.35)
goto_mu : float
The mean of the beat errors in the continuously correct
track must be less than this
(Default value = 0.2)
goto_sigma : float
The std of the beat errors in the continuously correct track must
be less than this
(Default value = 0.2)
Returns
-------
goto_score : float
Either 1.0 or 0.0 if some specific criteria are met
"""
validate(reference_beats, estimated_beats)
# When estimated beats are empty, no beats are correct; metric is 0
if estimated_beats.size == 0 or reference_beats.size == 0:
return 0.
# Error for each beat
beat_error = np.ones(reference_beats.shape[0])
# Flag for whether the reference and estimated beats are paired
paired = np.zeros(reference_beats.shape[0])
# Keep track of Goto's three criteria
goto_criteria = 0
for n in range(1, reference_beats.shape[0]-1):
# Get previous inner-reference-beat-interval
previous_interval = 0.5*(reference_beats[n] - reference_beats[n-1])
# Window start - in the middle of the current beat and the previous
window_min = reference_beats[n] - previous_interval
# Next inter-reference-beat-interval
next_interval = 0.5*(reference_beats[n+1] - reference_beats[n])
# Window end - in the middle of the current beat and the next
window_max = reference_beats[n] + next_interval
# Get estimated beats in the window
beats_in_window = np.logical_and((estimated_beats >= window_min),
(estimated_beats < window_max))
# False negative/positive
if beats_in_window.sum() == 0 or beats_in_window.sum() > 1:
paired[n] = 0
beat_error[n] = 1
else:
# Single beat is paired!
paired[n] = 1
# Get offset of the estimated beat and the reference beat
offset = estimated_beats[beats_in_window] - reference_beats[n]
# Scale by previous or next interval
if offset < 0:
beat_error[n] = offset/previous_interval
else:
beat_error[n] = offset/next_interval
# Get indices of incorrect beats
incorrect_beats = np.flatnonzero(np.abs(beat_error) > goto_threshold)
# All beats are correct (first and last will be 0 so always correct)
if incorrect_beats.shape[0] < 3:
# Get the track of correct beats
track = beat_error[incorrect_beats[0] + 1:incorrect_beats[-1] - 1]
goto_criteria = 1
else:
# Get the track of maximal length
track_len = np.max(np.diff(incorrect_beats))
track_start = np.flatnonzero(np.diff(incorrect_beats) == track_len)[0]
# Is the track length at least 25% of the song?
if track_len - 1 > .25*(reference_beats.shape[0] - 2):
goto_criteria = 1
start_beat = incorrect_beats[track_start]
end_beat = incorrect_beats[track_start + 1]
track = beat_error[start_beat:end_beat + 1]
# If we have a track
if goto_criteria:
# Are mean and std of the track less than the required thresholds?
if np.mean(np.abs(track)) < goto_mu \
and np.std(track, ddof=1) < goto_sigma:
goto_criteria = 3
# If all criteria are met, score is 100%!
return 1.0*(goto_criteria == 3)
def p_score(reference_beats,
estimated_beats,
p_score_threshold=0.2):
"""Get McKinney's P-score.
Based on the autocorrelation of the reference and estimated beats
Examples
--------
>>> reference_beats = mir_eval.io.load_events('reference.txt')
>>> reference_beats = mir_eval.beat.trim_beats(reference_beats)
>>> estimated_beats = mir_eval.io.load_events('estimated.txt')
>>> estimated_beats = mir_eval.beat.trim_beats(estimated_beats)
>>> p_score = mir_eval.beat.p_score(reference_beats, estimated_beats)
Parameters
----------
reference_beats : np.ndarray
reference beat times, in seconds
estimated_beats : np.ndarray
query beat times, in seconds
p_score_threshold : float
Window size will be
``p_score_threshold*np.median(inter_annotation_intervals)``,
(Default value = 0.2)
Returns
-------
correlation : float
McKinney's P-score
"""
validate(reference_beats, estimated_beats)
# Warn when only one beat is provided for either estimated or reference,
# report a warning
if reference_beats.size == 1:
warnings.warn("Only one reference beat was provided, so beat intervals"
" cannot be computed.")
if estimated_beats.size == 1:
warnings.warn("Only one estimated beat was provided, so beat intervals"
" cannot be computed.")
# When estimated or reference beats have <= 1 beats, can't compute the
# metric, so return 0
if estimated_beats.size <= 1 or reference_beats.size <= 1:
return 0.
# Quantize beats to 10ms
sampling_rate = int(1.0/0.010)
# Shift beats so that the minimum in either sequence is zero
offset = min(estimated_beats.min(), reference_beats.min())
estimated_beats = np.array(estimated_beats - offset)
reference_beats = np.array(reference_beats - offset)
# Get the largest time index
end_point = np.int(np.ceil(np.max([np.max(estimated_beats),
np.max(reference_beats)])))
# Make impulse trains with impulses at beat locations
reference_train = np.zeros(end_point*sampling_rate + 1)
beat_indices = np.ceil(reference_beats*sampling_rate).astype(np.int)
reference_train[beat_indices] = 1.0
estimated_train = np.zeros(end_point*sampling_rate + 1)
beat_indices = np.ceil(estimated_beats*sampling_rate).astype(np.int)
estimated_train[beat_indices] = 1.0
# Window size to take the correlation over
# defined as .2*median(inter-annotation-intervals)
annotation_intervals = np.diff(np.flatnonzero(reference_train))
win_size = int(np.round(p_score_threshold*np.median(annotation_intervals)))
# Get full correlation
train_correlation = np.correlate(reference_train, estimated_train, 'full')
# Get the middle element - note we are rounding down on purpose here
middle_lag = train_correlation.shape[0]//2
# Truncate to only valid lags (those corresponding to the window)
start = middle_lag - win_size
end = middle_lag + win_size + 1
train_correlation = train_correlation[start:end]
# Compute and return the P-score
n_beats = np.max([estimated_beats.shape[0], reference_beats.shape[0]])
return np.sum(train_correlation)/n_beats
def continuity(reference_beats,
estimated_beats,
continuity_phase_threshold=0.175,
continuity_period_threshold=0.175):
"""Get metrics based on how much of the estimated beat sequence is
continually correct.
Examples
--------
>>> reference_beats = mir_eval.io.load_events('reference.txt')
>>> reference_beats = mir_eval.beat.trim_beats(reference_beats)
>>> estimated_beats = mir_eval.io.load_events('estimated.txt')
>>> estimated_beats = mir_eval.beat.trim_beats(estimated_beats)
>>> CMLc, CMLt, AMLc, AMLt = mir_eval.beat.continuity(reference_beats,
estimated_beats)
Parameters
----------
reference_beats : np.ndarray
reference beat times, in seconds
estimated_beats : np.ndarray
query beat times, in seconds
continuity_phase_threshold : float
Allowable ratio of how far is the estimated beat
can be from the reference beat
(Default value = 0.175)
continuity_period_threshold : float
Allowable distance between the inter-beat-interval
and the inter-annotation-interval
(Default value = 0.175)
Returns
-------
CMLc : float
Correct metric level, continuous accuracy
CMLt : float
Correct metric level, total accuracy (continuity not required)
AMLc : float
Any metric level, continuous accuracy
AMLt : float
Any metric level, total accuracy (continuity not required)
"""
validate(reference_beats, estimated_beats)
# Warn when only one beat is provided for either estimated or reference,
# report a warning
if reference_beats.size == 1:
warnings.warn("Only one reference beat was provided, so beat intervals"
" cannot be computed.")
if estimated_beats.size == 1:
warnings.warn("Only one estimated beat was provided, so beat intervals"
" cannot be computed.")
# When estimated or reference beats have <= 1 beats, can't compute the
# metric, so return 0
if estimated_beats.size <= 1 or reference_beats.size <= 1:
return 0., 0., 0., 0.
# Accuracies for each variation
continuous_accuracies = []
total_accuracies = []
# Get accuracy for each variation
for reference_beats in _get_reference_beat_variations(reference_beats):
# Annotations that have been used
n_annotations = np.max([reference_beats.shape[0],
estimated_beats.shape[0]])
used_annotations = np.zeros(n_annotations)
# Whether or not we are continuous at any given point
beat_successes = np.zeros(n_annotations)
for m in range(estimated_beats.shape[0]):
# Is this beat correct?
beat_success = 0
# Get differences for this beat
beat_differences = np.abs(estimated_beats[m] - reference_beats)
# Get nearest annotation index
nearest = np.argmin(beat_differences)
min_difference = beat_differences[nearest]
# Have we already used this annotation?
if used_annotations[nearest] == 0:
# Is this the first beat or first annotation?
# If so, look forward.
if m == 0 or nearest == 0:
# How far is the estimated beat from the reference beat,
# relative to the inter-annotation-interval?
if nearest + 1 < reference_beats.shape[0]:
reference_interval = (reference_beats[nearest + 1] -
reference_beats[nearest])
else:
# Special case when nearest + 1 is too large - use the
# previous interval instead
reference_interval = (reference_beats[nearest] -
reference_beats[nearest - 1])
# Handle this special case when beats are not unique
if reference_interval == 0:
if min_difference == 0:
phase = 1
else:
phase = np.inf
else:
phase = np.abs(min_difference/reference_interval)
# How close is the inter-beat-interval
# to the inter-annotation-interval?
if m + 1 < estimated_beats.shape[0]:
estimated_interval = (estimated_beats[m + 1] -
estimated_beats[m])
else:
# Special case when m + 1 is too large - use the
# previous interval
estimated_interval = (estimated_beats[m] -
estimated_beats[m - 1])
# Handle this special case when beats are not unique
if reference_interval == 0:
if estimated_interval == 0:
period = 0
else:
period = np.inf
else:
period = \
np.abs(1 - estimated_interval/reference_interval)
if phase < continuity_phase_threshold and \
period < continuity_period_threshold:
# Set this annotation as used
used_annotations[nearest] = 1
# This beat is matched
beat_success = 1
# This beat/annotation is not the first
else:
# How far is the estimated beat from the reference beat,
# relative to the inter-annotation-interval?
reference_interval = (reference_beats[nearest] -
reference_beats[nearest - 1])
phase = np.abs(min_difference/reference_interval)
# How close is the inter-beat-interval
# to the inter-annotation-interval?
estimated_interval = (estimated_beats[m] -
estimated_beats[m - 1])
reference_interval = (reference_beats[nearest] -
reference_beats[nearest - 1])
period = np.abs(1 - estimated_interval/reference_interval)
if phase < continuity_phase_threshold and \
period < continuity_period_threshold:
# Set this annotation as used
used_annotations[nearest] = 1
# This beat is matched
beat_success = 1
# Set whether this beat is matched or not
beat_successes[m] = beat_success
# Add 0s at the begnning and end
# so that we at least find the beginning/end of the estimated beats
beat_successes = np.append(np.append(0, beat_successes), 0)
# Where is the beat not a match?
beat_failures = np.nonzero(beat_successes == 0)[0]
# Take out those zeros we added
beat_successes = beat_successes[1:-1]
# Get the continuous accuracy as the longest track of successful beats
longest_track = np.max(np.diff(beat_failures)) - 1
continuous_accuracy = longest_track/(1.0*beat_successes.shape[0])
continuous_accuracies.append(continuous_accuracy)
# Get the total accuracy - all sequences
total_accuracy = np.sum(beat_successes)/(1.0*beat_successes.shape[0])
total_accuracies.append(total_accuracy)
# Grab accuracy scores
return (continuous_accuracies[0],
total_accuracies[0],
np.max(continuous_accuracies),
np.max(total_accuracies))
def information_gain(reference_beats,
estimated_beats,
bins=41):
"""Get the information gain - K-L divergence of the beat error histogram
to a uniform histogram
Examples
--------
>>> reference_beats = mir_eval.io.load_events('reference.txt')
>>> reference_beats = mir_eval.beat.trim_beats(reference_beats)
>>> estimated_beats = mir_eval.io.load_events('estimated.txt')
>>> estimated_beats = mir_eval.beat.trim_beats(estimated_beats)
>>> information_gain = mir_eval.beat.information_gain(reference_beats,
estimated_beats)
Parameters
----------
reference_beats : np.ndarray
reference beat times, in seconds
estimated_beats : np.ndarray
query beat times, in seconds
bins : int
Number of bins in the beat error histogram
(Default value = 41)
Returns
-------
information_gain_score : float
Entropy of beat error histogram
"""
validate(reference_beats, estimated_beats)
# If an even number of bins is provided,
# there will be no bin centered at zero, so warn the user.
if not bins % 2:
warnings.warn("bins parameter is even, "
"so there will not be a bin centered at zero.")
# Warn when only one beat is provided for either estimated or reference,
# report a warning
if reference_beats.size == 1:
warnings.warn("Only one reference beat was provided, so beat intervals"
" cannot be computed.")
if estimated_beats.size == 1:
warnings.warn("Only one estimated beat was provided, so beat intervals"
" cannot be computed.")
# When estimated or reference beats have <= 1 beats, can't compute the
# metric, so return 0
if estimated_beats.size <= 1 or reference_beats.size <= 1:
return 0.
# Get entropy for reference beats->estimated beats
# and estimated beats->reference beats
forward_entropy = _get_entropy(reference_beats, estimated_beats, bins)
backward_entropy = _get_entropy(estimated_beats, reference_beats, bins)
# Pick the larger of the entropies
norm = np.log2(bins)
if forward_entropy > backward_entropy:
# Note that the beat evaluation toolbox does not normalize
information_gain_score = (norm - forward_entropy)/norm
else:
information_gain_score = (norm - backward_entropy)/norm
return information_gain_score
def _get_entropy(reference_beats, estimated_beats, bins):
"""Helper function for information gain
(needs to be run twice - once backwards, once forwards)
Parameters
----------
reference_beats : np.ndarray
reference beat times, in seconds
estimated_beats : np.ndarray
query beat times, in seconds
bins : int
Number of bins in the beat error histogram
Returns
-------
entropy : float
Entropy of beat error histogram
"""
beat_error = np.zeros(estimated_beats.shape[0])
for n in range(estimated_beats.shape[0]):
# Get index of closest annotation to this beat
beat_distances = estimated_beats[n] - reference_beats
closest_beat = np.argmin(np.abs(beat_distances))
absolute_error = beat_distances[closest_beat]
# If the first annotation is closest...
if closest_beat == 0:
# Inter-annotation interval - space between first two beats
interval = .5*(reference_beats[1] - reference_beats[0])
# If last annotation is closest...
if closest_beat == (reference_beats.shape[0] - 1):
interval = .5*(reference_beats[-1] - reference_beats[-2])
else:
if absolute_error < 0:
# Closest annotation is the one before the current beat
# so look at previous inner-annotation-interval
start = reference_beats[closest_beat]
end = reference_beats[closest_beat - 1]
interval = .5*(start - end)
else:
# Closest annotation is the one after the current beat
# so look at next inner-annotation-interval
start = reference_beats[closest_beat + 1]
end = reference_beats[closest_beat]
interval = .5*(start - end)
# The actual error of this beat
beat_error[n] = .5*absolute_error/interval
# Put beat errors in range (-.5, .5)
beat_error = np.mod(beat_error + .5, -1) + .5
# Note these are slightly different the beat evaluation toolbox
# (they are uniform)
histogram_bin_edges = np.linspace(-.5, .5, bins + 1)
# Get the histogram
raw_bin_values = np.histogram(beat_error, histogram_bin_edges)[0]
# Turn into a proper probability distribution
raw_bin_values = raw_bin_values/(1.0*np.sum(raw_bin_values))
# Set zero-valued bins to 1 to make the entropy calculation well-behaved
raw_bin_values[raw_bin_values == 0] = 1
# Calculate entropy
return -np.sum(raw_bin_values * np.log2(raw_bin_values))
def evaluate(reference_beats, estimated_beats, **kwargs):
"""Compute all metrics for the given reference and estimated annotations.
Examples
--------
>>> reference_beats = mir_eval.io.load_events('reference.txt')
>>> estimated_beats = mir_eval.io.load_events('estimated.txt')
>>> scores = mir_eval.beat.evaluate(reference_beats, estimated_beats)
Parameters
----------
reference_beats : np.ndarray
Reference beat times, in seconds
estimated_beats : np.ndarray
Query beat times, in seconds
kwargs
Additional keyword arguments which will be passed to the
appropriate metric or preprocessing functions.
Returns
-------
scores : dict
Dictionary of scores, where the key is the metric name (str) and
the value is the (float) score achieved.
"""
# Trim beat times at the beginning of the annotations
reference_beats = util.filter_kwargs(trim_beats, reference_beats, **kwargs)
estimated_beats = util.filter_kwargs(trim_beats, estimated_beats, **kwargs)
# Now compute all the metrics
scores = collections.OrderedDict()
# F-Measure
scores['F-measure'] = util.filter_kwargs(f_measure, reference_beats,
estimated_beats, **kwargs)
# Cemgil
scores['Cemgil'], scores['Cemgil Best Metric Level'] = \
util.filter_kwargs(cemgil, reference_beats, estimated_beats, **kwargs)
# Goto
scores['Goto'] = util.filter_kwargs(goto, reference_beats,
estimated_beats, **kwargs)
# P-Score
scores['P-score'] = util.filter_kwargs(p_score, reference_beats,
estimated_beats, **kwargs)
# Continuity metrics
(scores['Correct Metric Level Continuous'],
scores['Correct Metric Level Total'],
scores['Any Metric Level Continuous'],
scores['Any Metric Level Total']) = util.filter_kwargs(continuity,
reference_beats,
estimated_beats,
**kwargs)
# Information gain
scores['Information gain'] = util.filter_kwargs(information_gain,
reference_beats,
estimated_beats,
**kwargs)
return scores
mir_eval-0.7/mir_eval/chord.py 0000664 0000000 0000000 00000164255 14203260312 0016433 0 ustar 00root root 0000000 0000000 r'''
Chord estimation algorithms produce a list of intervals and labels which denote
the chord being played over each timespan. They are evaluated by comparing the
estimated chord labels to some reference, usually using a mapping to a chord
subalphabet (e.g. minor and major chords only, all triads, etc.). There is no
single 'right' way to compare two sequences of chord labels. Embracing this
reality, every conventional comparison rule is provided. Comparisons are made
over the different components of each chord (e.g. G:maj(6)/5): the root (G),
the root-invariant active semitones as determined by the quality
shorthand (maj) and scale degrees (6), and the bass interval (5).
This submodule provides functions both for comparing a sequences of chord
labels according to some chord subalphabet mapping and for using these
comparisons to score a sequence of estimated chords against a reference.
Conventions
-----------
A sequence of chord labels is represented as a list of strings, where each
label is the chord name based on the syntax of [#harte2010towards]_. Reference
and estimated chord label sequences should be of the same length for comparison
functions. When converting the chord string into its constituent parts,
* Pitch class counting starts at C, e.g. C:0, D:2, E:4, F:5, etc.
* Scale degree is represented as a string of the diatonic interval, relative to
the root note, e.g. 'b6', '#5', or '7'
* Bass intervals are represented as strings
* Chord bitmaps are positional binary vectors indicating active pitch classes
and may be absolute or relative depending on context in the code.
If no chord is present at a given point in time, it should have the label 'N',
which is defined in the variable ``mir_eval.chord.NO_CHORD``.
Metrics
-------
* :func:`mir_eval.chord.root`: Only compares the root of the chords.
* :func:`mir_eval.chord.majmin`: Only compares major, minor, and "no chord"
labels.
* :func:`mir_eval.chord.majmin_inv`: Compares major/minor chords, with
inversions. The bass note must exist in the triad.
* :func:`mir_eval.chord.mirex`: A estimated chord is considered correct if it
shares *at least* three pitch classes in common.
* :func:`mir_eval.chord.thirds`: Chords are compared at the level of major or
minor thirds (root and third), For example, both ('A:7', 'A:maj') and
('A:min', 'A:dim') are equivalent, as the third is major and minor in
quality, respectively.
* :func:`mir_eval.chord.thirds_inv`: Same as above, with inversions (bass
relationships).
* :func:`mir_eval.chord.triads`: Chords are considered at the level of triads
(major, minor, augmented, diminished, suspended), meaning that, in addition
to the root, the quality is only considered through #5th scale degree (for
augmented chords). For example, ('A:7', 'A:maj') are equivalent, while
('A:min', 'A:dim') and ('A:aug', 'A:maj') are not.
* :func:`mir_eval.chord.triads_inv`: Same as above, with inversions (bass
relationships).
* :func:`mir_eval.chord.tetrads`: Chords are considered at the level of the
entire quality in closed voicing, i.e. spanning only a single octave;
extended chords (9's, 11's and 13's) are rolled into a single octave with any
upper voices included as extensions. For example, ('A:7', 'A:9') are
equivlent but ('A:7', 'A:maj7') are not.
* :func:`mir_eval.chord.tetrads_inv`: Same as above, with inversions (bass
relationships).
* :func:`mir_eval.chord.sevenths`: Compares according to MIREX "sevenths"
rules; that is, only major, major seventh, seventh, minor, minor seventh and
no chord labels are compared.
* :func:`mir_eval.chord.sevenths_inv`: Same as above, with inversions (bass
relationships).
* :func:`mir_eval.chord.overseg`: Computes the level of over-segmentation
between estimated and reference intervals.
* :func:`mir_eval.chord.underseg`: Computes the level of under-segmentation
between estimated and reference intervals.
* :func:`mir_eval.chord.seg`: Computes the minimum of over- and
under-segmentation between estimated and reference intervals.
References
----------
.. [#harte2010towards] C. Harte. Towards Automatic Extraction of Harmony
Information from Music Signals. PhD thesis, Queen Mary University of
London, August 2010.
'''
import numpy as np
import warnings
import collections
import re
from mir_eval import util
BITMAP_LENGTH = 12
NO_CHORD = "N"
NO_CHORD_ENCODED = -1, np.array([0]*BITMAP_LENGTH), -1
X_CHORD = "X"
X_CHORD_ENCODED = -1, np.array([-1]*BITMAP_LENGTH), -1
class InvalidChordException(Exception):
r'''Exception class for suspect / invalid chord labels'''
def __init__(self, message='', chord_label=None):
self.message = message
self.chord_label = chord_label
self.name = self.__class__.__name__
super(InvalidChordException, self).__init__(message)
# --- Chord Primitives ---
def _pitch_classes():
r'''Map from pitch class (str) to semitone (int).'''
pitch_classes = ['C', 'D', 'E', 'F', 'G', 'A', 'B']
semitones = [0, 2, 4, 5, 7, 9, 11]
return dict([(c, s) for c, s in zip(pitch_classes, semitones)])
def _scale_degrees():
r'''Mapping from scale degrees (str) to semitones (int).'''
degrees = ['1', '2', '3', '4', '5', '6', '7',
'8', '9', '10', '11', '12', '13']
semitones = [0, 2, 4, 5, 7, 9, 11, 12, 14, 16, 17, 19, 21]
return dict([(d, s) for d, s in zip(degrees, semitones)])
# Maps pitch classes (strings) to semitone indexes (ints).
PITCH_CLASSES = _pitch_classes()
def pitch_class_to_semitone(pitch_class):
r'''Convert a pitch class to semitone.
Parameters
----------
pitch_class : str
Spelling of a given pitch class, e.g. 'C#', 'Gbb'
Returns
-------
semitone : int
Semitone value of the pitch class.
'''
semitone = 0
for idx, char in enumerate(pitch_class):
if char == '#' and idx > 0:
semitone += 1
elif char == 'b' and idx > 0:
semitone -= 1
elif idx == 0:
semitone = PITCH_CLASSES.get(char)
else:
raise InvalidChordException(
"Pitch class improperly formed: %s" % pitch_class)
return semitone % 12
# Maps scale degrees (strings) to semitone indexes (ints).
SCALE_DEGREES = _scale_degrees()
def scale_degree_to_semitone(scale_degree):
r"""Convert a scale degree to semitone.
Parameters
----------
scale degree : str
Spelling of a relative scale degree, e.g. 'b3', '7', '#5'
Returns
-------
semitone : int
Relative semitone of the scale degree, wrapped to a single octave
Raises
------
InvalidChordException if `scale_degree` is invalid.
"""
semitone = 0
offset = 0
if scale_degree.startswith("#"):
offset = scale_degree.count("#")
scale_degree = scale_degree.strip("#")
elif scale_degree.startswith('b'):
offset = -1 * scale_degree.count("b")
scale_degree = scale_degree.strip("b")
semitone = SCALE_DEGREES.get(scale_degree, None)
if semitone is None:
raise InvalidChordException(
"Scale degree improperly formed: {}, expected one of {}."
.format(scale_degree, list(SCALE_DEGREES.keys())))
return semitone + offset
def scale_degree_to_bitmap(scale_degree, modulo=False, length=BITMAP_LENGTH):
"""Create a bitmap representation of a scale degree.
Note that values in the bitmap may be negative, indicating that the
semitone is to be removed.
Parameters
----------
scale_degree : str
Spelling of a relative scale degree, e.g. 'b3', '7', '#5'
modulo : bool, default=True
If a scale degree exceeds the length of the bit-vector, modulo the
scale degree back into the bit-vector; otherwise it is discarded.
length : int, default=12
Length of the bit-vector to produce
Returns
-------
bitmap : np.ndarray, in [-1, 0, 1], len=`length`
Bitmap representation of this scale degree.
"""
sign = 1
if scale_degree.startswith("*"):
sign = -1
scale_degree = scale_degree.strip("*")
edit_map = [0] * length
sd_idx = scale_degree_to_semitone(scale_degree)
if sd_idx < length or modulo:
edit_map[sd_idx % length] = sign
return np.array(edit_map)
# Maps quality strings to bitmaps, corresponding to relative pitch class
# semitones, i.e. vector[0] is the tonic.
QUALITIES = {
# 1 2 3 4 5 6 7
'maj': [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0],
'min': [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0],
'aug': [1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0],
'dim': [1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0],
'sus4': [1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0],
'sus2': [1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0],
'7': [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0],
'maj7': [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1],
'min7': [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0],
'minmaj7': [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1],
'maj6': [1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0],
'min6': [1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0],
'dim7': [1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0],
'hdim7': [1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0],
'maj9': [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1],
'min9': [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0],
'9': [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0],
'b9': [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0],
'#9': [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0],
'min11': [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0],
'11': [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0],
'#11': [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0],
'maj13': [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1],
'min13': [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0],
'13': [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0],
'b13': [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0],
'1': [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
'5': [1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
'': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}
def quality_to_bitmap(quality):
"""Return the bitmap for a given quality.
Parameters
----------
quality : str
Chord quality name.
Returns
-------
bitmap : np.ndarray
Bitmap representation of this quality (12-dim).
"""
if quality not in QUALITIES:
raise InvalidChordException(
"Unsupported chord quality shorthand: '%s' "
"Did you mean to reduce extended chords?" % quality)
return np.array(QUALITIES[quality])
# Maps extended chord qualities to the subset above, translating additional
# voicings to extensions as a set of scale degrees (strings).
# TODO(ejhumphrey): Revisit how minmaj7's are mapped. This is how TMC did it,
# but MMV handles it like a separate quality (rather than an add7).
EXTENDED_QUALITY_REDUX = {
'minmaj7': ('min', set(['7'])),
'maj9': ('maj7', set(['9'])),
'min9': ('min7', set(['9'])),
'9': ('7', set(['9'])),
'b9': ('7', set(['b9'])),
'#9': ('7', set(['#9'])),
'11': ('7', set(['9', '11'])),
'#11': ('7', set(['9', '#11'])),
'13': ('7', set(['9', '11', '13'])),
'b13': ('7', set(['9', '11', 'b13'])),
'min11': ('min7', set(['9', '11'])),
'maj13': ('maj7', set(['9', '11', '13'])),
'min13': ('min7', set(['9', '11', '13']))}
def reduce_extended_quality(quality):
"""Map an extended chord quality to a simpler one, moving upper voices to
a set of scale degree extensions.
Parameters
----------
quality : str
Extended chord quality to reduce.
Returns
-------
base_quality : str
New chord quality.
extensions : set
Scale degrees extensions for the quality.
"""
return EXTENDED_QUALITY_REDUX.get(quality, (quality, set()))
# --- Chord Label Parsing ---
def validate_chord_label(chord_label):
"""Test for well-formedness of a chord label.
Parameters
----------
chord : str
Chord label to validate.
"""
# This monster regexp is pulled from the JAMS chord namespace,
# which is in turn derived from the context-free grammar of
# Harte et al., 2005.
pattern = re.compile(r'''^((N|X)|(([A-G](b*|#*))((:(maj|min|dim|aug|1|5|sus2|sus4|maj6|min6|7|maj7|min7|dim7|hdim7|minmaj7|aug7|9|maj9|min9|11|maj11|min11|13|maj13|min13)(\((\*?((b*|#*)([1-9]|1[0-3]?))(,\*?((b*|#*)([1-9]|1[0-3]?)))*)\))?)|(:\((\*?((b*|#*)([1-9]|1[0-3]?))(,\*?((b*|#*)([1-9]|1[0-3]?)))*)\)))?((/((b*|#*)([1-9]|1[0-3]?)))?)?))$''') # nopep8
if not pattern.match(chord_label):
raise InvalidChordException('Invalid chord label: '
'{}'.format(chord_label))
pass
def split(chord_label, reduce_extended_chords=False):
"""Parse a chord label into its four constituent parts:
- root
- quality shorthand
- scale degrees
- bass
Note: Chords lacking quality AND interval information are major.
- If a quality is specified, it is returned.
- If an interval is specified WITHOUT a quality, the quality field is
empty.
Some examples::
'C' -> ['C', 'maj', {}, '1']
'G#:min(*b3,*5)/5' -> ['G#', 'min', {'*b3', '*5'}, '5']
'A:(3)/6' -> ['A', '', {'3'}, '6']
Parameters
----------
chord_label : str
A chord label.
reduce_extended_chords : bool
Whether to map the upper voicings of extended chords (9's, 11's, 13's)
to semitone extensions. (Default value = False)
Returns
-------
chord_parts : list
Split version of the chord label.
"""
chord_label = str(chord_label)
validate_chord_label(chord_label)
if chord_label == NO_CHORD:
return [chord_label, '', set(), '']
bass = '1'
if "/" in chord_label:
chord_label, bass = chord_label.split("/")
scale_degrees = set()
omission = False
if "(" in chord_label:
chord_label, scale_degrees = chord_label.split("(")
omission = "*" in scale_degrees
scale_degrees = scale_degrees.strip(")")
scale_degrees = set([i.strip() for i in scale_degrees.split(",")])
# Note: Chords lacking quality AND added interval information are major.
# If a quality shorthand is specified, it is returned.
# If an interval is specified WITHOUT a quality, the quality field is
# empty.
# Intervals specifying omissions MUST have a quality.
if omission and ":" not in chord_label:
raise InvalidChordException(
"Intervals specifying omissions MUST have a quality.")
quality = '' if scale_degrees else 'maj'
if ":" in chord_label:
chord_root, quality_name = chord_label.split(":")
# Extended chords (with ":"s) may not explicitly have Major qualities,
# so only overwrite the default if the string is not empty.
if quality_name:
quality = quality_name.lower()
else:
chord_root = chord_label
if reduce_extended_chords:
quality, addl_scale_degrees = reduce_extended_quality(quality)
scale_degrees.update(addl_scale_degrees)
return [chord_root, quality, scale_degrees, bass]
def join(chord_root, quality='', extensions=None, bass=''):
r"""Join the parts of a chord into a complete chord label.
Parameters
----------
chord_root : str
Root pitch class of the chord, e.g. 'C', 'Eb'
quality : str
Quality of the chord, e.g. 'maj', 'hdim7'
(Default value = '')
extensions : list
Any added or absent scaled degrees for this chord, e.g. ['4', '\*3']
(Default value = None)
bass : str
Scale degree of the bass note, e.g. '5'.
(Default value = '')
Returns
-------
chord_label : str
A complete chord label.
"""
chord_label = chord_root
if quality or extensions:
chord_label += ":%s" % quality
if extensions:
chord_label += "(%s)" % ",".join(extensions)
if bass and bass != '1':
chord_label += "/%s" % bass
validate_chord_label(chord_label)
return chord_label
# --- Chords to Numerical Representations ---
def encode(chord_label, reduce_extended_chords=False,
strict_bass_intervals=False):
"""Translate a chord label to numerical representations for evaluation.
Parameters
----------
chord_label : str
Chord label to encode.
reduce_extended_chords : bool
Whether to map the upper voicings of extended chords (9's, 11's, 13's)
to semitone extensions.
(Default value = False)
strict_bass_intervals : bool
Whether to require that the bass scale degree is present in the chord.
(Default value = False)
Returns
-------
root_number : int
Absolute semitone of the chord's root.
semitone_bitmap : np.ndarray, dtype=int
12-dim vector of relative semitones in the chord spelling.
bass_number : int
Relative semitone of the chord's bass note, e.g. 0=root, 7=fifth, etc.
"""
if chord_label == NO_CHORD:
return NO_CHORD_ENCODED
if chord_label == X_CHORD:
return X_CHORD_ENCODED
chord_root, quality, scale_degrees, bass = split(
chord_label, reduce_extended_chords=reduce_extended_chords)
root_number = pitch_class_to_semitone(chord_root)
bass_number = scale_degree_to_semitone(bass) % 12
semitone_bitmap = quality_to_bitmap(quality)
semitone_bitmap[0] = 1
for scale_degree in scale_degrees:
semitone_bitmap += scale_degree_to_bitmap(scale_degree,
reduce_extended_chords)
semitone_bitmap = (semitone_bitmap > 0).astype(np.int)
if not semitone_bitmap[bass_number] and strict_bass_intervals:
raise InvalidChordException(
"Given bass scale degree is absent from this chord: "
"%s" % chord_label, chord_label)
else:
semitone_bitmap[bass_number] = 1
return root_number, semitone_bitmap, bass_number
def encode_many(chord_labels, reduce_extended_chords=False):
"""Translate a set of chord labels to numerical representations for sane
evaluation.
Parameters
----------
chord_labels : list
Set of chord labels to encode.
reduce_extended_chords : bool
Whether to map the upper voicings of extended chords (9's, 11's, 13's)
to semitone extensions.
(Default value = False)
Returns
-------
root_number : np.ndarray, dtype=int
Absolute semitone of the chord's root.
interval_bitmap : np.ndarray, dtype=int
12-dim vector of relative semitones in the given chord quality.
bass_number : np.ndarray, dtype=int
Relative semitones of the chord's bass notes.
"""
num_items = len(chord_labels)
roots, basses = np.zeros([2, num_items], dtype=np.int)
semitones = np.zeros([num_items, 12], dtype=np.int)
local_cache = dict()
for i, label in enumerate(chord_labels):
result = local_cache.get(label, None)
if result is None:
result = encode(label, reduce_extended_chords)
local_cache[label] = result
roots[i], semitones[i], basses[i] = result
return roots, semitones, basses
def rotate_bitmap_to_root(bitmap, chord_root):
"""Circularly shift a relative bitmap to its asbolute pitch classes.
For clarity, the best explanation is an example. Given 'G:Maj', the root
and quality map are as follows::
root=5
quality=[1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0] # Relative chord shape
After rotating to the root, the resulting bitmap becomes::
abs_quality = [0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1] # G, B, and D
Parameters
----------
bitmap : np.ndarray, shape=(12,)
Bitmap of active notes, relative to the given root.
chord_root : int
Absolute pitch class number.
Returns
-------
bitmap : np.ndarray, shape=(12,)
Absolute bitmap of active pitch classes.
"""
bitmap = np.asarray(bitmap)
assert bitmap.ndim == 1, "Currently only 1D bitmaps are supported."
idxs = list(np.nonzero(bitmap))
idxs[-1] = (idxs[-1] + chord_root) % 12
abs_bitmap = np.zeros_like(bitmap)
abs_bitmap[tuple(idxs)] = 1
return abs_bitmap
def rotate_bitmaps_to_roots(bitmaps, roots):
"""Circularly shift a relative bitmaps to asbolute pitch classes.
See :func:`rotate_bitmap_to_root` for more information.
Parameters
----------
bitmap : np.ndarray, shape=(N, 12)
Bitmap of active notes, relative to the given root.
root : np.ndarray, shape=(N,)
Absolute pitch class number.
Returns
-------
bitmap : np.ndarray, shape=(N, 12)
Absolute bitmaps of active pitch classes.
"""
abs_bitmaps = []
for bitmap, chord_root in zip(bitmaps, roots):
abs_bitmaps.append(rotate_bitmap_to_root(bitmap, chord_root))
return np.asarray(abs_bitmaps)
# --- Comparison Routines ---
def validate(reference_labels, estimated_labels):
"""Checks that the input annotations to a comparison function look like
valid chord labels.
Parameters
----------
reference_labels : list, len=n
Reference chord labels to score against.
estimated_labels : list, len=n
Estimated chord labels to score against.
"""
N = len(reference_labels)
M = len(estimated_labels)
if N != M:
raise ValueError(
"Chord comparison received different length lists: "
"len(reference)=%d\tlen(estimates)=%d" % (N, M))
for labels in [reference_labels, estimated_labels]:
for chord_label in labels:
validate_chord_label(chord_label)
# When either label list is empty, warn the user
if len(reference_labels) == 0:
warnings.warn('Reference labels are empty')
if len(estimated_labels) == 0:
warnings.warn('Estimated labels are empty')
def weighted_accuracy(comparisons, weights):
"""Compute the weighted accuracy of a list of chord comparisons.
Examples
--------
>>> (ref_intervals,
... ref_labels) = mir_eval.io.load_labeled_intervals('ref.lab')
>>> (est_intervals,
... est_labels) = mir_eval.io.load_labeled_intervals('est.lab')
>>> est_intervals, est_labels = mir_eval.util.adjust_intervals(
... est_intervals, est_labels, ref_intervals.min(),
... ref_intervals.max(), mir_eval.chord.NO_CHORD,
... mir_eval.chord.NO_CHORD)
>>> (intervals,
... ref_labels,
... est_labels) = mir_eval.util.merge_labeled_intervals(
... ref_intervals, ref_labels, est_intervals, est_labels)
>>> durations = mir_eval.util.intervals_to_durations(intervals)
>>> # Here, we're using the "thirds" function to compare labels
>>> # but any of the comparison functions would work.
>>> comparisons = mir_eval.chord.thirds(ref_labels, est_labels)
>>> score = mir_eval.chord.weighted_accuracy(comparisons, durations)
Parameters
----------
comparisons : np.ndarray
List of chord comparison scores, in [0, 1] or -1
weights : np.ndarray
Weights (not necessarily normalized) for each comparison.
This can be a list of interval durations
Returns
-------
score : float
Weighted accuracy
"""
N = len(comparisons)
# There should be as many weights as comparisons
if weights.shape[0] != N:
raise ValueError('weights and comparisons should be of the same'
' length. len(weights) = {} but len(comparisons)'
' = {}'.format(weights.shape[0], N))
if (weights < 0).any():
raise ValueError('Weights should all be positive.')
if np.sum(weights) == 0:
warnings.warn('No nonzero weights, returning 0')
return 0
# Find all comparison scores which are valid
valid_idx = (comparisons >= 0)
# If no comparable chords were provided, warn and return 0
if valid_idx.sum() == 0:
warnings.warn("No reference chords were comparable "
"to estimated chords, returning 0.")
return 0
# Remove any uncomparable labels
comparisons = comparisons[valid_idx]
weights = weights[valid_idx]
# Normalize the weights
total_weight = float(np.sum(weights))
normalized_weights = np.asarray(weights, dtype=float)/total_weight
# Score is the sum of all weighted comparisons
return np.sum(comparisons*normalized_weights)
def thirds(reference_labels, estimated_labels):
"""Compare chords along root & third relationships.
Examples
--------
>>> (ref_intervals,
... ref_labels) = mir_eval.io.load_labeled_intervals('ref.lab')
>>> (est_intervals,
... est_labels) = mir_eval.io.load_labeled_intervals('est.lab')
>>> est_intervals, est_labels = mir_eval.util.adjust_intervals(
... est_intervals, est_labels, ref_intervals.min(),
... ref_intervals.max(), mir_eval.chord.NO_CHORD,
... mir_eval.chord.NO_CHORD)
>>> (intervals,
... ref_labels,
... est_labels) = mir_eval.util.merge_labeled_intervals(
... ref_intervals, ref_labels, est_intervals, est_labels)
>>> durations = mir_eval.util.intervals_to_durations(intervals)
>>> comparisons = mir_eval.chord.thirds(ref_labels, est_labels)
>>> score = mir_eval.chord.weighted_accuracy(comparisons, durations)
Parameters
----------
reference_labels : list, len=n
Reference chord labels to score against.
estimated_labels : list, len=n
Estimated chord labels to score against.
Returns
-------
comparison_scores : np.ndarray, shape=(n,), dtype=float
Comparison scores, in [0.0, 1.0]
"""
validate(reference_labels, estimated_labels)
ref_roots, ref_semitones = encode_many(reference_labels, False)[:2]
est_roots, est_semitones = encode_many(estimated_labels, False)[:2]
eq_roots = ref_roots == est_roots
eq_thirds = ref_semitones[:, 3] == est_semitones[:, 3]
comparison_scores = (eq_roots * eq_thirds).astype(np.float)
# Ignore 'X' chords
comparison_scores[np.any(ref_semitones < 0, axis=1)] = -1.0
return comparison_scores
def thirds_inv(reference_labels, estimated_labels):
"""Score chords along root, third, & bass relationships.
Examples
--------
>>> (ref_intervals,
... ref_labels) = mir_eval.io.load_labeled_intervals('ref.lab')
>>> (est_intervals,
... est_labels) = mir_eval.io.load_labeled_intervals('est.lab')
>>> est_intervals, est_labels = mir_eval.util.adjust_intervals(
... est_intervals, est_labels, ref_intervals.min(),
... ref_intervals.max(), mir_eval.chord.NO_CHORD,
... mir_eval.chord.NO_CHORD)
>>> (intervals,
... ref_labels,
... est_labels) = mir_eval.util.merge_labeled_intervals(
... ref_intervals, ref_labels, est_intervals, est_labels)
>>> durations = mir_eval.util.intervals_to_durations(intervals)
>>> comparisons = mir_eval.chord.thirds_inv(ref_labels, est_labels)
>>> score = mir_eval.chord.weighted_accuracy(comparisons, durations)
Parameters
----------
reference_labels : list, len=n
Reference chord labels to score against.
estimated_labels : list, len=n
Estimated chord labels to score against.
Returns
-------
scores : np.ndarray, shape=(n,), dtype=float
Comparison scores, in [0.0, 1.0]
"""
validate(reference_labels, estimated_labels)
ref_roots, ref_semitones, ref_bass = encode_many(reference_labels, False)
est_roots, est_semitones, est_bass = encode_many(estimated_labels, False)
eq_root = ref_roots == est_roots
eq_bass = ref_bass == est_bass
eq_third = ref_semitones[:, 3] == est_semitones[:, 3]
comparison_scores = (eq_root * eq_third * eq_bass).astype(np.float)
# Ignore 'X' chords
comparison_scores[np.any(ref_semitones < 0, axis=1)] = -1.0
return comparison_scores
def triads(reference_labels, estimated_labels):
"""Compare chords along triad (root & quality to #5) relationships.
Examples
--------
>>> (ref_intervals,
... ref_labels) = mir_eval.io.load_labeled_intervals('ref.lab')
>>> (est_intervals,
... est_labels) = mir_eval.io.load_labeled_intervals('est.lab')
>>> est_intervals, est_labels = mir_eval.util.adjust_intervals(
... est_intervals, est_labels, ref_intervals.min(),
... ref_intervals.max(), mir_eval.chord.NO_CHORD,
... mir_eval.chord.NO_CHORD)
>>> (intervals,
... ref_labels,
... est_labels) = mir_eval.util.merge_labeled_intervals(
... ref_intervals, ref_labels, est_intervals, est_labels)
>>> durations = mir_eval.util.intervals_to_durations(intervals)
>>> comparisons = mir_eval.chord.triads(ref_labels, est_labels)
>>> score = mir_eval.chord.weighted_accuracy(comparisons, durations)
Parameters
----------
reference_labels : list, len=n
Reference chord labels to score against.
estimated_labels : list, len=n
Estimated chord labels to score against.
Returns
-------
comparison_scores : np.ndarray, shape=(n,), dtype=float
Comparison scores, in [0.0, 1.0]
"""
validate(reference_labels, estimated_labels)
ref_roots, ref_semitones = encode_many(reference_labels, False)[:2]
est_roots, est_semitones = encode_many(estimated_labels, False)[:2]
eq_roots = ref_roots == est_roots
eq_semitones = np.all(
np.equal(ref_semitones[:, :8], est_semitones[:, :8]), axis=1)
comparison_scores = (eq_roots * eq_semitones).astype(np.float)
# Ignore 'X' chords
comparison_scores[np.any(ref_semitones < 0, axis=1)] = -1.0
return comparison_scores
def triads_inv(reference_labels, estimated_labels):
"""Score chords along triad (root, quality to #5, & bass) relationships.
Examples
--------
>>> (ref_intervals,
... ref_labels) = mir_eval.io.load_labeled_intervals('ref.lab')
>>> (est_intervals,
... est_labels) = mir_eval.io.load_labeled_intervals('est.lab')
>>> est_intervals, est_labels = mir_eval.util.adjust_intervals(
... est_intervals, est_labels, ref_intervals.min(),
... ref_intervals.max(), mir_eval.chord.NO_CHORD,
... mir_eval.chord.NO_CHORD)
>>> (intervals,
... ref_labels,
... est_labels) = mir_eval.util.merge_labeled_intervals(
... ref_intervals, ref_labels, est_intervals, est_labels)
>>> durations = mir_eval.util.intervals_to_durations(intervals)
>>> comparisons = mir_eval.chord.triads_inv(ref_labels, est_labels)
>>> score = mir_eval.chord.weighted_accuracy(comparisons, durations)
Parameters
----------
reference_labels : list, len=n
Reference chord labels to score against.
estimated_labels : list, len=n
Estimated chord labels to score against.
Returns
-------
scores : np.ndarray, shape=(n,), dtype=float
Comparison scores, in [0.0, 1.0]
"""
validate(reference_labels, estimated_labels)
ref_roots, ref_semitones, ref_bass = encode_many(reference_labels, False)
est_roots, est_semitones, est_bass = encode_many(estimated_labels, False)
eq_roots = ref_roots == est_roots
eq_basses = ref_bass == est_bass
eq_semitones = np.all(
np.equal(ref_semitones[:, :8], est_semitones[:, :8]), axis=1)
comparison_scores = (eq_roots * eq_semitones * eq_basses).astype(np.float)
# Ignore 'X' chords
comparison_scores[np.any(ref_semitones < 0, axis=1)] = -1.0
return comparison_scores
def tetrads(reference_labels, estimated_labels):
"""Compare chords along tetrad (root & full quality) relationships.
Examples
--------
>>> (ref_intervals,
... ref_labels) = mir_eval.io.load_labeled_intervals('ref.lab')
>>> (est_intervals,
... est_labels) = mir_eval.io.load_labeled_intervals('est.lab')
>>> est_intervals, est_labels = mir_eval.util.adjust_intervals(
... est_intervals, est_labels, ref_intervals.min(),
... ref_intervals.max(), mir_eval.chord.NO_CHORD,
... mir_eval.chord.NO_CHORD)
>>> (intervals,
... ref_labels,
... est_labels) = mir_eval.util.merge_labeled_intervals(
... ref_intervals, ref_labels, est_intervals, est_labels)
>>> durations = mir_eval.util.intervals_to_durations(intervals)
>>> comparisons = mir_eval.chord.tetrads(ref_labels, est_labels)
>>> score = mir_eval.chord.weighted_accuracy(comparisons, durations)
Parameters
----------
reference_labels : list, len=n
Reference chord labels to score against.
estimated_labels : list, len=n
Estimated chord labels to score against.
Returns
-------
comparison_scores : np.ndarray, shape=(n,), dtype=float
Comparison scores, in [0.0, 1.0]
"""
validate(reference_labels, estimated_labels)
ref_roots, ref_semitones = encode_many(reference_labels, False)[:2]
est_roots, est_semitones = encode_many(estimated_labels, False)[:2]
eq_roots = ref_roots == est_roots
eq_semitones = np.all(np.equal(ref_semitones, est_semitones), axis=1)
comparison_scores = (eq_roots * eq_semitones).astype(np.float)
# Ignore 'X' chords
comparison_scores[np.any(ref_semitones < 0, axis=1)] = -1.0
return comparison_scores
def tetrads_inv(reference_labels, estimated_labels):
"""Compare chords along tetrad (root, full quality, & bass) relationships.
Examples
--------
>>> (ref_intervals,
... ref_labels) = mir_eval.io.load_labeled_intervals('ref.lab')
>>> (est_intervals,
... est_labels) = mir_eval.io.load_labeled_intervals('est.lab')
>>> est_intervals, est_labels = mir_eval.util.adjust_intervals(
... est_intervals, est_labels, ref_intervals.min(),
... ref_intervals.max(), mir_eval.chord.NO_CHORD,
... mir_eval.chord.NO_CHORD)
>>> (intervals,
... ref_labels,
... est_labels) = mir_eval.util.merge_labeled_intervals(
... ref_intervals, ref_labels, est_intervals, est_labels)
>>> durations = mir_eval.util.intervals_to_durations(intervals)
>>> comparisons = mir_eval.chord.tetrads_inv(ref_labels, est_labels)
>>> score = mir_eval.chord.weighted_accuracy(comparisons, durations)
Parameters
----------
reference_labels : list, len=n
Reference chord labels to score against.
estimated_labels : list, len=n
Estimated chord labels to score against.
Returns
-------
comparison_scores : np.ndarray, shape=(n,), dtype=float
Comparison scores, in [0.0, 1.0]
"""
validate(reference_labels, estimated_labels)
ref_roots, ref_semitones, ref_bass = encode_many(reference_labels, False)
est_roots, est_semitones, est_bass = encode_many(estimated_labels, False)
eq_roots = ref_roots == est_roots
eq_basses = ref_bass == est_bass
eq_semitones = np.all(np.equal(ref_semitones, est_semitones), axis=1)
comparison_scores = (eq_roots * eq_semitones * eq_basses).astype(np.float)
# Ignore 'X' chords
comparison_scores[np.any(ref_semitones < 0, axis=1)] = -1.0
return comparison_scores
def root(reference_labels, estimated_labels):
"""Compare chords according to roots.
Examples
--------
>>> (ref_intervals,
... ref_labels) = mir_eval.io.load_labeled_intervals('ref.lab')
>>> (est_intervals,
... est_labels) = mir_eval.io.load_labeled_intervals('est.lab')
>>> est_intervals, est_labels = mir_eval.util.adjust_intervals(
... est_intervals, est_labels, ref_intervals.min(),
... ref_intervals.max(), mir_eval.chord.NO_CHORD,
... mir_eval.chord.NO_CHORD)
>>> (intervals,
... ref_labels,
... est_labels) = mir_eval.util.merge_labeled_intervals(
... ref_intervals, ref_labels, est_intervals, est_labels)
>>> durations = mir_eval.util.intervals_to_durations(intervals)
>>> comparisons = mir_eval.chord.root(ref_labels, est_labels)
>>> score = mir_eval.chord.weighted_accuracy(comparisons, durations)
Parameters
----------
reference_labels : list, len=n
Reference chord labels to score against.
estimated_labels : list, len=n
Estimated chord labels to score against.
Returns
-------
comparison_scores : np.ndarray, shape=(n,), dtype=float
Comparison scores, in [0.0, 1.0], or -1 if the comparison is out of
gamut.
"""
validate(reference_labels, estimated_labels)
ref_roots, ref_semitones = encode_many(reference_labels, False)[:2]
est_roots = encode_many(estimated_labels, False)[0]
comparison_scores = (ref_roots == est_roots).astype(np.float)
# Ignore 'X' chords
comparison_scores[np.any(ref_semitones < 0, axis=1)] = -1.0
return comparison_scores
def mirex(reference_labels, estimated_labels):
"""Compare chords along MIREX rules.
Examples
--------
>>> (ref_intervals,
... ref_labels) = mir_eval.io.load_labeled_intervals('ref.lab')
>>> (est_intervals,
... est_labels) = mir_eval.io.load_labeled_intervals('est.lab')
>>> est_intervals, est_labels = mir_eval.util.adjust_intervals(
... est_intervals, est_labels, ref_intervals.min(),
... ref_intervals.max(), mir_eval.chord.NO_CHORD,
... mir_eval.chord.NO_CHORD)
>>> (intervals,
... ref_labels,
... est_labels) = mir_eval.util.merge_labeled_intervals(
... ref_intervals, ref_labels, est_intervals, est_labels)
>>> durations = mir_eval.util.intervals_to_durations(intervals)
>>> comparisons = mir_eval.chord.mirex(ref_labels, est_labels)
>>> score = mir_eval.chord.weighted_accuracy(comparisons, durations)
Parameters
----------
reference_labels : list, len=n
Reference chord labels to score against.
estimated_labels : list, len=n
Estimated chord labels to score against.
Returns
-------
comparison_scores : np.ndarray, shape=(n,), dtype=float
Comparison scores, in [0.0, 1.0]
"""
validate(reference_labels, estimated_labels)
# TODO(?): Should this be an argument?
min_intersection = 3
ref_data = encode_many(reference_labels, False)
ref_chroma = rotate_bitmaps_to_roots(ref_data[1], ref_data[0])
est_data = encode_many(estimated_labels, False)
est_chroma = rotate_bitmaps_to_roots(est_data[1], est_data[0])
eq_chroma = (ref_chroma * est_chroma).sum(axis=-1)
# Chroma matching for set bits
comparison_scores = (eq_chroma >= min_intersection).astype(np.float)
# No-chord matching; match -1 roots, SKIP_CHORDS dropped next
no_root = np.logical_and(ref_data[0] == -1, est_data[0] == -1)
comparison_scores[no_root] = 1.0
# Skip chords where the number of active semitones `n` is
# 0 < n < `min_intersection`.
ref_semitone_count = (ref_data[1] > 0).sum(axis=1)
skip_idx = np.logical_and(ref_semitone_count > 0,
ref_semitone_count < min_intersection)
# Also ignore 'X' chords.
np.logical_or(skip_idx, np.any(ref_data[1] < 0, axis=1), skip_idx)
comparison_scores[skip_idx] = -1.0
return comparison_scores
def majmin(reference_labels, estimated_labels):
"""Compare chords along major-minor rules. Chords with qualities outside
Major/minor/no-chord are ignored.
Examples
--------
>>> (ref_intervals,
... ref_labels) = mir_eval.io.load_labeled_intervals('ref.lab')
>>> (est_intervals,
... est_labels) = mir_eval.io.load_labeled_intervals('est.lab')
>>> est_intervals, est_labels = mir_eval.util.adjust_intervals(
... est_intervals, est_labels, ref_intervals.min(),
... ref_intervals.max(), mir_eval.chord.NO_CHORD,
... mir_eval.chord.NO_CHORD)
>>> (intervals,
... ref_labels,
... est_labels) = mir_eval.util.merge_labeled_intervals(
... ref_intervals, ref_labels, est_intervals, est_labels)
>>> durations = mir_eval.util.intervals_to_durations(intervals)
>>> comparisons = mir_eval.chord.majmin(ref_labels, est_labels)
>>> score = mir_eval.chord.weighted_accuracy(comparisons, durations)
Parameters
----------
reference_labels : list, len=n
Reference chord labels to score against.
estimated_labels : list, len=n
Estimated chord labels to score against.
Returns
-------
comparison_scores : np.ndarray, shape=(n,), dtype=float
Comparison scores, in [0.0, 1.0], or -1 if the comparison is out of
gamut.
"""
validate(reference_labels, estimated_labels)
maj_semitones = np.array(QUALITIES['maj'][:8])
min_semitones = np.array(QUALITIES['min'][:8])
ref_roots, ref_semitones, _ = encode_many(reference_labels, False)
est_roots, est_semitones, _ = encode_many(estimated_labels, False)
eq_root = ref_roots == est_roots
eq_quality = np.all(np.equal(ref_semitones[:, :8],
est_semitones[:, :8]), axis=1)
comparison_scores = (eq_root * eq_quality).astype(np.float)
# Test for Major / Minor / No-chord
is_maj = np.all(np.equal(ref_semitones[:, :8], maj_semitones), axis=1)
is_min = np.all(np.equal(ref_semitones[:, :8], min_semitones), axis=1)
is_none = np.logical_and(ref_roots < 0, np.all(ref_semitones == 0, axis=1))
# Only keep majors, minors, and Nones (NOR)
comparison_scores[(is_maj + is_min + is_none) == 0] = -1
# Disable chords that disrupt this quality (apparently)
# ref_voicing = np.all(np.equal(ref_qualities[:, :8],
# ref_notes[:, :8]), axis=1)
# comparison_scores[ref_voicing == 0] = -1
# est_voicing = np.all(np.equal(est_qualities[:, :8],
# est_notes[:, :8]), axis=1)
# comparison_scores[est_voicing == 0] = -1
return comparison_scores
def majmin_inv(reference_labels, estimated_labels):
"""Compare chords along major-minor rules, with inversions. Chords with
qualities outside Major/minor/no-chord are ignored, and the bass note must
exist in the triad (bass in [1, 3, 5]).
Examples
--------
>>> (ref_intervals,
... ref_labels) = mir_eval.io.load_labeled_intervals('ref.lab')
>>> (est_intervals,
... est_labels) = mir_eval.io.load_labeled_intervals('est.lab')
>>> est_intervals, est_labels = mir_eval.util.adjust_intervals(
... est_intervals, est_labels, ref_intervals.min(),
... ref_intervals.max(), mir_eval.chord.NO_CHORD,
... mir_eval.chord.NO_CHORD)
>>> (intervals,
... ref_labels,
... est_labels) = mir_eval.util.merge_labeled_intervals(
... ref_intervals, ref_labels, est_intervals, est_labels)
>>> durations = mir_eval.util.intervals_to_durations(intervals)
>>> comparisons = mir_eval.chord.majmin_inv(ref_labels, est_labels)
>>> score = mir_eval.chord.weighted_accuracy(comparisons, durations)
Parameters
----------
reference_labels : list, len=n
Reference chord labels to score against.
estimated_labels : list, len=n
Estimated chord labels to score against.
Returns
-------
comparison_scores : np.ndarray, shape=(n,), dtype=float
Comparison scores, in [0.0, 1.0], or -1 if the comparison is out of
gamut.
"""
validate(reference_labels, estimated_labels)
maj_semitones = np.array(QUALITIES['maj'][:8])
min_semitones = np.array(QUALITIES['min'][:8])
ref_roots, ref_semitones, ref_bass = encode_many(reference_labels, False)
est_roots, est_semitones, est_bass = encode_many(estimated_labels, False)
eq_root_bass = (ref_roots == est_roots) * (ref_bass == est_bass)
eq_semitones = np.all(np.equal(ref_semitones[:, :8],
est_semitones[:, :8]), axis=1)
comparison_scores = (eq_root_bass * eq_semitones).astype(np.float)
# Test for Major / Minor / No-chord
is_maj = np.all(np.equal(ref_semitones[:, :8], maj_semitones), axis=1)
is_min = np.all(np.equal(ref_semitones[:, :8], min_semitones), axis=1)
is_none = np.logical_and(ref_roots < 0, np.all(ref_semitones == 0, axis=1))
# Only keep majors, minors, and Nones (NOR)
comparison_scores[(is_maj + is_min + is_none) == 0] = -1
# Disable inversions that are not part of the quality
valid_inversion = np.ones(ref_bass.shape, dtype=bool)
bass_idx = ref_bass >= 0
valid_inversion[bass_idx] = ref_semitones[bass_idx, ref_bass[bass_idx]]
comparison_scores[valid_inversion == 0] = -1
return comparison_scores
def sevenths(reference_labels, estimated_labels):
"""Compare chords along MIREX 'sevenths' rules. Chords with qualities
outside [maj, maj7, 7, min, min7, N] are ignored.
Examples
--------
>>> (ref_intervals,
... ref_labels) = mir_eval.io.load_labeled_intervals('ref.lab')
>>> (est_intervals,
... est_labels) = mir_eval.io.load_labeled_intervals('est.lab')
>>> est_intervals, est_labels = mir_eval.util.adjust_intervals(
... est_intervals, est_labels, ref_intervals.min(),
... ref_intervals.max(), mir_eval.chord.NO_CHORD,
... mir_eval.chord.NO_CHORD)
>>> (intervals,
... ref_labels,
... est_labels) = mir_eval.util.merge_labeled_intervals(
... ref_intervals, ref_labels, est_intervals, est_labels)
>>> durations = mir_eval.util.intervals_to_durations(intervals)
>>> comparisons = mir_eval.chord.sevenths(ref_labels, est_labels)
>>> score = mir_eval.chord.weighted_accuracy(comparisons, durations)
Parameters
----------
reference_labels : list, len=n
Reference chord labels to score against.
estimated_labels : list, len=n
Estimated chord labels to score against.
Returns
-------
comparison_scores : np.ndarray, shape=(n,), dtype=float
Comparison scores, in [0.0, 1.0], or -1 if the comparison is out of
gamut.
"""
validate(reference_labels, estimated_labels)
seventh_qualities = ['maj', 'min', 'maj7', '7', 'min7', '']
valid_semitones = np.array([QUALITIES[name] for name in seventh_qualities])
ref_roots, ref_semitones = encode_many(reference_labels, False)[:2]
est_roots, est_semitones = encode_many(estimated_labels, False)[:2]
eq_root = ref_roots == est_roots
eq_semitones = np.all(np.equal(ref_semitones, est_semitones), axis=1)
comparison_scores = (eq_root * eq_semitones).astype(np.float)
# Test for reference chord inclusion
is_valid = np.array([np.all(np.equal(ref_semitones, semitones), axis=1)
for semitones in valid_semitones])
# Drop if NOR
comparison_scores[np.sum(is_valid, axis=0) == 0] = -1
return comparison_scores
def sevenths_inv(reference_labels, estimated_labels):
"""Compare chords along MIREX 'sevenths' rules. Chords with qualities
outside [maj, maj7, 7, min, min7, N] are ignored.
Examples
--------
>>> (ref_intervals,
... ref_labels) = mir_eval.io.load_labeled_intervals('ref.lab')
>>> (est_intervals,
... est_labels) = mir_eval.io.load_labeled_intervals('est.lab')
>>> est_intervals, est_labels = mir_eval.util.adjust_intervals(
... est_intervals, est_labels, ref_intervals.min(),
... ref_intervals.max(), mir_eval.chord.NO_CHORD,
... mir_eval.chord.NO_CHORD)
>>> (intervals,
... ref_labels,
... est_labels) = mir_eval.util.merge_labeled_intervals(
... ref_intervals, ref_labels, est_intervals, est_labels)
>>> durations = mir_eval.util.intervals_to_durations(intervals)
>>> comparisons = mir_eval.chord.sevenths_inv(ref_labels, est_labels)
>>> score = mir_eval.chord.weighted_accuracy(comparisons, durations)
Parameters
----------
reference_labels : list, len=n
Reference chord labels to score against.
estimated_labels : list, len=n
Estimated chord labels to score against.
Returns
-------
comparison_scores : np.ndarray, shape=(n,), dtype=float
Comparison scores, in [0.0, 1.0], or -1 if the comparison is out of
gamut.
"""
validate(reference_labels, estimated_labels)
seventh_qualities = ['maj', 'min', 'maj7', '7', 'min7', '']
valid_semitones = np.array([QUALITIES[name] for name in seventh_qualities])
ref_roots, ref_semitones, ref_basses = encode_many(reference_labels, False)
est_roots, est_semitones, est_basses = encode_many(estimated_labels, False)
eq_roots_basses = (ref_roots == est_roots) * (ref_basses == est_basses)
eq_semitones = np.all(np.equal(ref_semitones, est_semitones), axis=1)
comparison_scores = (eq_roots_basses * eq_semitones).astype(np.float)
# Test for Major / Minor / No-chord
is_valid = np.array([np.all(np.equal(ref_semitones, semitones), axis=1)
for semitones in valid_semitones])
comparison_scores[np.sum(is_valid, axis=0) == 0] = -1
# Disable inversions that are not part of the quality
valid_inversion = np.ones(ref_basses.shape, dtype=bool)
bass_idx = ref_basses >= 0
valid_inversion[bass_idx] = ref_semitones[bass_idx, ref_basses[bass_idx]]
comparison_scores[valid_inversion == 0] = -1
return comparison_scores
def directional_hamming_distance(reference_intervals, estimated_intervals):
"""Compute the directional hamming distance between reference and
estimated intervals as defined by [#harte2010towards]_ and used for MIREX
'OverSeg', 'UnderSeg' and 'MeanSeg' measures.
Examples
--------
>>> (ref_intervals,
... ref_labels) = mir_eval.io.load_labeled_intervals('ref.lab')
>>> (est_intervals,
... est_labels) = mir_eval.io.load_labeled_intervals('est.lab')
>>> overseg = 1 - mir_eval.chord.directional_hamming_distance(
... ref_intervals, est_intervals)
>>> underseg = 1 - mir_eval.chord.directional_hamming_distance(
... est_intervals, ref_intervals)
>>> seg = min(overseg, underseg)
Parameters
----------
reference_intervals : np.ndarray, shape=(n, 2), dtype=float
Reference chord intervals to score against.
estimated_intervals : np.ndarray, shape=(m, 2), dtype=float
Estimated chord intervals to score against.
Returns
-------
directional hamming distance : float
directional hamming distance between reference intervals and
estimated intervals.
"""
util.validate_intervals(estimated_intervals)
util.validate_intervals(reference_intervals)
# make sure chord intervals do not overlap
if len(reference_intervals) > 1 and (reference_intervals[:-1, 1] >
reference_intervals[1:, 0]).any():
raise ValueError('Chord Intervals must not overlap')
est_ts = np.unique(estimated_intervals.flatten())
seg = 0.
for start, end in reference_intervals:
dur = end - start
between_start_end = est_ts[(est_ts >= start) & (est_ts < end)]
seg_ts = np.hstack([start, between_start_end, end])
seg += dur - np.diff(seg_ts).max()
return seg / (reference_intervals[-1, 1] - reference_intervals[0, 0])
def overseg(reference_intervals, estimated_intervals):
"""Compute the MIREX 'OverSeg' score.
Examples
--------
>>> (ref_intervals,
... ref_labels) = mir_eval.io.load_labeled_intervals('ref.lab')
>>> (est_intervals,
... est_labels) = mir_eval.io.load_labeled_intervals('est.lab')
>>> score = mir_eval.chord.overseg(ref_intervals, est_intervals)
Parameters
----------
reference_intervals : np.ndarray, shape=(n, 2), dtype=float
Reference chord intervals to score against.
estimated_intervals : np.ndarray, shape=(m, 2), dtype=float
Estimated chord intervals to score against.
Returns
-------
oversegmentation score : float
Comparison score, in [0.0, 1.0], where 1.0 means no oversegmentation.
"""
return 1 - directional_hamming_distance(reference_intervals,
estimated_intervals)
def underseg(reference_intervals, estimated_intervals):
"""Compute the MIREX 'UnderSeg' score.
Examples
--------
>>> (ref_intervals,
... ref_labels) = mir_eval.io.load_labeled_intervals('ref.lab')
>>> (est_intervals,
... est_labels) = mir_eval.io.load_labeled_intervals('est.lab')
>>> score = mir_eval.chord.underseg(ref_intervals, est_intervals)
Parameters
----------
reference_intervals : np.ndarray, shape=(n, 2), dtype=float
Reference chord intervals to score against.
estimated_intervals : np.ndarray, shape=(m, 2), dtype=float
Estimated chord intervals to score against.
Returns
-------
undersegmentation score : float
Comparison score, in [0.0, 1.0], where 1.0 means no undersegmentation.
"""
return 1 - directional_hamming_distance(estimated_intervals,
reference_intervals)
def seg(reference_intervals, estimated_intervals):
"""Compute the MIREX 'MeanSeg' score.
Examples
--------
>>> (ref_intervals,
... ref_labels) = mir_eval.io.load_labeled_intervals('ref.lab')
>>> (est_intervals,
... est_labels) = mir_eval.io.load_labeled_intervals('est.lab')
>>> score = mir_eval.chord.seg(ref_intervals, est_intervals)
Parameters
----------
reference_intervals : np.ndarray, shape=(n, 2), dtype=float
Reference chord intervals to score against.
estimated_intervals : np.ndarray, shape=(m, 2), dtype=float
Estimated chord intervals to score against.
Returns
-------
segmentation score : float
Comparison score, in [0.0, 1.0], where 1.0 means perfect segmentation.
"""
return min(underseg(reference_intervals, estimated_intervals),
overseg(reference_intervals, estimated_intervals))
def merge_chord_intervals(intervals, labels):
"""
Merge consecutive chord intervals if they represent the same chord.
Parameters
----------
intervals : np.ndarray, shape=(n, 2), dtype=float
Chord intervals to be merged, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
labels : list, shape=(n,)
Chord labels to be merged, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
Returns
-------
merged_ivs : np.ndarray, shape=(k, 2), dtype=float
Merged chord intervals, k <= n
"""
roots, semitones, basses = encode_many(labels, True)
merged_ivs = []
prev_rt = None
prev_st = None
prev_ba = None
for s, e, rt, st, ba in zip(intervals[:, 0], intervals[:, 1],
roots, semitones, basses):
if rt != prev_rt or (st != prev_st).any() or ba != prev_ba:
prev_rt, prev_st, prev_ba = rt, st, ba
merged_ivs.append([s, e])
else:
merged_ivs[-1][-1] = e
return np.array(merged_ivs)
def evaluate(ref_intervals, ref_labels, est_intervals, est_labels, **kwargs):
"""Computes weighted accuracy for all comparison functions for the given
reference and estimated annotations.
Examples
--------
>>> (ref_intervals,
... ref_labels) = mir_eval.io.load_labeled_intervals('ref.lab')
>>> (est_intervals,
... est_labels) = mir_eval.io.load_labeled_intervals('est.lab')
>>> scores = mir_eval.chord.evaluate(ref_intervals, ref_labels,
... est_intervals, est_labels)
Parameters
----------
ref_intervals : np.ndarray, shape=(n, 2)
Reference chord intervals, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
ref_labels : list, shape=(n,)
reference chord labels, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
est_intervals : np.ndarray, shape=(m, 2)
estimated chord intervals, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
est_labels : list, shape=(m,)
estimated chord labels, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
kwargs
Additional keyword arguments which will be passed to the
appropriate metric or preprocessing functions.
Returns
-------
scores : dict
Dictionary of scores, where the key is the metric name (str) and
the value is the (float) score achieved.
"""
# Append or crop estimated intervals so their span is the same as reference
est_intervals, est_labels = util.adjust_intervals(
est_intervals, est_labels, ref_intervals.min(), ref_intervals.max(),
NO_CHORD, NO_CHORD)
# use merged intervals for segmentation evaluation
merged_ref_intervals = merge_chord_intervals(ref_intervals, ref_labels)
merged_est_intervals = merge_chord_intervals(est_intervals, est_labels)
# Adjust the labels so that they span the same intervals
intervals, ref_labels, est_labels = util.merge_labeled_intervals(
ref_intervals, ref_labels, est_intervals, est_labels)
# Convert intervals to durations (used as weights)
durations = util.intervals_to_durations(intervals)
# Store scores for each comparison function
scores = collections.OrderedDict()
scores['thirds'] = weighted_accuracy(thirds(ref_labels, est_labels),
durations)
scores['thirds_inv'] = weighted_accuracy(thirds_inv(ref_labels,
est_labels), durations)
scores['triads'] = weighted_accuracy(triads(ref_labels, est_labels),
durations)
scores['triads_inv'] = weighted_accuracy(triads_inv(ref_labels,
est_labels), durations)
scores['tetrads'] = weighted_accuracy(tetrads(ref_labels, est_labels),
durations)
scores['tetrads_inv'] = weighted_accuracy(tetrads_inv(ref_labels,
est_labels),
durations)
scores['root'] = weighted_accuracy(root(ref_labels, est_labels), durations)
scores['mirex'] = weighted_accuracy(mirex(ref_labels, est_labels),
durations)
scores['majmin'] = weighted_accuracy(majmin(ref_labels, est_labels),
durations)
scores['majmin_inv'] = weighted_accuracy(majmin_inv(ref_labels,
est_labels), durations)
scores['sevenths'] = weighted_accuracy(sevenths(ref_labels, est_labels),
durations)
scores['sevenths_inv'] = weighted_accuracy(sevenths_inv(ref_labels,
est_labels),
durations)
scores['underseg'] = underseg(merged_ref_intervals, merged_est_intervals)
scores['overseg'] = overseg(merged_ref_intervals, merged_est_intervals)
scores['seg'] = min(scores['overseg'], scores['underseg'])
return scores
mir_eval-0.7/mir_eval/display.py 0000664 0000000 0000000 00000062147 14203260312 0016776 0 ustar 00root root 0000000 0000000 # -*- encoding: utf-8 -*-
'''Display functions'''
from collections import defaultdict
import numpy as np
from scipy.signal import spectrogram
from matplotlib.patches import Rectangle
from matplotlib.ticker import FuncFormatter, MultipleLocator
from matplotlib.ticker import Formatter
from matplotlib.colors import LinearSegmentedColormap, LogNorm, ColorConverter
from matplotlib.collections import BrokenBarHCollection
from .melody import freq_to_voicing
from .util import midi_to_hz, hz_to_midi
def __expand_limits(ax, limits, which='x'):
'''Helper function to expand axis limits'''
if which == 'x':
getter, setter = ax.get_xlim, ax.set_xlim
elif which == 'y':
getter, setter = ax.get_ylim, ax.set_ylim
else:
raise ValueError('invalid axis: {}'.format(which))
old_lims = getter()
new_lims = list(limits)
# infinite limits occur on new axis objects with no data
if np.isfinite(old_lims[0]):
new_lims[0] = min(old_lims[0], limits[0])
if np.isfinite(old_lims[1]):
new_lims[1] = max(old_lims[1], limits[1])
setter(new_lims)
def __get_axes(ax=None, fig=None):
'''Get or construct the target axes object for a new plot.
Parameters
----------
ax : matplotlib.pyplot.axes, optional
If provided, return this axes object directly.
fig : matplotlib.figure.Figure, optional
The figure to query for axes.
By default, uses the current figure `plt.gcf()`.
Returns
-------
ax : matplotlib.pyplot.axes
An axis handle on which to draw the segmentation.
If none is provided, a new set of axes is created.
new_axes : bool
If `True`, the axis object was newly constructed.
If `False`, the axis object already existed.
'''
new_axes = False
if ax is not None:
return ax, new_axes
if fig is None:
import matplotlib.pyplot as plt
fig = plt.gcf()
if not fig.get_axes():
new_axes = True
return fig.gca(), new_axes
def segments(intervals, labels, base=None, height=None, text=False,
text_kw=None, ax=None, **kwargs):
'''Plot a segmentation as a set of disjoint rectangles.
Parameters
----------
intervals : np.ndarray, shape=(n, 2)
segment intervals, in the format returned by
:func:`mir_eval.io.load_intervals` or
:func:`mir_eval.io.load_labeled_intervals`.
labels : list, shape=(n,)
reference segment labels, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
base : number
The vertical position of the base of the rectangles.
By default, this will be the bottom of the plot.
height : number
The height of the rectangles.
By default, this will be the top of the plot (minus ``base``).
text : bool
If true, each segment's label is displayed in its
upper-left corner
text_kw : dict
If ``text == True``, the properties of the text
object can be specified here.
See ``matplotlib.pyplot.Text`` for valid parameters
ax : matplotlib.pyplot.axes
An axis handle on which to draw the segmentation.
If none is provided, a new set of axes is created.
kwargs
Additional keyword arguments to pass to
``matplotlib.patches.Rectangle``.
Returns
-------
ax : matplotlib.pyplot.axes._subplots.AxesSubplot
A handle to the (possibly constructed) plot axes
'''
if text_kw is None:
text_kw = dict()
text_kw.setdefault('va', 'top')
text_kw.setdefault('clip_on', True)
text_kw.setdefault('bbox', dict(boxstyle='round', facecolor='white'))
# Make sure we have a numpy array
intervals = np.atleast_2d(intervals)
seg_def_style = dict(linewidth=1)
ax, new_axes = __get_axes(ax=ax)
if new_axes:
ax.set_ylim([0, 1])
# Infer height
if base is None:
base = ax.get_ylim()[0]
if height is None:
height = ax.get_ylim()[1]
cycler = ax._get_patches_for_fill.prop_cycler
seg_map = dict()
for lab in labels:
if lab in seg_map:
continue
style = next(cycler)
seg_map[lab] = seg_def_style.copy()
seg_map[lab].update(style)
# Swap color -> facecolor here so we preserve edgecolor on rects
seg_map[lab]['facecolor'] = seg_map[lab].pop('color')
seg_map[lab].update(kwargs)
seg_map[lab]['label'] = lab
for ival, lab in zip(intervals, labels):
rect = Rectangle((ival[0], base), ival[1] - ival[0], height,
**seg_map[lab])
ax.add_patch(rect)
seg_map[lab].pop('label', None)
if text:
ann = ax.annotate(lab,
xy=(ival[0], height), xycoords='data',
xytext=(8, -10), textcoords='offset points',
**text_kw)
ann.set_clip_path(rect)
if new_axes:
ax.set_yticks([])
# Only expand if we have data
if intervals.size:
__expand_limits(ax, [intervals.min(), intervals.max()], which='x')
return ax
def labeled_intervals(intervals, labels, label_set=None,
base=None, height=None, extend_labels=True,
ax=None, tick=True, **kwargs):
'''Plot labeled intervals with each label on its own row.
Parameters
----------
intervals : np.ndarray, shape=(n, 2)
segment intervals, in the format returned by
:func:`mir_eval.io.load_intervals` or
:func:`mir_eval.io.load_labeled_intervals`.
labels : list, shape=(n,)
reference segment labels, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
label_set : list
An (ordered) list of labels to determine the plotting order.
If not provided, the labels will be inferred from
``ax.get_yticklabels()``.
If no ``yticklabels`` exist, then the sorted set of unique values
in ``labels`` is taken as the label set.
base : np.ndarray, shape=(n,), optional
Vertical positions of each label.
By default, labels are positioned at integers
``np.arange(len(labels))``.
height : scalar or np.ndarray, shape=(n,), optional
Height for each label.
If scalar, the same value is applied to all labels.
By default, each label has ``height=1``.
extend_labels : bool
If ``False``, only values of ``labels`` that also exist in
``label_set`` will be shown.
If ``True``, all labels are shown, with those in `labels` but
not in `label_set` appended to the top of the plot.
A horizontal line is drawn to indicate the separation between
values in or out of ``label_set``.
ax : matplotlib.pyplot.axes
An axis handle on which to draw the intervals.
If none is provided, a new set of axes is created.
tick : bool
If ``True``, sets tick positions and labels on the y-axis.
kwargs
Additional keyword arguments to pass to
`matplotlib.collection.BrokenBarHCollection`.
Returns
-------
ax : matplotlib.pyplot.axes._subplots.AxesSubplot
A handle to the (possibly constructed) plot axes
'''
# Get the axes handle
ax, _ = __get_axes(ax=ax)
# Make sure we have a numpy array
intervals = np.atleast_2d(intervals)
if label_set is None:
# If we have non-empty pre-existing tick labels, use them
label_set = [_.get_text() for _ in ax.get_yticklabels()]
# If none of the label strings have content, treat it as empty
if not any(label_set):
label_set = []
else:
label_set = list(label_set)
# Put additional labels at the end, in order
if extend_labels:
ticks = label_set + sorted(set(labels) - set(label_set))
elif label_set:
ticks = label_set
else:
ticks = sorted(set(labels))
style = dict(linewidth=1)
style.update(next(ax._get_patches_for_fill.prop_cycler))
# Swap color -> facecolor here so we preserve edgecolor on rects
style['facecolor'] = style.pop('color')
style.update(kwargs)
if base is None:
base = np.arange(len(ticks))
if height is None:
height = 1
if np.isscalar(height):
height = height * np.ones_like(base)
seg_y = dict()
for ybase, yheight, lab in zip(base, height, ticks):
seg_y[lab] = (ybase, yheight)
xvals = defaultdict(list)
for ival, lab in zip(intervals, labels):
if lab not in seg_y:
continue
xvals[lab].append((ival[0], ival[1] - ival[0]))
for lab in seg_y:
ax.add_collection(BrokenBarHCollection(xvals[lab], seg_y[lab],
**style))
# Pop the label after the first time we see it, so we only get
# one legend entry
style.pop('label', None)
# Draw a line separating the new labels from pre-existing labels
if label_set != ticks:
ax.axhline(len(label_set), color='k', alpha=0.5)
if tick:
ax.grid(True, axis='y')
ax.set_yticks([])
ax.set_yticks(base)
ax.set_yticklabels(ticks, va='bottom')
ax.yaxis.set_major_formatter(IntervalFormatter(base, ticks))
if base.size:
__expand_limits(ax, [base.min(), (base + height).max()], which='y')
if intervals.size:
__expand_limits(ax, [intervals.min(), intervals.max()], which='x')
return ax
class IntervalFormatter(Formatter):
'''Ticker formatter for labeled interval plots.
Parameters
----------
base : array-like of int
The base positions of each label
ticks : array-like of string
The labels for the ticks
'''
def __init__(self, base, ticks):
self._map = {int(k): v for k, v in zip(base, ticks)}
def __call__(self, x, pos=None):
return self._map.get(int(x), '')
def hierarchy(intervals_hier, labels_hier, levels=None, ax=None, **kwargs):
'''Plot a hierarchical segmentation
Parameters
----------
intervals_hier : list of np.ndarray
A list of segmentation intervals. Each element should be
an n-by-2 array of segment intervals, in the format returned by
:func:`mir_eval.io.load_intervals` or
:func:`mir_eval.io.load_labeled_intervals`.
Segmentations should be ordered by increasing specificity.
labels_hier : list of list-like
A list of segmentation labels. Each element should
be a list of labels for the corresponding element in
`intervals_hier`.
levels : list of string
Each element ``levels[i]`` is a label for the ```i`` th segmentation.
This is used in the legend to denote the levels in a segment hierarchy.
kwargs
Additional keyword arguments to `labeled_intervals`.
Returns
-------
ax : matplotlib.pyplot.axes._subplots.AxesSubplot
A handle to the (possibly constructed) plot axes
'''
# This will break if a segment label exists in multiple levels
if levels is None:
levels = list(range(len(intervals_hier)))
# Get the axes handle
ax, _ = __get_axes(ax=ax)
# Count the pre-existing patches
n_patches = len(ax.patches)
for ints, labs, key in zip(intervals_hier[::-1],
labels_hier[::-1],
levels[::-1]):
labeled_intervals(ints, labs, label=key, ax=ax, **kwargs)
# Reverse the patch ordering for anything we've added.
# This way, intervals are listed in the legend from top to bottom
ax.patches[n_patches:] = ax.patches[n_patches:][::-1]
return ax
def events(times, labels=None, base=None, height=None, ax=None, text_kw=None,
**kwargs):
'''Plot event times as a set of vertical lines
Parameters
----------
times : np.ndarray, shape=(n,)
event times, in the format returned by
:func:`mir_eval.io.load_events` or
:func:`mir_eval.io.load_labeled_events`.
labels : list, shape=(n,), optional
event labels, in the format returned by
:func:`mir_eval.io.load_labeled_events`.
base : number
The vertical position of the base of the line.
By default, this will be the bottom of the plot.
height : number
The height of the lines.
By default, this will be the top of the plot (minus `base`).
ax : matplotlib.pyplot.axes
An axis handle on which to draw the segmentation.
If none is provided, a new set of axes is created.
text_kw : dict
If `labels` is provided, the properties of the text
objects can be specified here.
See `matplotlib.pyplot.Text` for valid parameters
kwargs
Additional keyword arguments to pass to
`matplotlib.pyplot.vlines`.
Returns
-------
ax : matplotlib.pyplot.axes._subplots.AxesSubplot
A handle to the (possibly constructed) plot axes
'''
if text_kw is None:
text_kw = dict()
text_kw.setdefault('va', 'top')
text_kw.setdefault('clip_on', True)
text_kw.setdefault('bbox', dict(boxstyle='round', facecolor='white'))
# make sure we have an array for times
times = np.asarray(times)
# Get the axes handle
ax, new_axes = __get_axes(ax=ax)
# If we have fresh axes, set the limits
if new_axes:
# Infer base and height
if base is None:
base = 0
if height is None:
height = 1
ax.set_ylim([base, height])
else:
if base is None:
base = ax.get_ylim()[0]
if height is None:
height = ax.get_ylim()[1]
cycler = ax._get_patches_for_fill.prop_cycler
style = next(cycler).copy()
style.update(kwargs)
# If the user provided 'colors', don't override it with 'color'
if 'colors' in style:
style.pop('color', None)
lines = ax.vlines(times, base, base + height, **style)
if labels:
for path, lab in zip(lines.get_paths(), labels):
ax.annotate(lab,
xy=(path.vertices[0][0], height),
xycoords='data',
xytext=(8, -10), textcoords='offset points',
**text_kw)
if new_axes:
ax.set_yticks([])
__expand_limits(ax, [base, base + height], which='y')
if times.size:
__expand_limits(ax, [times.min(), times.max()], which='x')
return ax
def pitch(times, frequencies, midi=False, unvoiced=False, ax=None, **kwargs):
'''Visualize pitch contours
Parameters
----------
times : np.ndarray, shape=(n,)
Sample times of frequencies
frequencies : np.ndarray, shape=(n,)
frequencies (in Hz) of the pitch contours.
Voicing is indicated by sign (positive for voiced,
non-positive for non-voiced).
midi : bool
If `True`, plot on a MIDI-numbered vertical axis.
Otherwise, plot on a linear frequency axis.
unvoiced : bool
If `True`, unvoiced pitch contours are plotted and indicated
by transparency.
Otherwise, unvoiced pitch contours are omitted from the display.
ax : matplotlib.pyplot.axes
An axis handle on which to draw the pitch contours.
If none is provided, a new set of axes is created.
kwargs
Additional keyword arguments to `matplotlib.pyplot.plot`.
Returns
-------
ax : matplotlib.pyplot.axes._subplots.AxesSubplot
A handle to the (possibly constructed) plot axes
'''
ax, _ = __get_axes(ax=ax)
times = np.asarray(times)
# First, segment into contiguously voiced contours
frequencies, voicings = freq_to_voicing(np.asarray(frequencies,
dtype=np.float))
voicings = voicings.astype(bool)
# Here are all the change-points
v_changes = 1 + np.flatnonzero(voicings[1:] != voicings[:-1])
v_changes = np.unique(np.concatenate([[0], v_changes, [len(voicings)]]))
# Set up arrays of slices for voiced and unvoiced regions
v_slices, u_slices = [], []
for start, end in zip(v_changes, v_changes[1:]):
idx = slice(start, end)
# A region is voiced if its starting sample is voiced
# It's unvoiced if none of the samples in the region are voiced.
if voicings[start]:
v_slices.append(idx)
elif frequencies[idx].all():
u_slices.append(idx)
# Now we just need to plot the contour
style = dict()
style.update(next(ax._get_lines.prop_cycler))
style.update(kwargs)
if midi:
idx = frequencies > 0
frequencies[idx] = hz_to_midi(frequencies[idx])
# Tick at integer midi notes
ax.yaxis.set_minor_locator(MultipleLocator(1))
for idx in v_slices:
ax.plot(times[idx], frequencies[idx], **style)
style.pop('label', None)
# Plot the unvoiced portions
if unvoiced:
style['alpha'] = style.get('alpha', 1.0) * 0.5
for idx in u_slices:
ax.plot(times[idx], frequencies[idx], **style)
return ax
def multipitch(times, frequencies, midi=False, unvoiced=False, ax=None,
**kwargs):
'''Visualize multiple f0 measurements
Parameters
----------
times : np.ndarray, shape=(n,)
Sample times of frequencies
frequencies : list of np.ndarray
frequencies (in Hz) of the pitch measurements.
Voicing is indicated by sign (positive for voiced,
non-positive for non-voiced).
`times` and `frequencies` should be in the format produced by
:func:`mir_eval.io.load_ragged_time_series`
midi : bool
If `True`, plot on a MIDI-numbered vertical axis.
Otherwise, plot on a linear frequency axis.
unvoiced : bool
If `True`, unvoiced pitches are plotted and indicated
by transparency.
Otherwise, unvoiced pitches are omitted from the display.
ax : matplotlib.pyplot.axes
An axis handle on which to draw the pitch contours.
If none is provided, a new set of axes is created.
kwargs
Additional keyword arguments to `plt.scatter`.
Returns
-------
ax : matplotlib.pyplot.axes._subplots.AxesSubplot
A handle to the (possibly constructed) plot axes
'''
# Get the axes handle
ax, _ = __get_axes(ax=ax)
# Set up a style for the plot
style_voiced = dict()
style_voiced.update(next(ax._get_lines.prop_cycler))
style_voiced.update(kwargs)
style_unvoiced = style_voiced.copy()
style_unvoiced.pop('label', None)
style_unvoiced['alpha'] = style_unvoiced.get('alpha', 1.0) * 0.5
# We'll collect all times and frequencies first, then plot them
voiced_times = []
voiced_freqs = []
unvoiced_times = []
unvoiced_freqs = []
for t, freqs in zip(times, frequencies):
if not len(freqs):
continue
freqs, voicings = freq_to_voicing(np.asarray(freqs, dtype=np.float))
# Discard all 0-frequency measurements
idx = freqs > 0
freqs = freqs[idx]
voicings = voicings[idx].astype(bool)
if midi:
freqs = hz_to_midi(freqs)
n_voiced = sum(voicings)
voiced_times.extend([t] * int(n_voiced))
voiced_freqs.extend(freqs[voicings])
unvoiced_times.extend([t] * (len(freqs) - n_voiced))
unvoiced_freqs.extend(freqs[~voicings])
# Plot the voiced frequencies
ax.scatter(voiced_times, voiced_freqs, **style_voiced)
# Plot the unvoiced frequencies
if unvoiced:
ax.scatter(unvoiced_times, unvoiced_freqs, **style_unvoiced)
# Tick at integer midi notes
if midi:
ax.yaxis.set_minor_locator(MultipleLocator(1))
return ax
def piano_roll(intervals, pitches=None, midi=None, ax=None, **kwargs):
'''Plot a quantized piano roll as intervals
Parameters
----------
intervals : np.ndarray, shape=(n, 2)
timing intervals for notes
pitches : np.ndarray, shape=(n,), optional
pitches of notes (in Hz).
midi : np.ndarray, shape=(n,), optional
pitches of notes (in MIDI numbers).
At least one of ``pitches`` or ``midi`` must be provided.
ax : matplotlib.pyplot.axes
An axis handle on which to draw the intervals.
If none is provided, a new set of axes is created.
kwargs
Additional keyword arguments to :func:`labeled_intervals`.
Returns
-------
ax : matplotlib.pyplot.axes._subplots.AxesSubplot
A handle to the (possibly constructed) plot axes
'''
if midi is None:
if pitches is None:
raise ValueError('At least one of `midi` or `pitches` '
'must be provided.')
midi = hz_to_midi(pitches)
scale = np.arange(128)
ax = labeled_intervals(intervals, np.round(midi).astype(int),
label_set=scale,
tick=False,
ax=ax,
**kwargs)
# Minor tick at each semitone
ax.yaxis.set_minor_locator(MultipleLocator(1))
ax.axis('auto')
return ax
def separation(sources, fs=22050, labels=None, alpha=0.75, ax=None, **kwargs):
'''Source-separation visualization
Parameters
----------
sources : np.ndarray, shape=(nsrc, nsampl)
A list of waveform buffers corresponding to each source
fs : number > 0
The sampling rate
labels : list of strings
An optional list of descriptors corresponding to each source
alpha : float in [0, 1]
Maximum alpha (opacity) of spectrogram values.
ax : matplotlib.pyplot.axes
An axis handle on which to draw the spectrograms.
If none is provided, a new set of axes is created.
kwargs
Additional keyword arguments to ``scipy.signal.spectrogram``
Returns
-------
ax
The axis handle for this plot
'''
# Get the axes handle
ax, new_axes = __get_axes(ax=ax)
# Make sure we have at least two dimensions
sources = np.atleast_2d(sources)
if labels is None:
labels = ['Source {:d}'.format(_) for _ in range(len(sources))]
kwargs.setdefault('scaling', 'spectrum')
# The cumulative spectrogram across sources
# is used to establish the reference power
# for each individual source
cumspec = None
specs = []
for i, src in enumerate(sources):
freqs, times, spec = spectrogram(src, fs=fs, **kwargs)
specs.append(spec)
if cumspec is None:
cumspec = spec.copy()
else:
cumspec += spec
ref_max = cumspec.max()
ref_min = ref_max * 1e-6
color_conv = ColorConverter()
for i, spec in enumerate(specs):
# For each source, grab a new color from the cycler
# Then construct a colormap that interpolates from
# [transparent white -> new color]
color = next(ax._get_lines.prop_cycler)['color']
color = color_conv.to_rgba(color, alpha=alpha)
cmap = LinearSegmentedColormap.from_list(labels[i],
[(1.0, 1.0, 1.0, 0.0),
color])
ax.pcolormesh(times, freqs, spec,
cmap=cmap,
norm=LogNorm(vmin=ref_min, vmax=ref_max),
shading='gouraud',
label=labels[i])
# Attach a 0x0 rect to the axis with the corresponding label
# This way, it will show up in the legend
ax.add_patch(Rectangle((0, 0), 0, 0, color=color, label=labels[i]))
if new_axes:
ax.axis('tight')
return ax
def __ticker_midi_note(x, pos):
'''A ticker function for midi notes.
Inputs x are interpreted as midi numbers, and converted
to [NOTE][OCTAVE]+[cents].
'''
NOTES = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
cents = float(np.mod(x, 1.0))
if cents >= 0.5:
cents = cents - 1.0
x = x + 0.5
idx = int(x % 12)
octave = int(x / 12) - 1
if cents == 0:
return '{:s}{:2d}'.format(NOTES[idx], octave)
return '{:s}{:2d}{:+02d}'.format(NOTES[idx], octave, int(cents * 100))
def __ticker_midi_hz(x, pos):
'''A ticker function for midi pitches.
Inputs x are interpreted as midi numbers, and converted
to Hz.
'''
return '{:g}'.format(midi_to_hz(x))
def ticker_notes(ax=None):
'''Set the y-axis of the given axes to MIDI notes
Parameters
----------
ax : matplotlib.pyplot.axes
The axes handle to apply the ticker.
By default, uses the current axes handle.
'''
ax, _ = __get_axes(ax=ax)
ax.yaxis.set_major_formatter(FMT_MIDI_NOTE)
# Get the tick labels and reset the vertical alignment
for tick in ax.yaxis.get_ticklabels():
tick.set_verticalalignment('baseline')
def ticker_pitch(ax=None):
'''Set the y-axis of the given axes to MIDI frequencies
Parameters
----------
ax : matplotlib.pyplot.axes
The axes handle to apply the ticker.
By default, uses the current axes handle.
'''
ax, _ = __get_axes(ax=ax)
ax.yaxis.set_major_formatter(FMT_MIDI_HZ)
# Instantiate ticker objects; we don't need more than one of each
FMT_MIDI_NOTE = FuncFormatter(__ticker_midi_note)
FMT_MIDI_HZ = FuncFormatter(__ticker_midi_hz)
mir_eval-0.7/mir_eval/hierarchy.py 0000664 0000000 0000000 00000060636 14203260312 0017310 0 ustar 00root root 0000000 0000000 # CREATED:2015-09-16 14:46:47 by Brian McFee
# -*- encoding: utf-8 -*-
'''Evaluation criteria for hierarchical structure analysis.
Hierarchical structure analysis seeks to annotate a track with a nested
decomposition of the temporal elements of the piece, effectively providing
a kind of "parse tree" of the composition. Unlike the flat segmentation
metrics defined in :mod:`mir_eval.segment`, which can only encode one level of
analysis, hierarchical annotations expose the relationships between short
segments and the larger compositional elements to which they belong.
Conventions
-----------
Annotations are assumed to take the form of an ordered list of segmentations.
As in the :mod:`mir_eval.segment` metrics, each segmentation itself consists of
an n-by-2 array of interval times, so that the ``i`` th segment spans time
``intervals[i, 0]`` to ``intervals[i, 1]``.
Hierarchical annotations are ordered by increasing specificity, so that the
first segmentation should contain the fewest segments, and the last
segmentation contains the most.
Metrics
-------
* :func:`mir_eval.hierarchy.tmeasure`: Precision, recall, and F-measure of
triplet-based frame accuracy for boundary detection.
* :func:`mir_eval.hierarchy.lmeasure`: Precision, recall, and F-measure of
triplet-based frame accuracy for segment labeling.
References
----------
.. [#mcfee2015] Brian McFee, Oriol Nieto, and Juan P. Bello.
"Hierarchical evaluation of segment boundary detection",
International Society for Music Information Retrieval (ISMIR) conference,
2015.
.. [#mcfee2017] Brian McFee, Oriol Nieto, Morwaread Farbood, and
Juan P. Bello.
"Evaluating hierarchical structure in music annotations",
Frontiers in Psychology, 2017.
'''
import collections
import itertools
import warnings
import numpy as np
import scipy.sparse
from . import util
from .segment import validate_structure
def _round(t, frame_size):
'''Round a time-stamp to a specified resolution.
Equivalent to ``t - np.mod(t, frame_size)``.
Examples
--------
>>> _round(53.279, 0.1)
53.2
>>> _round(53.279, 0.25)
53.25
Parameters
----------
t : number or ndarray
The time-stamp to round
frame_size : number > 0
The resolution to round to
Returns
-------
t_round : number
The rounded time-stamp
'''
return t - np.mod(t, float(frame_size))
def _hierarchy_bounds(intervals_hier):
'''Compute the covered time range of a hierarchical segmentation.
Parameters
----------
intervals_hier : list of ndarray
A hierarchical segmentation, encoded as a list of arrays of segment
intervals.
Returns
-------
t_min : float
t_max : float
The minimum and maximum times spanned by the annotation
'''
boundaries = list(itertools.chain(*list(itertools.chain(*intervals_hier))))
return min(boundaries), max(boundaries)
def _align_intervals(int_hier, lab_hier, t_min=0.0, t_max=None):
'''Align a hierarchical annotation to span a fixed start and end time.
Parameters
----------
int_hier : list of list of intervals
lab_hier : list of list of str
Hierarchical segment annotations, encoded as a
list of list of intervals (int_hier) and list of
list of strings (lab_hier)
t_min : None or number >= 0
The minimum time value for the segmentation
t_max : None or number >= t_min
The maximum time value for the segmentation
Returns
-------
intervals_hier : list of list of intervals
labels_hier : list of list of str
`int_hier` `lab_hier` aligned to span `[t_min, t_max]`.
'''
return [list(_) for _ in zip(*[util.adjust_intervals(np.asarray(ival),
labels=lab,
t_min=t_min,
t_max=t_max)
for ival, lab in zip(int_hier, lab_hier)])]
def _lca(intervals_hier, frame_size):
'''Compute the (sparse) least-common-ancestor (LCA) matrix for a
hierarchical segmentation.
For any pair of frames ``(s, t)``, the LCA is the deepest level in
the hierarchy such that ``(s, t)`` are contained within a single
segment at that level.
Parameters
----------
intervals_hier : list of ndarray
An ordered list of segment interval arrays.
The list is assumed to be ordered by increasing specificity (depth).
frame_size : number
The length of the sample frames (in seconds)
Returns
-------
lca_matrix : scipy.sparse.csr_matrix
A sparse matrix such that ``lca_matrix[i, j]`` contains the depth
of the deepest segment containing frames ``i`` and ``j``.
'''
frame_size = float(frame_size)
# Figure out how many frames we need
n_start, n_end = _hierarchy_bounds(intervals_hier)
n = int((_round(n_end, frame_size) -
_round(n_start, frame_size)) / frame_size)
# Initialize the LCA matrix
lca_matrix = scipy.sparse.lil_matrix((n, n), dtype=np.uint8)
for level, intervals in enumerate(intervals_hier, 1):
for ival in (_round(np.asarray(intervals),
frame_size) / frame_size).astype(int):
idx = slice(ival[0], ival[1])
lca_matrix[idx, idx] = level
return lca_matrix.tocsr()
def _meet(intervals_hier, labels_hier, frame_size):
'''Compute the (sparse) least-common-ancestor (LCA) matrix for a
hierarchical segmentation.
For any pair of frames ``(s, t)``, the LCA is the deepest level in
the hierarchy such that ``(s, t)`` are contained within a single
segment at that level.
Parameters
----------
intervals_hier : list of ndarray
An ordered list of segment interval arrays.
The list is assumed to be ordered by increasing specificity (depth).
labels_hier : list of list of str
``labels_hier[i]`` contains the segment labels for the
``i``th layer of the annotations
frame_size : number
The length of the sample frames (in seconds)
Returns
-------
meet_matrix : scipy.sparse.csr_matrix
A sparse matrix such that ``meet_matrix[i, j]`` contains the depth
of the deepest segment label containing both ``i`` and ``j``.
'''
frame_size = float(frame_size)
# Figure out how many frames we need
n_start, n_end = _hierarchy_bounds(intervals_hier)
n = int((_round(n_end, frame_size) -
_round(n_start, frame_size)) / frame_size)
# Initialize the meet matrix
meet_matrix = scipy.sparse.lil_matrix((n, n), dtype=np.uint8)
for level, (intervals, labels) in enumerate(zip(intervals_hier,
labels_hier), 1):
# Encode the labels at this level
lab_enc = util.index_labels(labels)[0]
# Find unique agreements
int_agree = np.triu(np.equal.outer(lab_enc, lab_enc))
# Map intervals to frame indices
int_frames = (_round(intervals, frame_size) / frame_size).astype(int)
# For each intervals i, j where labels agree, update the meet matrix
for (seg_i, seg_j) in zip(*np.where(int_agree)):
idx_i = slice(*list(int_frames[seg_i]))
idx_j = slice(*list(int_frames[seg_j]))
meet_matrix[idx_i, idx_j] = level
if seg_i != seg_j:
meet_matrix[idx_j, idx_i] = level
return scipy.sparse.csr_matrix(meet_matrix)
def _gauc(ref_lca, est_lca, transitive, window):
'''Generalized area under the curve (GAUC)
This function computes the normalized recall score for correctly
ordering triples ``(q, i, j)`` where frames ``(q, i)`` are closer than
``(q, j)`` in the reference annotation.
Parameters
----------
ref_lca : scipy.sparse
est_lca : scipy.sparse
The least common ancestor matrices for the reference and
estimated annotations
transitive : bool
If True, then transitive comparisons are counted, meaning that
``(q, i)`` and ``(q, j)`` can differ by any number of levels.
If False, then ``(q, i)`` and ``(q, j)`` can differ by exactly one
level.
window : number or None
The maximum number of frames to consider for each query.
If `None`, then all frames are considered.
Returns
-------
score : number [0, 1]
The percentage of reference triples correctly ordered by
the estimation.
Raises
------
ValueError
If ``ref_lca`` and ``est_lca`` have different shapes
'''
# Make sure we have the right number of frames
if ref_lca.shape != est_lca.shape:
raise ValueError('Estimated and reference hierarchies '
'must have the same shape.')
# How many frames?
n = ref_lca.shape[0]
# By default, the window covers the entire track
if window is None:
window = n
# Initialize the score
score = 0.0
# Iterate over query frames
num_frames = 0
for query in range(n):
# Find all pairs i,j such that ref_lca[q, i] > ref_lca[q, j]
results = slice(max(0, query - window), min(n, query + window))
ref_score = ref_lca[query, results]
est_score = est_lca[query, results]
# Densify the results
ref_score = ref_score.toarray().squeeze()
est_score = est_score.toarray().squeeze()
# Don't count the query as a result
# when query < window, query itself is the index within the slice
# otherwise, query is located at the center of the slice, window
# (this also holds when the slice goes off the end of the array.)
idx = min(query, window)
ref_score = np.concatenate((ref_score[:idx], ref_score[idx+1:]))
est_score = np.concatenate((est_score[:idx], est_score[idx+1:]))
inversions, normalizer = _compare_frame_rankings(ref_score, est_score,
transitive=transitive)
if normalizer:
score += 1.0 - inversions / float(normalizer)
num_frames += 1
# Normalize by the number of frames counted.
# If no frames are counted, take the convention 0/0 -> 0
if num_frames:
score /= float(num_frames)
else:
score = 0.0
return score
def _count_inversions(a, b):
'''Count the number of inversions in two numpy arrays:
# points i, j where a[i] >= b[j]
Parameters
----------
a, b : np.ndarray, shape=(n,) (m,)
The arrays to be compared.
This implementation is optimized for arrays with many
repeated values.
Returns
-------
inversions : int
The number of detected inversions
'''
a, a_counts = np.unique(a, return_counts=True)
b, b_counts = np.unique(b, return_counts=True)
inversions = 0
i = 0
j = 0
while i < len(a) and j < len(b):
if a[i] < b[j]:
i += 1
elif a[i] >= b[j]:
inversions += np.sum(a_counts[i:]) * b_counts[j]
j += 1
return inversions
def _compare_frame_rankings(ref, est, transitive=False):
'''Compute the number of ranking disagreements in two lists.
Parameters
----------
ref : np.ndarray, shape=(n,)
est : np.ndarray, shape=(n,)
Reference and estimate ranked lists.
`ref[i]` is the relevance score for point `i`.
transitive : bool
If true, all pairs of reference levels are compared.
If false, only adjacent pairs of reference levels are compared.
Returns
-------
inversions : int
The number of pairs of indices `i, j` where
`ref[i] < ref[j]` but `est[i] >= est[j]`.
normalizer : float
The total number of pairs (i, j) under consideration.
If transitive=True, then this is |{(i,j) : ref[i] < ref[j]}|
If transitive=False, then this is |{i,j) : ref[i] +1 = ref[j]}|
'''
idx = np.argsort(ref)
ref_sorted = ref[idx]
est_sorted = est[idx]
# Find the break-points in ref_sorted
levels, positions, counts = np.unique(ref_sorted,
return_index=True,
return_counts=True)
positions = list(positions)
positions.append(len(ref_sorted))
index = collections.defaultdict(lambda: slice(0))
ref_map = collections.defaultdict(lambda: 0)
for level, cnt, start, end in zip(levels, counts,
positions[:-1], positions[1:]):
index[level] = slice(start, end)
ref_map[level] = cnt
# Now that we have values sorted, apply the inversion-counter to
# pairs of reference values
if transitive:
level_pairs = itertools.combinations(levels, 2)
else:
level_pairs = [(i, i+1) for i in levels]
level_pairs, lcounter = itertools.tee(level_pairs)
normalizer = float(sum([ref_map[i] * ref_map[j] for (i, j) in lcounter]))
if normalizer == 0:
return 0, 0.0
inversions = 0
for level_1, level_2 in level_pairs:
inversions += _count_inversions(est_sorted[index[level_1]],
est_sorted[index[level_2]])
return inversions, float(normalizer)
def validate_hier_intervals(intervals_hier):
'''Validate a hierarchical segment annotation.
Parameters
----------
intervals_hier : ordered list of segmentations
Raises
------
ValueError
If any segmentation does not span the full duration of the top-level
segmentation.
If any segmentation does not start at 0.
'''
# Synthesize a label array for the top layer.
label_top = util.generate_labels(intervals_hier[0])
boundaries = set(util.intervals_to_boundaries(intervals_hier[0]))
for level, intervals in enumerate(intervals_hier[1:], 1):
# Make sure this level is consistent with the root
label_current = util.generate_labels(intervals)
validate_structure(intervals_hier[0], label_top,
intervals, label_current)
# Make sure all previous boundaries are accounted for
new_bounds = set(util.intervals_to_boundaries(intervals))
if boundaries - new_bounds:
warnings.warn('Segment hierarchy is inconsistent '
'at level {:d}'.format(level))
boundaries |= new_bounds
def tmeasure(reference_intervals_hier, estimated_intervals_hier,
transitive=False, window=15.0, frame_size=0.1, beta=1.0):
'''Computes the tree measures for hierarchical segment annotations.
Parameters
----------
reference_intervals_hier : list of ndarray
``reference_intervals_hier[i]`` contains the segment intervals
(in seconds) for the ``i`` th layer of the annotations. Layers are
ordered from top to bottom, so that the last list of intervals should
be the most specific.
estimated_intervals_hier : list of ndarray
Like ``reference_intervals_hier`` but for the estimated annotation
transitive : bool
whether to compute the t-measures using transitivity or not.
window : float > 0
size of the window (in seconds). For each query frame q,
result frames are only counted within q +- window.
frame_size : float > 0
length (in seconds) of frames. The frame size cannot be longer than
the window.
beta : float > 0
beta parameter for the F-measure.
Returns
-------
t_precision : number [0, 1]
T-measure Precision
t_recall : number [0, 1]
T-measure Recall
t_measure : number [0, 1]
F-beta measure for ``(t_precision, t_recall)``
Raises
------
ValueError
If either of the input hierarchies are inconsistent
If the input hierarchies have different time durations
If ``frame_size > window`` or ``frame_size <= 0``
'''
# Compute the number of frames in the window
if frame_size <= 0:
raise ValueError('frame_size ({:.2f}) must be a positive '
'number.'.format(frame_size))
if window is None:
window_frames = None
else:
if frame_size > window:
raise ValueError('frame_size ({:.2f}) cannot exceed '
'window ({:.2f})'.format(frame_size, window))
window_frames = int(_round(window, frame_size) / frame_size)
# Validate the hierarchical segmentations
validate_hier_intervals(reference_intervals_hier)
validate_hier_intervals(estimated_intervals_hier)
# Build the least common ancestor matrices
ref_lca = _lca(reference_intervals_hier, frame_size)
est_lca = _lca(estimated_intervals_hier, frame_size)
# Compute precision and recall
t_recall = _gauc(ref_lca, est_lca, transitive, window_frames)
t_precision = _gauc(est_lca, ref_lca, transitive, window_frames)
t_measure = util.f_measure(t_precision, t_recall, beta=beta)
return t_precision, t_recall, t_measure
def lmeasure(reference_intervals_hier, reference_labels_hier,
estimated_intervals_hier, estimated_labels_hier,
frame_size=0.1, beta=1.0):
'''Computes the tree measures for hierarchical segment annotations.
Parameters
----------
reference_intervals_hier : list of ndarray
``reference_intervals_hier[i]`` contains the segment intervals
(in seconds) for the ``i`` th layer of the annotations. Layers are
ordered from top to bottom, so that the last list of intervals should
be the most specific.
reference_labels_hier : list of list of str
``reference_labels_hier[i]`` contains the segment labels for the
``i``th layer of the annotations
estimated_intervals_hier : list of ndarray
estimated_labels_hier : list of ndarray
Like ``reference_intervals_hier`` and ``reference_labels_hier``
but for the estimated annotation
frame_size : float > 0
length (in seconds) of frames. The frame size cannot be longer than
the window.
beta : float > 0
beta parameter for the F-measure.
Returns
-------
l_precision : number [0, 1]
L-measure Precision
l_recall : number [0, 1]
L-measure Recall
l_measure : number [0, 1]
F-beta measure for ``(l_precision, l_recall)``
Raises
------
ValueError
If either of the input hierarchies are inconsistent
If the input hierarchies have different time durations
If ``frame_size > window`` or ``frame_size <= 0``
'''
# Compute the number of frames in the window
if frame_size <= 0:
raise ValueError('frame_size ({:.2f}) must be a positive '
'number.'.format(frame_size))
# Validate the hierarchical segmentations
validate_hier_intervals(reference_intervals_hier)
validate_hier_intervals(estimated_intervals_hier)
# Build the least common ancestor matrices
ref_meet = _meet(reference_intervals_hier, reference_labels_hier,
frame_size)
est_meet = _meet(estimated_intervals_hier, estimated_labels_hier,
frame_size)
# Compute precision and recall
l_recall = _gauc(ref_meet, est_meet, True, None)
l_precision = _gauc(est_meet, ref_meet, True, None)
l_measure = util.f_measure(l_precision, l_recall, beta=beta)
return l_precision, l_recall, l_measure
def evaluate(ref_intervals_hier, ref_labels_hier,
est_intervals_hier, est_labels_hier, **kwargs):
'''Compute all hierarchical structure metrics for the given reference and
estimated annotations.
Examples
--------
A toy example with two two-layer annotations
>>> ref_i = [[[0, 30], [30, 60]], [[0, 15], [15, 30], [30, 45], [45, 60]]]
>>> est_i = [[[0, 45], [45, 60]], [[0, 15], [15, 30], [30, 45], [45, 60]]]
>>> ref_l = [ ['A', 'B'], ['a', 'b', 'a', 'c'] ]
>>> est_l = [ ['A', 'B'], ['a', 'a', 'b', 'b'] ]
>>> scores = mir_eval.hierarchy.evaluate(ref_i, ref_l, est_i, est_l)
>>> dict(scores)
{'T-Measure full': 0.94822745804853459,
'T-Measure reduced': 0.8732458222764804,
'T-Precision full': 0.96569179094693058,
'T-Precision reduced': 0.89939075137018787,
'T-Recall full': 0.93138358189386117,
'T-Recall reduced': 0.84857799953694923}
A more realistic example, using SALAMI pre-parsed annotations
>>> def load_salami(filename):
... "load SALAMI event format as labeled intervals"
... events, labels = mir_eval.io.load_labeled_events(filename)
... intervals = mir_eval.util.boundaries_to_intervals(events)[0]
... return intervals, labels[:len(intervals)]
>>> ref_files = ['data/10/parsed/textfile1_uppercase.txt',
... 'data/10/parsed/textfile1_lowercase.txt']
>>> est_files = ['data/10/parsed/textfile2_uppercase.txt',
... 'data/10/parsed/textfile2_lowercase.txt']
>>> ref = [load_salami(fname) for fname in ref_files]
>>> ref_int = [seg[0] for seg in ref]
>>> ref_lab = [seg[1] for seg in ref]
>>> est = [load_salami(fname) for fname in est_files]
>>> est_int = [seg[0] for seg in est]
>>> est_lab = [seg[1] for seg in est]
>>> scores = mir_eval.hierarchy.evaluate(ref_int, ref_lab,
... est_hier, est_lab)
>>> dict(scores)
{'T-Measure full': 0.66029225561405358,
'T-Measure reduced': 0.62001868041578034,
'T-Precision full': 0.66844764668949885,
'T-Precision reduced': 0.63252297209957919,
'T-Recall full': 0.6523334654992341,
'T-Recall reduced': 0.60799919710921635}
Parameters
----------
ref_intervals_hier : list of list-like
ref_labels_hier : list of list of str
est_intervals_hier : list of list-like
est_labels_hier : list of list of str
Hierarchical annotations are encoded as an ordered list
of segmentations. Each segmentation itself is a list (or list-like)
of intervals (\*_intervals_hier) and a list of lists of labels
(\*_labels_hier).
kwargs
additional keyword arguments to the evaluation metrics.
Returns
-------
scores : OrderedDict
Dictionary of scores, where the key is the metric name (str) and
the value is the (float) score achieved.
T-measures are computed in both the "full" (``transitive=True``) and
"reduced" (``transitive=False``) modes.
Raises
------
ValueError
Thrown when the provided annotations are not valid.
'''
# First, find the maximum length of the reference
_, t_end = _hierarchy_bounds(ref_intervals_hier)
# Pre-process the intervals to match the range of the reference,
# and start at 0
ref_intervals_hier, ref_labels_hier = _align_intervals(ref_intervals_hier,
ref_labels_hier,
t_min=0.0,
t_max=None)
est_intervals_hier, est_labels_hier = _align_intervals(est_intervals_hier,
est_labels_hier,
t_min=0.0,
t_max=t_end)
scores = collections.OrderedDict()
# Force the transitivity setting
kwargs['transitive'] = False
(scores['T-Precision reduced'],
scores['T-Recall reduced'],
scores['T-Measure reduced']) = util.filter_kwargs(tmeasure,
ref_intervals_hier,
est_intervals_hier,
**kwargs)
kwargs['transitive'] = True
(scores['T-Precision full'],
scores['T-Recall full'],
scores['T-Measure full']) = util.filter_kwargs(tmeasure,
ref_intervals_hier,
est_intervals_hier,
**kwargs)
(scores['L-Precision'],
scores['L-Recall'],
scores['L-Measure']) = util.filter_kwargs(lmeasure,
ref_intervals_hier,
ref_labels_hier,
est_intervals_hier,
est_labels_hier,
**kwargs)
return scores
mir_eval-0.7/mir_eval/io.py 0000664 0000000 0000000 00000055023 14203260312 0015733 0 ustar 00root root 0000000 0000000 """
Functions for loading in annotations from files in different formats.
"""
import contextlib
import numpy as np
import re
import warnings
import scipy.io.wavfile
import six
from . import util
from . import key
from . import tempo
@contextlib.contextmanager
def _open(file_or_str, **kwargs):
'''Either open a file handle, or use an existing file-like object.
This will behave as the `open` function if `file_or_str` is a string.
If `file_or_str` has the `read` attribute, it will return `file_or_str`.
Otherwise, an `IOError` is raised.
'''
if hasattr(file_or_str, 'read'):
yield file_or_str
elif isinstance(file_or_str, six.string_types):
with open(file_or_str, **kwargs) as file_desc:
yield file_desc
else:
raise IOError('Invalid file-or-str object: {}'.format(file_or_str))
def load_delimited(filename, converters, delimiter=r'\s+', comment='#'):
r"""Utility function for loading in data from an annotation file where columns
are delimited. The number of columns is inferred from the length of
the provided converters list.
Examples
--------
>>> # Load in a one-column list of event times (floats)
>>> load_delimited('events.txt', [float])
>>> # Load in a list of labeled events, separated by commas
>>> load_delimited('labeled_events.csv', [float, str], ',')
Parameters
----------
filename : str
Path to the annotation file
converters : list of functions
Each entry in column ``n`` of the file will be cast by the function
``converters[n]``.
delimiter : str
Separator regular expression.
By default, lines will be split by any amount of whitespace.
comment : str or None
Comment regular expression.
Any lines beginning with this string or pattern will be ignored.
Setting to `None` disables comments.
Returns
-------
columns : tuple of lists
Each list in this tuple corresponds to values in one of the columns
in the file.
"""
# Initialize list of empty lists
n_columns = len(converters)
columns = tuple(list() for _ in range(n_columns))
# Create re object for splitting lines
splitter = re.compile(delimiter)
# And one for comments
if comment is None:
commenter = None
else:
commenter = re.compile('^{}'.format(comment))
# Note: we do io manually here for two reasons.
# 1. The csv module has difficulties with unicode, which may lead
# to failures on certain annotation strings
#
# 2. numpy's text loader does not handle non-numeric data
#
with _open(filename, mode='r') as input_file:
for row, line in enumerate(input_file, 1):
# Skip commented lines
if comment is not None and commenter.match(line):
continue
# Split each line using the supplied delimiter
data = splitter.split(line.strip(), n_columns - 1)
# Throw a helpful error if we got an unexpected # of columns
if n_columns != len(data):
raise ValueError('Expected {} columns, got {} at '
'{}:{:d}:\n\t{}'.format(n_columns, len(data),
filename, row, line))
for value, column, converter in zip(data, columns, converters):
# Try converting the value, throw a helpful error on failure
try:
converted_value = converter(value)
except:
raise ValueError("Couldn't convert value {} using {} "
"found at {}:{:d}:\n\t{}".format(
value, converter.__name__, filename,
row, line))
column.append(converted_value)
# Sane output
if n_columns == 1:
return columns[0]
else:
return columns
def load_events(filename, delimiter=r'\s+', comment='#'):
r"""Import time-stamp events from an annotation file. The file should
consist of a single column of numeric values corresponding to the event
times. This is primarily useful for processing events which lack duration,
such as beats or onsets.
Parameters
----------
filename : str
Path to the annotation file
delimiter : str
Separator regular expression.
By default, lines will be split by any amount of whitespace.
comment : str or None
Comment regular expression.
Any lines beginning with this string or pattern will be ignored.
Setting to `None` disables comments.
Returns
-------
event_times : np.ndarray
array of event times (float)
"""
# Use our universal function to load in the events
events = load_delimited(filename, [float],
delimiter=delimiter, comment=comment)
events = np.array(events)
# Validate them, but throw a warning in place of an error
try:
util.validate_events(events)
except ValueError as error:
warnings.warn(error.args[0])
return events
def load_labeled_events(filename, delimiter=r'\s+', comment='#'):
r"""Import labeled time-stamp events from an annotation file. The file should
consist of two columns; the first having numeric values corresponding to
the event times and the second having string labels for each event. This
is primarily useful for processing labeled events which lack duration, such
as beats with metric beat number or onsets with an instrument label.
Parameters
----------
filename : str
Path to the annotation file
delimiter : str
Separator regular expression.
By default, lines will be split by any amount of whitespace.
comment : str or None
Comment regular expression.
Any lines beginning with this string or pattern will be ignored.
Setting to `None` disables comments.
Returns
-------
event_times : np.ndarray
array of event times (float)
labels : list of str
list of labels
"""
# Use our universal function to load in the events
events, labels = load_delimited(filename, [float, str],
delimiter=delimiter,
comment=comment)
events = np.array(events)
# Validate them, but throw a warning in place of an error
try:
util.validate_events(events)
except ValueError as error:
warnings.warn(error.args[0])
return events, labels
def load_intervals(filename, delimiter=r'\s+', comment='#'):
r"""Import intervals from an annotation file. The file should consist of two
columns of numeric values corresponding to start and end time of each
interval. This is primarily useful for processing events which span a
duration, such as segmentation, chords, or instrument activation.
Parameters
----------
filename : str
Path to the annotation file
delimiter : str
Separator regular expression.
By default, lines will be split by any amount of whitespace.
comment : str or None
Comment regular expression.
Any lines beginning with this string or pattern will be ignored.
Setting to `None` disables comments.
Returns
-------
intervals : np.ndarray, shape=(n_events, 2)
array of event start and end times
"""
# Use our universal function to load in the events
starts, ends = load_delimited(filename, [float, float],
delimiter=delimiter,
comment=comment)
# Stack into an interval matrix
intervals = np.array([starts, ends]).T
# Validate them, but throw a warning in place of an error
try:
util.validate_intervals(intervals)
except ValueError as error:
warnings.warn(error.args[0])
return intervals
def load_labeled_intervals(filename, delimiter=r'\s+', comment='#'):
r"""Import labeled intervals from an annotation file. The file should consist
of three columns: Two consisting of numeric values corresponding to start
and end time of each interval and a third corresponding to the label of
each interval. This is primarily useful for processing events which span a
duration, such as segmentation, chords, or instrument activation.
Parameters
----------
filename : str
Path to the annotation file
delimiter : str
Separator regular expression.
By default, lines will be split by any amount of whitespace.
comment : str or None
Comment regular expression.
Any lines beginning with this string or pattern will be ignored.
Setting to `None` disables comments.
Returns
-------
intervals : np.ndarray, shape=(n_events, 2)
array of event start and end time
labels : list of str
list of labels
"""
# Use our universal function to load in the events
starts, ends, labels = load_delimited(filename, [float, float, str],
delimiter=delimiter,
comment=comment)
# Stack into an interval matrix
intervals = np.array([starts, ends]).T
# Validate them, but throw a warning in place of an error
try:
util.validate_intervals(intervals)
except ValueError as error:
warnings.warn(error.args[0])
return intervals, labels
def load_time_series(filename, delimiter=r'\s+', comment='#'):
r"""Import a time series from an annotation file. The file should consist of
two columns of numeric values corresponding to the time and value of each
sample of the time series.
Parameters
----------
filename : str
Path to the annotation file
delimiter : str
Separator regular expression.
By default, lines will be split by any amount of whitespace.
comment : str or None
Comment regular expression.
Any lines beginning with this string or pattern will be ignored.
Setting to `None` disables comments.
Returns
-------
times : np.ndarray
array of timestamps (float)
values : np.ndarray
array of corresponding numeric values (float)
"""
# Use our universal function to load in the events
times, values = load_delimited(filename, [float, float],
delimiter=delimiter,
comment=comment)
times = np.array(times)
values = np.array(values)
return times, values
def load_patterns(filename):
"""Loads the patters contained in the filename and puts them into a list
of patterns, each pattern being a list of occurrence, and each
occurrence being a list of (onset, midi) pairs.
The input file must be formatted as described in MIREX 2013:
http://www.music-ir.org/mirex/wiki/2013:Discovery_of_Repeated_Themes_%26_Sections
Parameters
----------
filename : str
The input file path containing the patterns of a given piece using the
MIREX 2013 format.
Returns
-------
pattern_list : list
The list of patterns, containing all their occurrences,
using the following format::
onset_midi = (onset_time, midi_number)
occurrence = [onset_midi1, ..., onset_midiO]
pattern = [occurrence1, ..., occurrenceM]
pattern_list = [pattern1, ..., patternN]
where ``N`` is the number of patterns, ``M[i]`` is the number of
occurrences of the ``i`` th pattern, and ``O[j]`` is the number of
onsets in the ``j``'th occurrence. E.g.::
occ1 = [(0.5, 67.0), (1.0, 67.0), (1.5, 67.0), (2.0, 64.0)]
occ2 = [(4.5, 65.0), (5.0, 65.0), (5.5, 65.0), (6.0, 62.0)]
pattern1 = [occ1, occ2]
occ1 = [(10.5, 67.0), (11.0, 67.0), (11.5, 67.0), (12.0, 64.0),
(12.5, 69.0), (13.0, 69.0), (13.5, 69.0), (14.0, 67.0),
(14.5, 76.0), (15.0, 76.0), (15.5, 76.0), (16.0, 72.0)]
occ2 = [(18.5, 67.0), (19.0, 67.0), (19.5, 67.0), (20.0, 62.0),
(20.5, 69.0), (21.0, 69.0), (21.5, 69.0), (22.0, 67.0),
(22.5, 77.0), (23.0, 77.0), (23.5, 77.0), (24.0, 74.0)]
pattern2 = [occ1, occ2]
pattern_list = [pattern1, pattern2]
"""
# List with all the patterns
pattern_list = []
# Current pattern, which will contain all occs
pattern = []
# Current occurrence, containing (onset, midi)
occurrence = []
with _open(filename, mode='r') as input_file:
for line in input_file.readlines():
if "pattern" in line:
if occurrence != []:
pattern.append(occurrence)
if pattern != []:
pattern_list.append(pattern)
occurrence = []
pattern = []
continue
if "occurrence" in line:
if occurrence != []:
pattern.append(occurrence)
occurrence = []
continue
string_values = line.split(",")
onset_midi = (float(string_values[0]), float(string_values[1]))
occurrence.append(onset_midi)
# Add last occurrence and pattern to pattern_list
if occurrence != []:
pattern.append(occurrence)
if pattern != []:
pattern_list.append(pattern)
return pattern_list
def load_wav(path, mono=True):
"""Loads a .wav file as a numpy array using ``scipy.io.wavfile``.
Parameters
----------
path : str
Path to a .wav file
mono : bool
If the provided .wav has more than one channel, it will be
converted to mono if ``mono=True``. (Default value = True)
Returns
-------
audio_data : np.ndarray
Array of audio samples, normalized to the range [-1., 1.]
fs : int
Sampling rate of the audio data
"""
fs, audio_data = scipy.io.wavfile.read(path)
# Make float in range [-1, 1]
if audio_data.dtype == 'int8':
audio_data = audio_data/float(2**8)
elif audio_data.dtype == 'int16':
audio_data = audio_data/float(2**16)
elif audio_data.dtype == 'int32':
audio_data = audio_data/float(2**24)
else:
raise ValueError('Got unexpected .wav data type '
'{}'.format(audio_data.dtype))
# Optionally convert to mono
if mono and audio_data.ndim != 1:
audio_data = audio_data.mean(axis=1)
return audio_data, fs
def load_valued_intervals(filename, delimiter=r'\s+', comment='#'):
r"""Import valued intervals from an annotation file. The file should
consist of three columns: Two consisting of numeric values corresponding to
start and end time of each interval and a third, also of numeric values,
corresponding to the value of each interval. This is primarily useful for
processing events which span a duration and have a numeric value, such as
piano-roll notes which have an onset, offset, and a pitch value.
Parameters
----------
filename : str
Path to the annotation file
delimiter : str
Separator regular expression.
By default, lines will be split by any amount of whitespace.
comment : str or None
Comment regular expression.
Any lines beginning with this string or pattern will be ignored.
Setting to `None` disables comments.
Returns
-------
intervals : np.ndarray, shape=(n_events, 2)
Array of event start and end times
values : np.ndarray, shape=(n_events,)
Array of values
"""
# Use our universal function to load in the events
starts, ends, values = load_delimited(filename, [float, float, float],
delimiter=delimiter,
comment=comment)
# Stack into an interval matrix
intervals = np.array([starts, ends]).T
# Validate them, but throw a warning in place of an error
try:
util.validate_intervals(intervals)
except ValueError as error:
warnings.warn(error.args[0])
# return values as np.ndarray
values = np.array(values)
return intervals, values
def load_key(filename, delimiter=r'\s+', comment='#'):
r"""Load key labels from an annotation file. The file should
consist of two string columns: One denoting the key scale degree
(semitone), and the other denoting the mode (major or minor). The file
should contain only one row.
Parameters
----------
filename : str
Path to the annotation file
delimiter : str
Separator regular expression.
By default, lines will be split by any amount of whitespace.
comment : str or None
Comment regular expression.
Any lines beginning with this string or pattern will be ignored.
Setting to `None` disables comments.
Returns
-------
key : str
Key label, in the form ``'(key) (mode)'``
"""
# Use our universal function to load the key and mode strings
scale, mode = load_delimited(filename, [str, str],
delimiter=delimiter,
comment=comment)
if len(scale) != 1:
raise ValueError('Key file should contain only one line.')
scale, mode = scale[0], mode[0]
# Join with a space
key_string = '{} {}'.format(scale, mode)
# Validate them, but throw a warning in place of an error
try:
key.validate_key(key_string)
except ValueError as error:
warnings.warn(error.args[0])
return key_string
def load_tempo(filename, delimiter=r'\s+', comment='#'):
r"""Load tempo estimates from an annotation file in MIREX format.
The file should consist of three numeric columns: the first two
correspond to tempo estimates (in beats-per-minute), and the third
denotes the relative confidence of the first value compared to the
second (in the range [0, 1]). The file should contain only one row.
Parameters
----------
filename : str
Path to the annotation file
delimiter : str
Separator regular expression.
By default, lines will be split by any amount of whitespace.
comment : str or None
Comment regular expression.
Any lines beginning with this string or pattern will be ignored.
Setting to `None` disables comments.
Returns
-------
tempi : np.ndarray, non-negative
The two tempo estimates
weight : float [0, 1]
The relative importance of ``tempi[0]`` compared to ``tempi[1]``
"""
# Use our universal function to load the key and mode strings
t1, t2, weight = load_delimited(filename, [float, float, float],
delimiter=delimiter,
comment=comment)
weight = weight[0]
tempi = np.concatenate([t1, t2])
if len(t1) != 1:
raise ValueError('Tempo file should contain only one line.')
# Validate them, but throw a warning in place of an error
try:
tempo.validate_tempi(tempi)
except ValueError as error:
warnings.warn(error.args[0])
if not 0 <= weight <= 1:
raise ValueError('Invalid weight: {}'.format(weight))
return tempi, weight
def load_ragged_time_series(filename, dtype=float, delimiter=r'\s+',
header=False, comment='#'):
r"""Utility function for loading in data from a delimited time series
annotation file with a variable number of columns.
Assumes that column 0 contains time stamps and columns 1 through n contain
values. n may be variable from time stamp to time stamp.
Examples
--------
>>> # Load a ragged list of tab-delimited multi-f0 midi notes
>>> times, vals = load_ragged_time_series('multif0.txt', dtype=int,
delimiter='\t')
>>> # Load a raggled list of space delimited multi-f0 values with a header
>>> times, vals = load_ragged_time_series('labeled_events.csv',
header=True)
Parameters
----------
filename : str
Path to the annotation file
dtype : function
Data type to apply to values columns.
delimiter : str
Separator regular expression.
By default, lines will be split by any amount of whitespace.
header : bool
Indicates whether a header row is present or not.
By default, assumes no header is present.
comment : str or None
Comment regular expression.
Any lines beginning with this string or pattern will be ignored.
Setting to `None` disables comments.
Returns
-------
times : np.ndarray
array of timestamps (float)
values : list of np.ndarray
list of arrays of corresponding values
"""
# Initialize empty lists
times = []
values = []
# Create re object for splitting lines
splitter = re.compile(delimiter)
# And one for comments
if comment is None:
commenter = None
else:
commenter = re.compile('^{}'.format(comment))
if header:
start_row = 1
else:
start_row = 0
with _open(filename, mode='r') as input_file:
for row, line in enumerate(input_file, start_row):
# If this is a comment line, skip it
if comment is not None and commenter.match(line):
continue
# Split each line using the supplied delimiter
data = splitter.split(line.strip())
try:
converted_time = float(data[0])
except (TypeError, ValueError) as exe:
six.raise_from(ValueError("Couldn't convert value {} using {} "
"found at {}:{:d}:\n\t{}".format(
data[0], float.__name__,
filename, row, line)), exe)
times.append(converted_time)
# cast values to a numpy array. time stamps with no values are cast
# to an empty array.
try:
converted_value = np.array(data[1:], dtype=dtype)
except (TypeError, ValueError) as exe:
six.raise_from(ValueError("Couldn't convert value {} using {} "
"found at {}:{:d}:\n\t{}".format(
data[1:], dtype.__name__,
filename, row, line)), exe)
values.append(converted_value)
return np.array(times), values
mir_eval-0.7/mir_eval/key.py 0000664 0000000 0000000 00000015311 14203260312 0016110 0 ustar 00root root 0000000 0000000 '''
Key Detection involves determining the underlying key (distribution of notes
and note transitions) in a piece of music. Key detection algorithms are
evaluated by comparing their estimated key to a ground-truth reference key and
reporting a score according to the relationship of the keys.
Conventions
-----------
Keys are represented as strings of the form ``'(key) (mode)'``, e.g. ``'C#
major'`` or ``'Fb minor'``. The case of the key is ignored. Note that certain
key strings are equivalent, e.g. ``'C# major'`` and ``'Db major'``. The mode
may only be specified as either ``'major'`` or ``'minor'``, no other mode
strings will be accepted.
Metrics
-------
* :func:`mir_eval.key.weighted_score`: Heuristic scoring of the relation of two
keys.
'''
import collections
from . import util
KEY_TO_SEMITONE = {'c': 0, 'c#': 1, 'db': 1, 'd': 2, 'd#': 3, 'eb': 3, 'e': 4,
'f': 5, 'f#': 6, 'gb': 6, 'g': 7, 'g#': 8, 'ab': 8, 'a': 9,
'a#': 10, 'bb': 10, 'b': 11, 'x': None}
def validate_key(key):
"""Checks that a key is well-formatted, e.g. in the form ``'C# major'``.
The Key can be 'X' if it is not possible to categorize the Key and mode
can be 'other' if it can't be categorized as major or minor.
Parameters
----------
key : str
Key to verify
"""
if len(key.split()) != 2 \
and not (len(key.split()) and key.lower() == 'x'):
raise ValueError("'{}' is not in the form '(key) (mode)' "
"or 'X'".format(key))
if key.lower() != 'x':
key, mode = key.split()
if key.lower() == 'x':
raise ValueError(
"Mode {} is invalid; 'X' (Uncategorized) "
"doesn't have mode".format(mode))
if key.lower() not in KEY_TO_SEMITONE:
raise ValueError(
"Key {} is invalid; should be e.g. D or C# or Eb or "
"X (Uncategorized)".format(key))
if mode not in ['major', 'minor', 'other']:
raise ValueError(
"Mode '{}' is invalid; must be 'major', 'minor' or 'other'"
.format(mode))
def validate(reference_key, estimated_key):
"""Checks that the input annotations to a metric are valid key strings and
throws helpful errors if not.
Parameters
----------
reference_key : str
Reference key string.
estimated_key : str
Estimated key string.
"""
for key in [reference_key, estimated_key]:
validate_key(key)
def split_key_string(key):
"""Splits a key string (of the form, e.g. ``'C# major'``), into a tuple of
``(key, mode)`` where ``key`` is is an integer representing the semitone
distance from C.
Parameters
----------
key : str
String representing a key.
Returns
-------
key : int
Number of semitones above C.
mode : str
String representing the mode.
"""
if key.lower() != 'x':
key, mode = key.split()
else:
mode = None
return KEY_TO_SEMITONE[key.lower()], mode
def weighted_score(reference_key, estimated_key):
"""Computes a heuristic score which is weighted according to the
relationship of the reference and estimated key, as follows:
+------------------------------------------------------+-------+
| Relationship | Score |
+------------------------------------------------------+-------+
| Same key and mode | 1.0 |
+------------------------------------------------------+-------+
| Estimated key is a perfect fifth above reference key | 0.5 |
+------------------------------------------------------+-------+
| Relative major/minor (same key signature) | 0.3 |
+------------------------------------------------------+-------+
| Parallel major/minor (same key) | 0.2 |
+------------------------------------------------------+-------+
| Other | 0.0 |
+------------------------------------------------------+-------+
Examples
--------
>>> ref_key = mir_eval.io.load_key('ref.txt')
>>> est_key = mir_eval.io.load_key('est.txt')
>>> score = mir_eval.key.weighted_score(ref_key, est_key)
Parameters
----------
reference_key : str
Reference key string.
estimated_key : str
Estimated key string.
Returns
-------
score : float
Score representing how closely related the keys are.
"""
validate(reference_key, estimated_key)
reference_key, reference_mode = split_key_string(reference_key)
estimated_key, estimated_mode = split_key_string(estimated_key)
# If keys are the same, return 1.
if reference_key == estimated_key and reference_mode == estimated_mode:
return 1.
# If reference or estimated key are x and they are not the same key
# then the result is 'Other'.
if reference_key is None or estimated_key is None:
return 0.
# If keys are the same mode and a perfect fifth (differ by 7 semitones)
if (estimated_mode == reference_mode and
(estimated_key - reference_key) % 12 == 7):
return 0.5
# Estimated key is relative minor of reference key (9 semitones)
if (estimated_mode != reference_mode == 'major' and
(estimated_key - reference_key) % 12 == 9):
return 0.3
# Estimated key is relative major of reference key (3 semitones)
if (estimated_mode != reference_mode == 'minor' and
(estimated_key - reference_key) % 12 == 3):
return 0.3
# If keys are in different modes and parallel (same key name)
if estimated_mode != reference_mode and reference_key == estimated_key:
return 0.2
# Otherwise return 0
return 0.
def evaluate(reference_key, estimated_key, **kwargs):
"""Compute all metrics for the given reference and estimated annotations.
Examples
--------
>>> ref_key = mir_eval.io.load_key('reference.txt')
>>> est_key = mir_eval.io.load_key('estimated.txt')
>>> scores = mir_eval.key.evaluate(ref_key, est_key)
Parameters
----------
ref_key : str
Reference key string.
ref_key : str
Estimated key string.
kwargs
Additional keyword arguments which will be passed to the
appropriate metric or preprocessing functions.
Returns
-------
scores : dict
Dictionary of scores, where the key is the metric name (str) and
the value is the (float) score achieved.
"""
# Compute all metrics
scores = collections.OrderedDict()
scores['Weighted Score'] = util.filter_kwargs(
weighted_score, reference_key, estimated_key)
return scores
mir_eval-0.7/mir_eval/melody.py 0000664 0000000 0000000 00000077551 14203260312 0016627 0 ustar 00root root 0000000 0000000 # CREATED:2014-03-07 by Justin Salamon
'''
Melody extraction algorithms aim to produce a sequence of frequency values
corresponding to the pitch of the dominant melody from a musical
recording. For evaluation, an estimated pitch series is evaluated against a
reference based on whether the voicing (melody present or not) and the pitch
is correct (within some tolerance).
For a detailed explanation of the measures please refer to:
J. Salamon, E. Gomez, D. P. W. Ellis and G. Richard, "Melody Extraction
from Polyphonic Music Signals: Approaches, Applications and Challenges",
IEEE Signal Processing Magazine, 31(2):118-134, Mar. 2014.
and:
G. E. Poliner, D. P. W. Ellis, A. F. Ehmann, E. Gomez, S.
Streich, and B. Ong. "Melody transcription from music audio:
Approaches and evaluation", IEEE Transactions on Audio, Speech, and
Language Processing, 15(4):1247-1256, 2007.
For an explanation of the generalized measures (using non-binary voicings),
please refer to:
R. Bittner and J. Bosch, "Generalized Metrics for Single-F0 Estimation
Evaluation", International Society for Music Information Retrieval
Conference (ISMIR), 2019.
Conventions
-----------
Melody annotations are assumed to be given in the format of a 1d array of
frequency values which are accompanied by a 1d array of times denoting when
each frequency value occurs. In a reference melody time series, a frequency
value of 0 denotes "unvoiced". In a estimated melody time series, unvoiced
frames can be indicated either by 0 Hz or by a negative Hz value - negative
values represent the algorithm's pitch estimate for frames it has determined as
unvoiced, in case they are in fact voiced.
Metrics are computed using a sequence of reference and estimated pitches in
cents and voicing arrays, both of which are sampled to the same
timebase. The function :func:`mir_eval.melody.to_cent_voicing` can be used to
convert a sequence of estimated and reference times and frequency values in Hz
to voicing arrays and frequency arrays in the format required by the
metric functions. By default, the convention is to resample the estimated
melody time series to the reference melody time series' timebase.
Metrics
-------
* :func:`mir_eval.melody.voicing_measures`: Voicing measures, including the
recall rate (proportion of frames labeled as melody frames in the reference
that are estimated as melody frames) and the false alarm
rate (proportion of frames labeled as non-melody in the reference that are
mistakenly estimated as melody frames)
* :func:`mir_eval.melody.raw_pitch_accuracy`: Raw Pitch Accuracy, which
computes the proportion of melody frames in the reference for which the
frequency is considered correct (i.e. within half a semitone of the reference
frequency)
* :func:`mir_eval.melody.raw_chroma_accuracy`: Raw Chroma Accuracy, where the
estimated and reference frequency sequences are mapped onto a single octave
before computing the raw pitch accuracy
* :func:`mir_eval.melody.overall_accuracy`: Overall Accuracy, which computes
the proportion of all frames correctly estimated by the algorithm, including
whether non-melody frames where labeled by the algorithm as non-melody
'''
import numpy as np
import scipy.interpolate
import collections
import warnings
from . import util
def validate_voicing(ref_voicing, est_voicing):
"""Checks that voicing inputs to a metric are in the correct format.
Parameters
----------
ref_voicing : np.ndarray
Reference voicing array
est_voicing : np.ndarray
Estimated voicing array
"""
if ref_voicing.size == 0:
warnings.warn("Reference voicing array is empty.")
if est_voicing.size == 0:
warnings.warn("Estimated voicing array is empty.")
if ref_voicing.sum() == 0:
warnings.warn("Reference melody has no voiced frames.")
if est_voicing.sum() == 0:
warnings.warn("Estimated melody has no voiced frames.")
# Make sure they're the same length
if ref_voicing.shape[0] != est_voicing.shape[0]:
raise ValueError('Reference and estimated voicing arrays should '
'be the same length.')
for voicing in [ref_voicing, est_voicing]:
# Make sure voicing is between 0 and 1
if np.logical_or(voicing < 0, voicing > 1).any():
raise ValueError('Voicing arrays must be between 0 and 1.')
def validate(ref_voicing, ref_cent, est_voicing, est_cent):
"""Checks that voicing and frequency arrays are well-formed. To be used in
conjunction with :func:`mir_eval.melody.validate_voicing`
Parameters
----------
ref_voicing : np.ndarray
Reference voicing array
ref_cent : np.ndarray
Reference pitch sequence in cents
est_voicing : np.ndarray
Estimated voicing array
est_cent : np.ndarray
Estimate pitch sequence in cents
"""
if ref_cent.size == 0:
warnings.warn("Reference frequency array is empty.")
if est_cent.size == 0:
warnings.warn("Estimated frequency array is empty.")
# Make sure they're the same length
if ref_voicing.shape[0] != ref_cent.shape[0] or \
est_voicing.shape[0] != est_cent.shape[0] or \
ref_cent.shape[0] != est_cent.shape[0]:
raise ValueError('All voicing and frequency arrays must have the '
'same length.')
def hz2cents(freq_hz, base_frequency=10.0):
"""Convert an array of frequency values in Hz to cents.
0 values are left in place.
Parameters
----------
freq_hz : np.ndarray
Array of frequencies in Hz.
base_frequency : float
Base frequency for conversion.
(Default value = 10.0)
Returns
-------
freq_cent : np.ndarray
Array of frequencies in cents, relative to base_frequency
"""
freq_cent = np.zeros(freq_hz.shape[0])
freq_nonz_ind = np.flatnonzero(freq_hz)
normalized_frequency = np.abs(freq_hz[freq_nonz_ind]) / base_frequency
freq_cent[freq_nonz_ind] = 1200.0 * np.log2(normalized_frequency)
return freq_cent
def freq_to_voicing(frequencies, voicing=None):
"""Convert from an array of frequency values to frequency array +
voice/unvoiced array
Parameters
----------
frequencies : np.ndarray
Array of frequencies. A frequency <= 0 indicates "unvoiced".
voicing : np.ndarray
Array of voicing values.
(Default value = None)
Default None, which means the voicing is inferred from `frequencies`:
frames with frequency <= 0.0 are considered "unvoiced"
frames with frequency > 0.0 are considered "voiced"
If specified, `voicing` is used as the voicing array, but
frequencies with value 0 are forced to have 0 voicing.
Voicing inferred by negative frequency values is ignored.
Returns
-------
frequencies : np.ndarray
Array of frequencies, all >= 0.
voiced : np.ndarray
Array of voicings between 0 and 1, same length as frequencies,
which indicates voiced or unvoiced
"""
if voicing is not None:
voicing[frequencies == 0] = 0
else:
voicing = (frequencies > 0).astype(float)
return np.abs(frequencies), voicing
def constant_hop_timebase(hop, end_time):
"""Generates a time series from 0 to ``end_time`` with times spaced ``hop``
apart
Parameters
----------
hop : float
Spacing of samples in the time series
end_time : float
Time series will span ``[0, end_time]``
Returns
-------
times : np.ndarray
Generated timebase
"""
# Compute new timebase. Rounding/linspace is to avoid float problems.
end_time = np.round(end_time, 10)
times = np.linspace(0, hop * int(np.floor(end_time / hop)),
int(np.floor(end_time / hop)) + 1)
times = np.round(times, 10)
return times
def resample_melody_series(times, frequencies, voicing,
times_new, kind='linear'):
"""Resamples frequency and voicing time series to a new timescale. Maintains
any zero ("unvoiced") values in frequencies.
If ``times`` and ``times_new`` are equivalent, no resampling will be
performed.
Parameters
----------
times : np.ndarray
Times of each frequency value
frequencies : np.ndarray
Array of frequency values, >= 0
voicing : np.ndarray
Array which indicates voiced or unvoiced. This array may be binary
or have continuous values between 0 and 1.
times_new : np.ndarray
Times to resample frequency and voicing sequences to
kind : str
kind parameter to pass to scipy.interpolate.interp1d.
(Default value = 'linear')
Returns
-------
frequencies_resampled : np.ndarray
Frequency array resampled to new timebase
voicing_resampled : np.ndarray
Voicing array resampled to new timebase
"""
# If the timebases are already the same, no need to interpolate
if times.shape == times_new.shape and np.allclose(times, times_new):
return frequencies, voicing
# Warn when the delta between the original times is not constant,
# unless times[0] == 0. and frequencies[0] == frequencies[1] (see logic at
# the beginning of to_cent_voicing)
if not (np.allclose(np.diff(times), np.diff(times).mean()) or
(np.allclose(np.diff(times[1:]), np.diff(times[1:]).mean()) and
frequencies[0] == frequencies[1])):
warnings.warn(
"Non-uniform timescale passed to resample_melody_series. Pitch "
"will be linearly interpolated, which will result in undesirable "
"behavior if silences are indicated by missing values. Silences "
"should be indicated by nonpositive frequency values.")
# Round to avoid floating point problems
times = np.round(times, 10)
times_new = np.round(times_new, 10)
# Add in an additional sample if we'll be asking for a time too large
if times_new.max() > times.max():
times = np.append(times, times_new.max())
frequencies = np.append(frequencies, 0)
voicing = np.append(voicing, 0)
# We need to fix zero transitions if interpolation is not zero or nearest
if kind != 'zero' and kind != 'nearest':
# Fill in zero values with the last reported frequency
# to avoid erroneous values when resampling
frequencies_held = np.array(frequencies)
for n, frequency in enumerate(frequencies[1:]):
if frequency == 0:
frequencies_held[n + 1] = frequencies_held[n]
# Linearly interpolate frequencies
frequencies_resampled = scipy.interpolate.interp1d(times,
frequencies_held,
kind)(times_new)
# Retain zeros
frequency_mask = scipy.interpolate.interp1d(times,
frequencies,
'zero')(times_new)
frequencies_resampled *= (frequency_mask != 0)
else:
frequencies_resampled = scipy.interpolate.interp1d(times,
frequencies,
kind)(times_new)
# Use nearest-neighbor for voicing if it was used for frequencies
# if voicing is not binary, use linear interpolation
is_binary_voicing = np.all(
np.logical_or(np.equal(voicing, 0), np.equal(voicing, 1)))
if kind == 'nearest' or (kind == 'linear' and not is_binary_voicing):
voicing_resampled = scipy.interpolate.interp1d(times,
voicing,
kind)(times_new)
# otherwise, always use zeroth order
else:
voicing_resampled = scipy.interpolate.interp1d(times,
voicing,
'zero')(times_new)
return frequencies_resampled, voicing_resampled
def to_cent_voicing(ref_time, ref_freq, est_time, est_freq,
est_voicing=None, ref_reward=None, base_frequency=10.,
hop=None, kind='linear'):
"""Converts reference and estimated time/frequency (Hz) annotations to sampled
frequency (cent)/voicing arrays.
A zero frequency indicates "unvoiced".
If est_voicing is not provided, a negative frequency indicates:
"Predicted as unvoiced, but if it's voiced,
this is the frequency estimate".
If it is provided, negative frequency values are ignored, and the voicing
from est_voicing is directly used.
Parameters
----------
ref_time : np.ndarray
Time of each reference frequency value
ref_freq : np.ndarray
Array of reference frequency values
est_time : np.ndarray
Time of each estimated frequency value
est_freq : np.ndarray
Array of estimated frequency values
est_voicing : np.ndarray
Estimate voicing confidence.
Default None, which means the voicing is inferred from est_freq:
frames with frequency <= 0.0 are considered "unvoiced"
frames with frequency > 0.0 are considered "voiced"
ref_reward : np.ndarray
Reference voicing reward.
Default None, which means all frames are weighted equally.
base_frequency : float
Base frequency in Hz for conversion to cents
(Default value = 10.)
hop : float
Hop size, in seconds, to resample,
default None which means use ref_time
kind : str
kind parameter to pass to scipy.interpolate.interp1d.
(Default value = 'linear')
Returns
-------
ref_voicing : np.ndarray
Resampled reference voicing array
ref_cent : np.ndarray
Resampled reference frequency (cent) array
est_voicing : np.ndarray
Resampled estimated voicing array
est_cent : np.ndarray
Resampled estimated frequency (cent) array
"""
# Check if missing sample at time 0 and if so add one
if ref_time[0] > 0:
ref_time = np.insert(ref_time, 0, 0)
ref_freq = np.insert(ref_freq, 0, ref_freq[0])
if ref_reward is not None:
ref_reward = np.insert(ref_reward, 0, ref_reward[0])
if est_time[0] > 0:
est_time = np.insert(est_time, 0, 0)
est_freq = np.insert(est_freq, 0, est_freq[0])
if est_voicing is not None:
est_voicing = np.insert(est_voicing, 0, est_voicing[0])
# Get separated frequency array and voicing array
ref_freq, ref_voicing = freq_to_voicing(ref_freq, ref_reward)
est_freq, est_voicing = freq_to_voicing(est_freq, est_voicing)
# convert both sequences to cents
ref_cent = hz2cents(ref_freq, base_frequency)
est_cent = hz2cents(est_freq, base_frequency)
# If we received a hop, use it to resample both
if hop is not None:
# Resample to common time base
ref_cent, ref_voicing = resample_melody_series(
ref_time, ref_cent, ref_voicing,
constant_hop_timebase(hop, ref_time.max()), kind)
est_cent, est_voicing = resample_melody_series(
est_time, est_cent, est_voicing,
constant_hop_timebase(hop, est_time.max()), kind)
# Otherwise, only resample estimated to the reference time base
else:
est_cent, est_voicing = resample_melody_series(
est_time, est_cent, est_voicing, ref_time, kind)
# ensure the estimated sequence is the same length as the reference
len_diff = ref_cent.shape[0] - est_cent.shape[0]
if len_diff >= 0:
est_cent = np.append(est_cent, np.zeros(len_diff))
est_voicing = np.append(est_voicing, np.zeros(len_diff))
else:
est_cent = est_cent[:ref_cent.shape[0]]
est_voicing = est_voicing[:ref_voicing.shape[0]]
return (ref_voicing, ref_cent, est_voicing, est_cent)
def voicing_recall(ref_voicing, est_voicing):
"""Compute the voicing recall given two voicing
indicator sequences, one as reference (truth) and the other as the estimate
(prediction). The sequences must be of the same length.
Examples
--------
>>> ref_time, ref_freq = mir_eval.io.load_time_series('ref.txt')
>>> est_time, est_freq = mir_eval.io.load_time_series('est.txt')
>>> (ref_v, ref_c,
... est_v, est_c) = mir_eval.melody.to_cent_voicing(ref_time,
... ref_freq,
... est_time,
... est_freq)
>>> recall = mir_eval.melody.voicing_recall(ref_v, est_v)
Parameters
----------
ref_voicing : np.ndarray
Reference boolean voicing array
est_voicing : np.ndarray
Estimated boolean voicing array
Returns
-------
vx_recall : float
Voicing recall rate, the fraction of voiced frames in ref
indicated as voiced in est
"""
if ref_voicing.size == 0 or est_voicing.size == 0:
return 0.
ref_indicator = (ref_voicing > 0).astype(float)
if np.sum(ref_indicator) == 0:
return 1
return np.sum(est_voicing * ref_indicator) / np.sum(ref_indicator)
def voicing_false_alarm(ref_voicing, est_voicing):
"""Compute the voicing false alarm rates given two voicing
indicator sequences, one as reference (truth) and the other as the estimate
(prediction). The sequences must be of the same length.
Examples
--------
>>> ref_time, ref_freq = mir_eval.io.load_time_series('ref.txt')
>>> est_time, est_freq = mir_eval.io.load_time_series('est.txt')
>>> (ref_v, ref_c,
... est_v, est_c) = mir_eval.melody.to_cent_voicing(ref_time,
... ref_freq,
... est_time,
... est_freq)
>>> false_alarm = mir_eval.melody.voicing_false_alarm(ref_v, est_v)
Parameters
----------
ref_voicing : np.ndarray
Reference boolean voicing array
est_voicing : np.ndarray
Estimated boolean voicing array
Returns
-------
vx_false_alarm : float
Voicing false alarm rate, the fraction of unvoiced frames in ref
indicated as voiced in est
"""
if ref_voicing.size == 0 or est_voicing.size == 0:
return 0.
ref_indicator = (ref_voicing == 0).astype(float)
if np.sum(ref_indicator) == 0:
return 0
return np.sum(est_voicing * ref_indicator) / np.sum(ref_indicator)
def voicing_measures(ref_voicing, est_voicing):
"""Compute the voicing recall and false alarm rates given two voicing
indicator sequences, one as reference (truth) and the other as the estimate
(prediction). The sequences must be of the same length.
Examples
--------
>>> ref_time, ref_freq = mir_eval.io.load_time_series('ref.txt')
>>> est_time, est_freq = mir_eval.io.load_time_series('est.txt')
>>> (ref_v, ref_c,
... est_v, est_c) = mir_eval.melody.to_cent_voicing(ref_time,
... ref_freq,
... est_time,
... est_freq)
>>> recall, false_alarm = mir_eval.melody.voicing_measures(ref_v,
... est_v)
Parameters
----------
ref_voicing : np.ndarray
Reference boolean voicing array
est_voicing : np.ndarray
Estimated boolean voicing array
Returns
-------
vx_recall : float
Voicing recall rate, the fraction of voiced frames in ref
indicated as voiced in est
vx_false_alarm : float
Voicing false alarm rate, the fraction of unvoiced frames in ref
indicated as voiced in est
"""
validate_voicing(ref_voicing, est_voicing)
vx_recall = voicing_recall(ref_voicing, est_voicing)
vx_false_alm = voicing_false_alarm(ref_voicing, est_voicing)
return vx_recall, vx_false_alm
def raw_pitch_accuracy(ref_voicing, ref_cent, est_voicing, est_cent,
cent_tolerance=50):
"""Compute the raw pitch accuracy given two pitch (frequency) sequences in
cents and matching voicing indicator sequences. The first pitch and voicing
arrays are treated as the reference (truth), and the second two as the
estimate (prediction). All 4 sequences must be of the same length.
Examples
--------
>>> ref_time, ref_freq = mir_eval.io.load_time_series('ref.txt')
>>> est_time, est_freq = mir_eval.io.load_time_series('est.txt')
>>> (ref_v, ref_c,
... est_v, est_c) = mir_eval.melody.to_cent_voicing(ref_time,
... ref_freq,
... est_time,
... est_freq)
>>> raw_pitch = mir_eval.melody.raw_pitch_accuracy(ref_v, ref_c,
... est_v, est_c)
Parameters
----------
ref_voicing : np.ndarray
Reference voicing array. When this array is non-binary, it is treated
as a 'reference reward', as in (Bittner & Bosch, 2019)
ref_cent : np.ndarray
Reference pitch sequence in cents
est_voicing : np.ndarray
Estimated voicing array
est_cent : np.ndarray
Estimate pitch sequence in cents
cent_tolerance : float
Maximum absolute deviation in cents for a frequency value to be
considered correct
(Default value = 50)
Returns
-------
raw_pitch : float
Raw pitch accuracy, the fraction of voiced frames in ref_cent for
which est_cent provides a correct frequency values
(within cent_tolerance cents).
"""
validate_voicing(ref_voicing, est_voicing)
validate(ref_voicing, ref_cent, est_voicing, est_cent)
# When input arrays are empty, return 0 by special case
# If there are no voiced frames in reference, metric is 0
if ref_voicing.size == 0 or ref_voicing.sum() == 0 \
or ref_cent.size == 0 or est_cent.size == 0:
return 0.
# Raw pitch = the number of voiced frames in the reference for which the
# estimate provides a correct frequency value (within cent_tolerance cents)
# NB: voicing estimation is ignored in this measure
nonzero_freqs = np.logical_and(est_cent != 0, ref_cent != 0)
if sum(nonzero_freqs) == 0:
return 0.
freq_diff_cents = np.abs(ref_cent - est_cent)[nonzero_freqs]
correct_frequencies = freq_diff_cents < cent_tolerance
rpa = (
np.sum(ref_voicing[nonzero_freqs] * correct_frequencies) /
np.sum(ref_voicing)
)
return rpa
def raw_chroma_accuracy(ref_voicing, ref_cent, est_voicing, est_cent,
cent_tolerance=50):
"""Compute the raw chroma accuracy given two pitch (frequency) sequences
in cents and matching voicing indicator sequences. The first pitch and
voicing arrays are treated as the reference (truth), and the second two as
the estimate (prediction). All 4 sequences must be of the same length.
Examples
--------
>>> ref_time, ref_freq = mir_eval.io.load_time_series('ref.txt')
>>> est_time, est_freq = mir_eval.io.load_time_series('est.txt')
>>> (ref_v, ref_c,
... est_v, est_c) = mir_eval.melody.to_cent_voicing(ref_time,
... ref_freq,
... est_time,
... est_freq)
>>> raw_chroma = mir_eval.melody.raw_chroma_accuracy(ref_v, ref_c,
... est_v, est_c)
Parameters
----------
ref_voicing : np.ndarray
Reference voicing array. When this array is non-binary, it is treated
as a 'reference reward', as in (Bittner & Bosch, 2019)
ref_cent : np.ndarray
Reference pitch sequence in cents
est_voicing : np.ndarray
Estimated voicing array
est_cent : np.ndarray
Estimate pitch sequence in cents
cent_tolerance : float
Maximum absolute deviation in cents for a frequency value to be
considered correct
(Default value = 50)
Returns
-------
raw_chroma : float
Raw chroma accuracy, the fraction of voiced frames in ref_cent for
which est_cent provides a correct frequency values (within
cent_tolerance cents), ignoring octave errors
"""
validate_voicing(ref_voicing, est_voicing)
validate(ref_voicing, ref_cent, est_voicing, est_cent)
# When input arrays are empty, return 0 by special case
# If there are no voiced frames in reference, metric is 0
if ref_voicing.size == 0 or ref_voicing.sum() == 0 \
or ref_cent.size == 0 or est_cent.size == 0:
return 0.
# # Raw chroma = same as raw pitch except that octave errors are ignored.
nonzero_freqs = np.logical_and(est_cent != 0, ref_cent != 0)
if sum(nonzero_freqs) == 0:
return 0.
freq_diff_cents = np.abs(ref_cent - est_cent)[nonzero_freqs]
octave = 1200.0 * np.floor(freq_diff_cents / 1200 + 0.5)
correct_chroma = np.abs(freq_diff_cents - octave) < cent_tolerance
rca = (
np.sum(ref_voicing[nonzero_freqs] * correct_chroma) /
np.sum(ref_voicing)
)
return rca
def overall_accuracy(ref_voicing, ref_cent, est_voicing, est_cent,
cent_tolerance=50):
"""Compute the overall accuracy given two pitch (frequency) sequences
in cents and matching voicing indicator sequences. The first pitch and
voicing arrays are treated as the reference (truth), and the second two
as the estimate (prediction). All 4 sequences must be of the same length.
Examples
--------
>>> ref_time, ref_freq = mir_eval.io.load_time_series('ref.txt')
>>> est_time, est_freq = mir_eval.io.load_time_series('est.txt')
>>> (ref_v, ref_c,
... est_v, est_c) = mir_eval.melody.to_cent_voicing(ref_time,
... ref_freq,
... est_time,
... est_freq)
>>> overall_accuracy = mir_eval.melody.overall_accuracy(ref_v, ref_c,
... est_v, est_c)
Parameters
----------
ref_voicing : np.ndarray
Reference voicing array. When this array is non-binary, it is treated
as a 'reference reward', as in (Bittner & Bosch, 2019)
ref_cent : np.ndarray
Reference pitch sequence in cents
est_voicing : np.ndarray
Estimated voicing array
est_cent : np.ndarray
Estimate pitch sequence in cents
cent_tolerance : float
Maximum absolute deviation in cents for a frequency value to be
considered correct
(Default value = 50)
Returns
-------
overall_accuracy : float
Overall accuracy, the total fraction of correctly estimates frames,
where provides a correct frequency values (within cent_tolerance).
"""
validate_voicing(ref_voicing, est_voicing)
validate(ref_voicing, ref_cent, est_voicing, est_cent)
# When input arrays are empty, return 0 by special case
if ref_voicing.size == 0 or est_voicing.size == 0 \
or ref_cent.size == 0 or est_cent.size == 0:
return 0.
nonzero_freqs = np.logical_and(est_cent != 0, ref_cent != 0)
freq_diff_cents = np.abs(ref_cent - est_cent)[nonzero_freqs]
correct_frequencies = freq_diff_cents < cent_tolerance
ref_binary = (ref_voicing > 0).astype(float)
n_frames = float(len(ref_voicing))
if np.sum(ref_voicing) == 0:
ratio = 0.0
else:
ratio = (np.sum(ref_binary) / np.sum(ref_voicing))
accuracy = (
(
ratio * np.sum(ref_voicing[nonzero_freqs] *
est_voicing[nonzero_freqs] *
correct_frequencies)
) +
np.sum((1.0 - ref_binary) * (1.0 - est_voicing))
) / n_frames
return accuracy
def evaluate(ref_time, ref_freq, est_time, est_freq,
est_voicing=None, ref_reward=None, **kwargs):
"""Evaluate two melody (predominant f0) transcriptions, where the first is
treated as the reference (ground truth) and the second as the estimate to
be evaluated (prediction).
Examples
--------
>>> ref_time, ref_freq = mir_eval.io.load_time_series('ref.txt')
>>> est_time, est_freq = mir_eval.io.load_time_series('est.txt')
>>> scores = mir_eval.melody.evaluate(ref_time, ref_freq,
... est_time, est_freq)
Parameters
----------
ref_time : np.ndarray
Time of each reference frequency value
ref_freq : np.ndarray
Array of reference frequency values
est_time : np.ndarray
Time of each estimated frequency value
est_freq : np.ndarray
Array of estimated frequency values
est_voicing : np.ndarray
Estimate voicing confidence.
Default None, which means the voicing is inferred from est_freq:
frames with frequency <= 0.0 are considered "unvoiced"
frames with frequency > 0.0 are considered "voiced"
ref_reward : np.ndarray
Reference pitch estimation reward.
Default None, which means all frames are weighted equally.
kwargs
Additional keyword arguments which will be passed to the
appropriate metric or preprocessing functions.
Returns
-------
scores : dict
Dictionary of scores, where the key is the metric name (str) and
the value is the (float) score achieved.
References
----------
.. [#] J. Salamon, E. Gomez, D. P. W. Ellis and G. Richard, "Melody
Extraction from Polyphonic Music Signals: Approaches, Applications
and Challenges", IEEE Signal Processing Magazine, 31(2):118-134,
Mar. 2014.
.. [#] G. E. Poliner, D. P. W. Ellis, A. F. Ehmann, E. Gomez, S.
Streich, and B. Ong. "Melody transcription from music audio:
Approaches and evaluation", IEEE Transactions on Audio, Speech, and
Language Processing, 15(4):1247-1256, 2007.
.. [#] R. Bittner and J. Bosch, "Generalized Metrics for Single-F0
Estimation Evaluation", International Society for Music Information
Retrieval Conference (ISMIR), 2019.
"""
# Convert to reference/estimated voicing/frequency (cent) arrays
(ref_voicing, ref_cent,
est_voicing, est_cent) = util.filter_kwargs(
to_cent_voicing, ref_time, ref_freq, est_time, est_freq,
est_voicing, ref_reward, **kwargs)
# Compute metrics
scores = collections.OrderedDict()
scores['Voicing Recall'] = util.filter_kwargs(voicing_recall,
ref_voicing,
est_voicing, **kwargs)
scores['Voicing False Alarm'] = util.filter_kwargs(voicing_false_alarm,
ref_voicing,
est_voicing, **kwargs)
scores['Raw Pitch Accuracy'] = util.filter_kwargs(raw_pitch_accuracy,
ref_voicing, ref_cent,
est_voicing, est_cent,
**kwargs)
scores['Raw Chroma Accuracy'] = util.filter_kwargs(raw_chroma_accuracy,
ref_voicing, ref_cent,
est_voicing, est_cent,
**kwargs)
scores['Overall Accuracy'] = util.filter_kwargs(overall_accuracy,
ref_voicing, ref_cent,
est_voicing, est_cent,
**kwargs)
return scores
mir_eval-0.7/mir_eval/multipitch.py 0000664 0000000 0000000 00000040641 14203260312 0017506 0 ustar 00root root 0000000 0000000 '''
The goal of multiple f0 (multipitch) estimation and tracking is to identify
all of the active fundamental frequencies in each time frame in a complex music
signal.
Conventions
-----------
Multipitch estimates are represented by a timebase and a corresponding list
of arrays of frequency estimates. Frequency estimates may have any number of
frequency values, including 0 (represented by an empty array). Time values are
in units of seconds and frequency estimates are in units of Hz.
The timebase of the estimate time series should ideally match the timebase of
the reference time series, but if this is not the case, the estimate time
series is resampled using a nearest neighbor interpolation to match the
estimate. Time values in the estimate time series that are outside of the range
of the reference time series are given null (empty array) frequencies.
By default, a frequency is "correct" if it is within 0.5 semitones of a
reference frequency. Frequency values are compared by first mapping them to
log-2 semitone space, where the distance between semitones is constant.
Chroma-wrapped frequency values are computed by taking the log-2 frequency
values modulo 12 to map them down to a single octave. A chroma-wrapped
frequency estimate is correct if it's single-octave value is within 0.5
semitones of the single-octave reference frequency.
The metrics are based on those described in
[#poliner2007]_ and [#bay2009]_.
Metrics
-------
* :func:`mir_eval.multipitch.metrics`: Precision, Recall, Accuracy,
Substitution, Miss, False Alarm, and Total Error scores based both on raw
frequency values and values mapped to a single octave (chroma).
References
----------
.. [#poliner2007] G. E. Poliner, and D. P. W. Ellis, "A Discriminative
Model for Polyphonic Piano Transription", EURASIP Journal on Advances in
Signal Processing, 2007(1):154-163, Jan. 2007.
.. [#bay2009] Bay, M., Ehmann, A. F., & Downie, J. S. (2009). Evaluation of
Multiple-F0 Estimation and Tracking Systems. In ISMIR (pp. 315-320).
'''
import numpy as np
import collections
import scipy.interpolate
from . import util
import warnings
MAX_TIME = 30000. # The maximum allowable time stamp (seconds)
MAX_FREQ = 5000. # The maximum allowable frequency (Hz)
MIN_FREQ = 20. # The minimum allowable frequency (Hz)
def validate(ref_time, ref_freqs, est_time, est_freqs):
"""Checks that the time and frequency inputs are well-formed.
Parameters
----------
ref_time : np.ndarray
reference time stamps in seconds
ref_freqs : list of np.ndarray
reference frequencies in Hz
est_time : np.ndarray
estimate time stamps in seconds
est_freqs : list of np.ndarray
estimated frequencies in Hz
"""
util.validate_events(ref_time, max_time=MAX_TIME)
util.validate_events(est_time, max_time=MAX_TIME)
if ref_time.size == 0:
warnings.warn("Reference times are empty.")
if ref_time.ndim != 1:
raise ValueError("Reference times have invalid dimension")
if len(ref_freqs) == 0:
warnings.warn("Reference frequencies are empty.")
if est_time.size == 0:
warnings.warn("Estimated times are empty.")
if est_time.ndim != 1:
raise ValueError("Estimated times have invalid dimension")
if len(est_freqs) == 0:
warnings.warn("Estimated frequencies are empty.")
if ref_time.size != len(ref_freqs):
raise ValueError('Reference times and frequencies have unequal '
'lengths.')
if est_time.size != len(est_freqs):
raise ValueError('Estimate times and frequencies have unequal '
'lengths.')
for freq in ref_freqs:
util.validate_frequencies(freq, max_freq=MAX_FREQ, min_freq=MIN_FREQ,
allow_negatives=False)
for freq in est_freqs:
util.validate_frequencies(freq, max_freq=MAX_FREQ, min_freq=MIN_FREQ,
allow_negatives=False)
def resample_multipitch(times, frequencies, target_times):
"""Resamples multipitch time series to a new timescale. Values in
``target_times`` outside the range of ``times`` return no pitch estimate.
Parameters
----------
times : np.ndarray
Array of time stamps
frequencies : list of np.ndarray
List of np.ndarrays of frequency values
target_times : np.ndarray
Array of target time stamps
Returns
-------
frequencies_resampled : list of numpy arrays
Frequency list of lists resampled to new timebase
"""
if target_times.size == 0:
return []
if times.size == 0:
return [np.array([])]*len(target_times)
n_times = len(frequencies)
# scipy's interpolate doesn't handle ragged arrays. Instead, we interpolate
# the frequency index and then map back to the frequency values.
# This only works because we're using a nearest neighbor interpolator!
frequency_index = np.arange(0, n_times)
# times are already ordered so assume_sorted=True for efficiency
# since we're interpolating the index, fill_value is set to the first index
# that is out of range. We handle this in the next line.
new_frequency_index = scipy.interpolate.interp1d(
times, frequency_index, kind='nearest', bounds_error=False,
assume_sorted=True, fill_value=n_times)(target_times)
# create array of frequencies plus additional empty element at the end for
# target time stamps that are out of the interpolation range
freq_vals = frequencies + [np.array([])]
# map interpolated indices back to frequency values
frequencies_resampled = [
freq_vals[i] for i in new_frequency_index.astype(int)]
return frequencies_resampled
def frequencies_to_midi(frequencies, ref_frequency=440.0):
"""Converts frequencies to continuous MIDI values.
Parameters
----------
frequencies : list of np.ndarray
Original frequency values
ref_frequency : float
reference frequency in Hz.
Returns
-------
frequencies_midi : list of np.ndarray
Continuous MIDI frequency values.
"""
return [69.0 + 12.0*np.log2(freqs/ref_frequency) for freqs in frequencies]
def midi_to_chroma(frequencies_midi):
"""Wrap MIDI frequencies to a single octave (chroma).
Parameters
----------
frequencies_midi : list of np.ndarray
Continuous MIDI note frequency values.
Returns
-------
frequencies_chroma : list of np.ndarray
Midi values wrapped to one octave.
"""
return [np.mod(freqs, 12) for freqs in frequencies_midi]
def compute_num_freqs(frequencies):
"""Computes the number of frequencies for each time point.
Parameters
----------
frequencies : list of np.ndarray
Frequency values
Returns
-------
num_freqs : np.ndarray
Number of frequencies at each time point.
"""
return np.array([f.size for f in frequencies])
def compute_num_true_positives(ref_freqs, est_freqs, window=0.5, chroma=False):
"""Compute the number of true positives in an estimate given a reference.
A frequency is correct if it is within a quartertone of the
correct frequency.
Parameters
----------
ref_freqs : list of np.ndarray
reference frequencies (MIDI)
est_freqs : list of np.ndarray
estimated frequencies (MIDI)
window : float
Window size, in semitones
chroma : bool
If True, computes distances modulo n.
If True, ``ref_freqs`` and ``est_freqs`` should be wrapped modulo n.
Returns
-------
true_positives : np.ndarray
Array the same length as ref_freqs containing the number of true
positives.
"""
n_frames = len(ref_freqs)
true_positives = np.zeros((n_frames, ))
for i, (ref_frame, est_frame) in enumerate(zip(ref_freqs, est_freqs)):
if chroma:
# match chroma-wrapped frequency events
matching = util.match_events(
ref_frame, est_frame, window,
distance=util._outer_distance_mod_n)
else:
# match frequency events within tolerance window in semitones
matching = util.match_events(ref_frame, est_frame, window)
true_positives[i] = len(matching)
return true_positives
def compute_accuracy(true_positives, n_ref, n_est):
"""Compute accuracy metrics.
Parameters
----------
true_positives : np.ndarray
Array containing the number of true positives at each time point.
n_ref : np.ndarray
Array containing the number of reference frequencies at each time
point.
n_est : np.ndarray
Array containing the number of estimate frequencies at each time point.
Returns
-------
precision : float
``sum(true_positives)/sum(n_est)``
recall : float
``sum(true_positives)/sum(n_ref)``
acc : float
``sum(true_positives)/sum(n_est + n_ref - true_positives)``
"""
true_positive_sum = float(true_positives.sum())
n_est_sum = n_est.sum()
if n_est_sum > 0:
precision = true_positive_sum/n_est.sum()
else:
warnings.warn("Estimate frequencies are all empty.")
precision = 0.0
n_ref_sum = n_ref.sum()
if n_ref_sum > 0:
recall = true_positive_sum/n_ref.sum()
else:
warnings.warn("Reference frequencies are all empty.")
recall = 0.0
acc_denom = (n_est + n_ref - true_positives).sum()
if acc_denom > 0:
acc = true_positive_sum/acc_denom
else:
acc = 0.0
return precision, recall, acc
def compute_err_score(true_positives, n_ref, n_est):
"""Compute error score metrics.
Parameters
----------
true_positives : np.ndarray
Array containing the number of true positives at each time point.
n_ref : np.ndarray
Array containing the number of reference frequencies at each time
point.
n_est : np.ndarray
Array containing the number of estimate frequencies at each time point.
Returns
-------
e_sub : float
Substitution error
e_miss : float
Miss error
e_fa : float
False alarm error
e_tot : float
Total error
"""
n_ref_sum = float(n_ref.sum())
if n_ref_sum == 0:
warnings.warn("Reference frequencies are all empty.")
return 0., 0., 0., 0.
# Substitution error
e_sub = (np.min([n_ref, n_est], axis=0) - true_positives).sum()/n_ref_sum
# compute the max of (n_ref - n_est) and 0
e_miss_numerator = n_ref - n_est
e_miss_numerator[e_miss_numerator < 0] = 0
# Miss error
e_miss = e_miss_numerator.sum()/n_ref_sum
# compute the max of (n_est - n_ref) and 0
e_fa_numerator = n_est - n_ref
e_fa_numerator[e_fa_numerator < 0] = 0
# False alarm error
e_fa = e_fa_numerator.sum()/n_ref_sum
# total error
e_tot = (np.max([n_ref, n_est], axis=0) - true_positives).sum()/n_ref_sum
return e_sub, e_miss, e_fa, e_tot
def metrics(ref_time, ref_freqs, est_time, est_freqs, **kwargs):
"""Compute multipitch metrics. All metrics are computed at the 'macro' level
such that the frame true positive/false positive/false negative rates are
summed across time and the metrics are computed on the combined values.
Examples
--------
>>> ref_time, ref_freqs = mir_eval.io.load_ragged_time_series(
... 'reference.txt')
>>> est_time, est_freqs = mir_eval.io.load_ragged_time_series(
... 'estimated.txt')
>>> metris_tuple = mir_eval.multipitch.metrics(
... ref_time, ref_freqs, est_time, est_freqs)
Parameters
----------
ref_time : np.ndarray
Time of each reference frequency value
ref_freqs : list of np.ndarray
List of np.ndarrays of reference frequency values
est_time : np.ndarray
Time of each estimated frequency value
est_freqs : list of np.ndarray
List of np.ndarrays of estimate frequency values
kwargs
Additional keyword arguments which will be passed to the
appropriate metric or preprocessing functions.
Returns
-------
precision : float
Precision (TP/(TP + FP))
recall : float
Recall (TP/(TP + FN))
accuracy : float
Accuracy (TP/(TP + FP + FN))
e_sub : float
Substitution error
e_miss : float
Miss error
e_fa : float
False alarm error
e_tot : float
Total error
precision_chroma : float
Chroma precision
recall_chroma : float
Chroma recall
accuracy_chroma : float
Chroma accuracy
e_sub_chroma : float
Chroma substitution error
e_miss_chroma : float
Chroma miss error
e_fa_chroma : float
Chroma false alarm error
e_tot_chroma : float
Chroma total error
"""
validate(ref_time, ref_freqs, est_time, est_freqs)
# resample est_freqs if est_times is different from ref_times
if est_time.size != ref_time.size or not np.allclose(est_time, ref_time):
warnings.warn("Estimate times not equal to reference times. "
"Resampling to common time base.")
est_freqs = resample_multipitch(est_time, est_freqs, ref_time)
# convert frequencies from Hz to continuous midi note number
ref_freqs_midi = frequencies_to_midi(ref_freqs)
est_freqs_midi = frequencies_to_midi(est_freqs)
# compute chroma wrapped midi number
ref_freqs_chroma = midi_to_chroma(ref_freqs_midi)
est_freqs_chroma = midi_to_chroma(est_freqs_midi)
# count number of occurences
n_ref = compute_num_freqs(ref_freqs_midi)
n_est = compute_num_freqs(est_freqs_midi)
# compute the number of true positives
true_positives = util.filter_kwargs(
compute_num_true_positives, ref_freqs_midi, est_freqs_midi, **kwargs)
# compute the number of true positives ignoring octave mistakes
true_positives_chroma = util.filter_kwargs(
compute_num_true_positives, ref_freqs_chroma,
est_freqs_chroma, chroma=True, **kwargs)
# compute accuracy metrics
precision, recall, accuracy = compute_accuracy(
true_positives, n_ref, n_est)
# compute error metrics
e_sub, e_miss, e_fa, e_tot = compute_err_score(
true_positives, n_ref, n_est)
# compute accuracy metrics ignoring octave mistakes
precision_chroma, recall_chroma, accuracy_chroma = compute_accuracy(
true_positives_chroma, n_ref, n_est)
# compute error metrics ignoring octave mistakes
e_sub_chroma, e_miss_chroma, e_fa_chroma, e_tot_chroma = compute_err_score(
true_positives_chroma, n_ref, n_est)
return (precision, recall, accuracy, e_sub, e_miss, e_fa, e_tot,
precision_chroma, recall_chroma, accuracy_chroma, e_sub_chroma,
e_miss_chroma, e_fa_chroma, e_tot_chroma)
def evaluate(ref_time, ref_freqs, est_time, est_freqs, **kwargs):
"""Evaluate two multipitch (multi-f0) transcriptions, where the first is
treated as the reference (ground truth) and the second as the estimate to
be evaluated (prediction).
Examples
--------
>>> ref_time, ref_freq = mir_eval.io.load_ragged_time_series('ref.txt')
>>> est_time, est_freq = mir_eval.io.load_ragged_time_series('est.txt')
>>> scores = mir_eval.multipitch.evaluate(ref_time, ref_freq,
... est_time, est_freq)
Parameters
----------
ref_time : np.ndarray
Time of each reference frequency value
ref_freqs : list of np.ndarray
List of np.ndarrays of reference frequency values
est_time : np.ndarray
Time of each estimated frequency value
est_freqs : list of np.ndarray
List of np.ndarrays of estimate frequency values
kwargs
Additional keyword arguments which will be passed to the
appropriate metric or preprocessing functions.
Returns
-------
scores : dict
Dictionary of scores, where the key is the metric name (str) and
the value is the (float) score achieved.
"""
scores = collections.OrderedDict()
(scores['Precision'],
scores['Recall'],
scores['Accuracy'],
scores['Substitution Error'],
scores['Miss Error'],
scores['False Alarm Error'],
scores['Total Error'],
scores['Chroma Precision'],
scores['Chroma Recall'],
scores['Chroma Accuracy'],
scores['Chroma Substitution Error'],
scores['Chroma Miss Error'],
scores['Chroma False Alarm Error'],
scores['Chroma Total Error']) = util.filter_kwargs(
metrics, ref_time, ref_freqs, est_time, est_freqs, **kwargs)
return scores
mir_eval-0.7/mir_eval/onset.py 0000664 0000000 0000000 00000010455 14203260312 0016454 0 ustar 00root root 0000000 0000000 '''
The goal of an onset detection algorithm is to automatically determine when
notes are played in a piece of music. The primary method used to evaluate
onset detectors is to first determine which estimated onsets are "correct",
where correctness is defined as being within a small window of a reference
onset.
Based in part on this script:
https://github.com/CPJKU/onset_detection/blob/master/onset_evaluation.py
Conventions
-----------
Onsets should be provided in the form of a 1-dimensional array of onset
times in seconds in increasing order.
Metrics
-------
* :func:`mir_eval.onset.f_measure`: Precision, Recall, and F-measure scores
based on the number of esimated onsets which are sufficiently close to
reference onsets.
'''
import collections
from . import util
import warnings
# The maximum allowable beat time
MAX_TIME = 30000.
def validate(reference_onsets, estimated_onsets):
"""Checks that the input annotations to a metric look like valid onset time
arrays, and throws helpful errors if not.
Parameters
----------
reference_onsets : np.ndarray
reference onset locations, in seconds
estimated_onsets : np.ndarray
estimated onset locations, in seconds
"""
# If reference or estimated onsets are empty, warn because metric will be 0
if reference_onsets.size == 0:
warnings.warn("Reference onsets are empty.")
if estimated_onsets.size == 0:
warnings.warn("Estimated onsets are empty.")
for onsets in [reference_onsets, estimated_onsets]:
util.validate_events(onsets, MAX_TIME)
def f_measure(reference_onsets, estimated_onsets, window=.05):
"""Compute the F-measure of correct vs incorrectly predicted onsets.
"Corectness" is determined over a small window.
Examples
--------
>>> reference_onsets = mir_eval.io.load_events('reference.txt')
>>> estimated_onsets = mir_eval.io.load_events('estimated.txt')
>>> F, P, R = mir_eval.onset.f_measure(reference_onsets,
... estimated_onsets)
Parameters
----------
reference_onsets : np.ndarray
reference onset locations, in seconds
estimated_onsets : np.ndarray
estimated onset locations, in seconds
window : float
Window size, in seconds
(Default value = .05)
Returns
-------
f_measure : float
2*precision*recall/(precision + recall)
precision : float
(# true positives)/(# true positives + # false positives)
recall : float
(# true positives)/(# true positives + # false negatives)
"""
validate(reference_onsets, estimated_onsets)
# If either list is empty, return 0s
if reference_onsets.size == 0 or estimated_onsets.size == 0:
return 0., 0., 0.
# Compute the best-case matching between reference and estimated onset
# locations
matching = util.match_events(reference_onsets, estimated_onsets, window)
precision = float(len(matching))/len(estimated_onsets)
recall = float(len(matching))/len(reference_onsets)
# Compute F-measure and return all statistics
return util.f_measure(precision, recall), precision, recall
def evaluate(reference_onsets, estimated_onsets, **kwargs):
"""Compute all metrics for the given reference and estimated annotations.
Examples
--------
>>> reference_onsets = mir_eval.io.load_events('reference.txt')
>>> estimated_onsets = mir_eval.io.load_events('estimated.txt')
>>> scores = mir_eval.onset.evaluate(reference_onsets,
... estimated_onsets)
Parameters
----------
reference_onsets : np.ndarray
reference onset locations, in seconds
estimated_onsets : np.ndarray
estimated onset locations, in seconds
kwargs
Additional keyword arguments which will be passed to the
appropriate metric or preprocessing functions.
Returns
-------
scores : dict
Dictionary of scores, where the key is the metric name (str) and
the value is the (float) score achieved.
"""
# Compute all metrics
scores = collections.OrderedDict()
(scores['F-measure'],
scores['Precision'],
scores['Recall']) = util.filter_kwargs(f_measure, reference_onsets,
estimated_onsets, **kwargs)
return scores
mir_eval-0.7/mir_eval/pattern.py 0000664 0000000 0000000 00000055770 14203260312 0017012 0 ustar 00root root 0000000 0000000 """
Pattern discovery involves the identification of musical patterns (i.e. short
fragments or melodic ideas that repeat at least twice) both from audio and
symbolic representations. The metrics used to evaluate pattern discovery
systems attempt to quantify the ability of the algorithm to not only determine
the present patterns in a piece, but also to find all of their occurrences.
Based on the methods described here:
T. Collins. MIREX task: Discovery of repeated themes & sections.
http://www.music-ir.org/mirex/wiki/2013:Discovery_of_Repeated_Themes_&_Sections,
2013.
Conventions
-----------
The input format can be automatically generated by calling
:func:`mir_eval.io.load_patterns`. This format is a list of a list of
tuples. The first list collections patterns, each of which is a list of
occurences, and each occurrence is a list of MIDI onset tuples of
``(onset_time, mid_note)``
A pattern is a list of occurrences. The first occurrence must be the prototype
of that pattern (i.e. the most representative of all the occurrences). An
occurrence is a list of tuples containing the onset time and the midi note
number.
Metrics
-------
* :func:`mir_eval.pattern.standard_FPR`: Strict metric in order to find the
possibly transposed patterns of exact length. This is the only metric that
considers transposed patterns.
* :func:`mir_eval.pattern.establishment_FPR`: Evaluates the amount of patterns
that were successfully identified by the estimated results, no matter how
many occurrences they found. In other words, this metric captures how the
algorithm successfully *established* that a pattern repeated at least twice,
and this pattern is also found in the reference annotation.
* :func:`mir_eval.pattern.occurrence_FPR`: Evaluation of how well an estimation
can effectively identify all the occurrences of the found patterns,
independently of how many patterns have been discovered. This metric has a
threshold parameter that indicates how similar two occurrences must be in
order to be considered equal. In MIREX, this evaluation is run twice, with
thresholds .75 and .5.
* :func:`mir_eval.pattern.three_layer_FPR`: Aims to evaluate the general
similarity between the reference and the estimations, combining both the
establishment of patterns and the retrieval of its occurrences in a single F1
score.
* :func:`mir_eval.pattern.first_n_three_layer_P`: Computes the three-layer
precision for the first N patterns only in order to measure the ability of
the algorithm to sort the identified patterns based on their relevance.
* :func:`mir_eval.pattern.first_n_target_proportion_R`: Computes the target
proportion recall for the first N patterns only in order to measure the
ability of the algorithm to sort the identified patterns based on their
relevance.
"""
import numpy as np
from . import util
import warnings
import collections
def _n_onset_midi(patterns):
"""Computes the number of onset_midi objects in a pattern
Parameters
----------
patterns :
A list of patterns using the format returned by
:func:`mir_eval.io.load_patterns()`
Returns
-------
n_onsets : int
Number of onsets within the pattern.
"""
return len([o_m for pat in patterns for occ in pat for o_m in occ])
def validate(reference_patterns, estimated_patterns):
"""Checks that the input annotations to a metric look like valid pattern
lists, and throws helpful errors if not.
Parameters
----------
reference_patterns : list
The reference patterns using the format returned by
:func:`mir_eval.io.load_patterns()`
estimated_patterns : list
The estimated patterns in the same format
Returns
-------
"""
# Warn if pattern lists are empty
if _n_onset_midi(reference_patterns) == 0:
warnings.warn('Reference patterns are empty.')
if _n_onset_midi(estimated_patterns) == 0:
warnings.warn('Estimated patterns are empty.')
for patterns in [reference_patterns, estimated_patterns]:
for pattern in patterns:
if len(pattern) <= 0:
raise ValueError("Each pattern must contain at least one "
"occurrence.")
for occurrence in pattern:
for onset_midi in occurrence:
if len(onset_midi) != 2:
raise ValueError("The (onset, midi) tuple must "
"contain exactly 2 elements.")
def _occurrence_intersection(occ_P, occ_Q):
"""Computes the intersection between two occurrences.
Parameters
----------
occ_P : list of tuples
(onset, midi) pairs representing the reference occurrence.
occ_Q : list
second list of (onset, midi) tuples
Returns
-------
S : set
Set of the intersection between occ_P and occ_Q.
"""
set_P = set([tuple(onset_midi) for onset_midi in occ_P])
set_Q = set([tuple(onset_midi) for onset_midi in occ_Q])
return set_P & set_Q # Return the intersection
def _compute_score_matrix(P, Q, similarity_metric="cardinality_score"):
"""Computes the score matrix between the patterns P and Q.
Parameters
----------
P : list
Pattern containing a list of occurrences.
Q : list
Pattern containing a list of occurrences.
similarity_metric : str
A string representing the metric to be used
when computing the similarity matrix. Accepted values:
- "cardinality_score":
Count of the intersection between occurrences.
(Default value = "cardinality_score")
Returns
-------
sm : np.array
The score matrix between P and Q using the similarity_metric.
"""
sm = np.zeros((len(P), len(Q))) # The score matrix
for iP, occ_P in enumerate(P):
for iQ, occ_Q in enumerate(Q):
if similarity_metric == "cardinality_score":
denom = float(np.max([len(occ_P), len(occ_Q)]))
# Compute the score
sm[iP, iQ] = len(_occurrence_intersection(occ_P, occ_Q)) / \
denom
# TODO: More scores: 'normalised matching socre'
else:
raise ValueError("The similarity metric (%s) can only be: "
"'cardinality_score'.")
return sm
def standard_FPR(reference_patterns, estimated_patterns, tol=1e-5):
"""Standard F1 Score, Precision and Recall.
This metric checks if the prototype patterns of the reference match
possible translated patterns in the prototype patterns of the estimations.
Since the sizes of these prototypes must be equal, this metric is quite
restictive and it tends to be 0 in most of 2013 MIREX results.
Examples
--------
>>> ref_patterns = mir_eval.io.load_patterns("ref_pattern.txt")
>>> est_patterns = mir_eval.io.load_patterns("est_pattern.txt")
>>> F, P, R = mir_eval.pattern.standard_FPR(ref_patterns, est_patterns)
Parameters
----------
reference_patterns : list
The reference patterns using the format returned by
:func:`mir_eval.io.load_patterns()`
estimated_patterns : list
The estimated patterns in the same format
tol : float
Tolerance level when comparing reference against estimation.
Default parameter is the one found in the original matlab code by
Tom Collins used for MIREX 2013.
(Default value = 1e-5)
Returns
-------
f_measure : float
The standard F1 Score
precision : float
The standard Precision
recall : float
The standard Recall
"""
validate(reference_patterns, estimated_patterns)
nP = len(reference_patterns) # Number of patterns in the reference
nQ = len(estimated_patterns) # Number of patterns in the estimation
k = 0 # Number of patterns that match
# If no patterns were provided, metric is zero
if _n_onset_midi(reference_patterns) == 0 or \
_n_onset_midi(estimated_patterns) == 0:
return 0., 0., 0.
# Find matches of the prototype patterns
for ref_pattern in reference_patterns:
P = np.asarray(ref_pattern[0]) # Get reference prototype
for est_pattern in estimated_patterns:
Q = np.asarray(est_pattern[0]) # Get estimation prototype
if len(P) != len(Q):
continue
# Check transposition given a certain tolerance
if (len(P) == len(Q) == 1 or
np.max(np.abs(np.diff(P - Q, axis=0))) < tol):
k += 1
break
# Compute the standard measures
precision = k / float(nQ)
recall = k / float(nP)
f_measure = util.f_measure(precision, recall)
return f_measure, precision, recall
def establishment_FPR(reference_patterns, estimated_patterns,
similarity_metric="cardinality_score"):
"""Establishment F1 Score, Precision and Recall.
Examples
--------
>>> ref_patterns = mir_eval.io.load_patterns("ref_pattern.txt")
>>> est_patterns = mir_eval.io.load_patterns("est_pattern.txt")
>>> F, P, R = mir_eval.pattern.establishment_FPR(ref_patterns,
... est_patterns)
Parameters
----------
reference_patterns : list
The reference patterns in the format returned by
:func:`mir_eval.io.load_patterns()`
estimated_patterns : list
The estimated patterns in the same format
similarity_metric : str
A string representing the metric to be used when computing the
similarity matrix. Accepted values:
- "cardinality_score": Count of the intersection
between occurrences.
(Default value = "cardinality_score")
Returns
-------
f_measure : float
The establishment F1 Score
precision : float
The establishment Precision
recall : float
The establishment Recall
"""
validate(reference_patterns, estimated_patterns)
nP = len(reference_patterns) # Number of elements in reference
nQ = len(estimated_patterns) # Number of elements in estimation
S = np.zeros((nP, nQ)) # Establishment matrix
# If no patterns were provided, metric is zero
if _n_onset_midi(reference_patterns) == 0 or \
_n_onset_midi(estimated_patterns) == 0:
return 0., 0., 0.
for iP, ref_pattern in enumerate(reference_patterns):
for iQ, est_pattern in enumerate(estimated_patterns):
s = _compute_score_matrix(ref_pattern, est_pattern,
similarity_metric)
S[iP, iQ] = np.max(s)
# Compute scores
precision = np.mean(np.max(S, axis=0))
recall = np.mean(np.max(S, axis=1))
f_measure = util.f_measure(precision, recall)
return f_measure, precision, recall
def occurrence_FPR(reference_patterns, estimated_patterns, thres=.75,
similarity_metric="cardinality_score"):
"""Establishment F1 Score, Precision and Recall.
Examples
--------
>>> ref_patterns = mir_eval.io.load_patterns("ref_pattern.txt")
>>> est_patterns = mir_eval.io.load_patterns("est_pattern.txt")
>>> F, P, R = mir_eval.pattern.occurrence_FPR(ref_patterns,
... est_patterns)
Parameters
----------
reference_patterns : list
The reference patterns in the format returned by
:func:`mir_eval.io.load_patterns()`
estimated_patterns : list
The estimated patterns in the same format
thres : float
How similar two occcurrences must be in order to be considered
equal
(Default value = .75)
similarity_metric : str
A string representing the metric to be used
when computing the similarity matrix. Accepted values:
- "cardinality_score": Count of the intersection
between occurrences.
(Default value = "cardinality_score")
Returns
-------
f_measure : float
The establishment F1 Score
precision : float
The establishment Precision
recall : float
The establishment Recall
"""
validate(reference_patterns, estimated_patterns)
# Number of elements in reference
nP = len(reference_patterns)
# Number of elements in estimation
nQ = len(estimated_patterns)
# Occurrence matrix with Precision and recall in its last dimension
O_PR = np.zeros((nP, nQ, 2))
# Index of the values that are greater than the specified threshold
rel_idx = np.empty((0, 2), dtype=int)
# If no patterns were provided, metric is zero
if _n_onset_midi(reference_patterns) == 0 or \
_n_onset_midi(estimated_patterns) == 0:
return 0., 0., 0.
for iP, ref_pattern in enumerate(reference_patterns):
for iQ, est_pattern in enumerate(estimated_patterns):
s = _compute_score_matrix(ref_pattern, est_pattern,
similarity_metric)
if np.max(s) >= thres:
O_PR[iP, iQ, 0] = np.mean(np.max(s, axis=0))
O_PR[iP, iQ, 1] = np.mean(np.max(s, axis=1))
rel_idx = np.vstack((rel_idx, [iP, iQ]))
# Compute the scores
if len(rel_idx) == 0:
precision = 0
recall = 0
else:
P = O_PR[:, :, 0]
precision = np.mean(np.max(P[np.ix_(rel_idx[:, 0], rel_idx[:, 1])],
axis=0))
R = O_PR[:, :, 1]
recall = np.mean(np.max(R[np.ix_(rel_idx[:, 0], rel_idx[:, 1])],
axis=1))
f_measure = util.f_measure(precision, recall)
return f_measure, precision, recall
def three_layer_FPR(reference_patterns, estimated_patterns):
"""Three Layer F1 Score, Precision and Recall. As described by Meridith.
Examples
--------
>>> ref_patterns = mir_eval.io.load_patterns("ref_pattern.txt")
>>> est_patterns = mir_eval.io.load_patterns("est_pattern.txt")
>>> F, P, R = mir_eval.pattern.three_layer_FPR(ref_patterns,
... est_patterns)
Parameters
----------
reference_patterns : list
The reference patterns in the format returned by
:func:`mir_eval.io.load_patterns()`
estimated_patterns : list
The estimated patterns in the same format
Returns
-------
f_measure : float
The three-layer F1 Score
precision : float
The three-layer Precision
recall : float
The three-layer Recall
"""
validate(reference_patterns, estimated_patterns)
def compute_first_layer_PR(ref_occs, est_occs):
"""Computes the first layer Precision and Recall values given the
set of occurrences in the reference and the set of occurrences in the
estimation.
Parameters
----------
ref_occs :
est_occs :
Returns
-------
"""
# Find the length of the intersection between reference and estimation
s = len(_occurrence_intersection(ref_occs, est_occs))
# Compute the first layer scores
precision = s / float(len(ref_occs))
recall = s / float(len(est_occs))
return precision, recall
def compute_second_layer_PR(ref_pattern, est_pattern):
"""Computes the second layer Precision and Recall values given the
set of occurrences in the reference and the set of occurrences in the
estimation.
Parameters
----------
ref_pattern :
est_pattern :
Returns
-------
"""
# Compute the first layer scores
F_1 = compute_layer(ref_pattern, est_pattern)
# Compute the second layer scores
precision = np.mean(np.max(F_1, axis=0))
recall = np.mean(np.max(F_1, axis=1))
return precision, recall
def compute_layer(ref_elements, est_elements, layer=1):
"""Computes the F-measure matrix for a given layer. The reference and
estimated elements can be either patters or occurrences, depending
on the layer.
For layer 1, the elements must be occurrences.
For layer 2, the elements must be patterns.
Parameters
----------
ref_elements :
est_elements :
layer :
(Default value = 1)
Returns
-------
"""
if layer != 1 and layer != 2:
raise ValueError("Layer (%d) must be an integer between 1 and 2"
% layer)
nP = len(ref_elements) # Number of elements in reference
nQ = len(est_elements) # Number of elements in estimation
F = np.zeros((nP, nQ)) # F-measure matrix for the given layer
for iP in range(nP):
for iQ in range(nQ):
if layer == 1:
func = compute_first_layer_PR
elif layer == 2:
func = compute_second_layer_PR
# Compute layer scores
precision, recall = func(ref_elements[iP], est_elements[iQ])
F[iP, iQ] = util.f_measure(precision, recall)
return F
# If no patterns were provided, metric is zero
if _n_onset_midi(reference_patterns) == 0 or \
_n_onset_midi(estimated_patterns) == 0:
return 0., 0., 0.
# Compute the second layer (it includes the first layer)
F_2 = compute_layer(reference_patterns, estimated_patterns, layer=2)
# Compute the final scores (third layer)
precision_3 = np.mean(np.max(F_2, axis=0))
recall_3 = np.mean(np.max(F_2, axis=1))
f_measure_3 = util.f_measure(precision_3, recall_3)
return f_measure_3, precision_3, recall_3
def first_n_three_layer_P(reference_patterns, estimated_patterns, n=5):
"""First n three-layer precision.
This metric is basically the same as the three-layer FPR but it is only
applied to the first n estimated patterns, and it only returns the
precision. In MIREX and typically, n = 5.
Examples
--------
>>> ref_patterns = mir_eval.io.load_patterns("ref_pattern.txt")
>>> est_patterns = mir_eval.io.load_patterns("est_pattern.txt")
>>> P = mir_eval.pattern.first_n_three_layer_P(ref_patterns,
... est_patterns, n=5)
Parameters
----------
reference_patterns : list
The reference patterns in the format returned by
:func:`mir_eval.io.load_patterns()`
estimated_patterns : list
The estimated patterns in the same format
n : int
Number of patterns to consider from the estimated results, in
the order they appear in the matrix
(Default value = 5)
Returns
-------
precision : float
The first n three-layer Precision
"""
validate(reference_patterns, estimated_patterns)
# If no patterns were provided, metric is zero
if _n_onset_midi(reference_patterns) == 0 or \
_n_onset_midi(estimated_patterns) == 0:
return 0., 0., 0.
# Get only the first n patterns from the estimated results
fn_est_patterns = estimated_patterns[:min(len(estimated_patterns), n)]
# Compute the three-layer scores for the first n estimated patterns
F, P, R = three_layer_FPR(reference_patterns, fn_est_patterns)
return P # Return the precision only
def first_n_target_proportion_R(reference_patterns, estimated_patterns, n=5):
"""First n target proportion establishment recall metric.
This metric is similar is similar to the establishment FPR score, but it
only takes into account the first n estimated patterns and it only
outputs the Recall value of it.
Examples
--------
>>> ref_patterns = mir_eval.io.load_patterns("ref_pattern.txt")
>>> est_patterns = mir_eval.io.load_patterns("est_pattern.txt")
>>> R = mir_eval.pattern.first_n_target_proportion_R(
... ref_patterns, est_patterns, n=5)
Parameters
----------
reference_patterns : list
The reference patterns in the format returned by
:func:`mir_eval.io.load_patterns()`
estimated_patterns : list
The estimated patterns in the same format
n : int
Number of patterns to consider from the estimated results, in
the order they appear in the matrix.
(Default value = 5)
Returns
-------
recall : float
The first n target proportion Recall.
"""
validate(reference_patterns, estimated_patterns)
# If no patterns were provided, metric is zero
if _n_onset_midi(reference_patterns) == 0 or \
_n_onset_midi(estimated_patterns) == 0:
return 0., 0., 0.
# Get only the first n patterns from the estimated results
fn_est_patterns = estimated_patterns[:min(len(estimated_patterns), n)]
F, P, R = establishment_FPR(reference_patterns, fn_est_patterns)
return R
def evaluate(ref_patterns, est_patterns, **kwargs):
"""Load data and perform the evaluation.
Examples
--------
>>> ref_patterns = mir_eval.io.load_patterns("ref_pattern.txt")
>>> est_patterns = mir_eval.io.load_patterns("est_pattern.txt")
>>> scores = mir_eval.pattern.evaluate(ref_patterns, est_patterns)
Parameters
----------
ref_patterns : list
The reference patterns in the format returned by
:func:`mir_eval.io.load_patterns()`
est_patterns : list
The estimated patterns in the same format
kwargs
Additional keyword arguments which will be passed to the
appropriate metric or preprocessing functions.
Returns
-------
scores : dict
Dictionary of scores, where the key is the metric name (str) and
the value is the (float) score achieved.
"""
# Compute all the metrics
scores = collections.OrderedDict()
# Standard scores
scores['F'], scores['P'], scores['R'] = \
util.filter_kwargs(standard_FPR, ref_patterns, est_patterns, **kwargs)
# Establishment scores
scores['F_est'], scores['P_est'], scores['R_est'] = \
util.filter_kwargs(establishment_FPR, ref_patterns, est_patterns,
**kwargs)
# Occurrence scores
# Force these values for thresh
kwargs['thresh'] = .5
scores['F_occ.5'], scores['P_occ.5'], scores['R_occ.5'] = \
util.filter_kwargs(occurrence_FPR, ref_patterns, est_patterns,
**kwargs)
kwargs['thresh'] = .75
scores['F_occ.75'], scores['P_occ.75'], scores['R_occ.75'] = \
util.filter_kwargs(occurrence_FPR, ref_patterns, est_patterns,
**kwargs)
# Three-layer scores
scores['F_3'], scores['P_3'], scores['R_3'] = \
util.filter_kwargs(three_layer_FPR, ref_patterns, est_patterns,
**kwargs)
# First Five Patterns scores
# Set default value of n
if 'n' not in kwargs:
kwargs['n'] = 5
scores['FFP'] = util.filter_kwargs(first_n_three_layer_P, ref_patterns,
est_patterns, **kwargs)
scores['FFTP_est'] = \
util.filter_kwargs(first_n_target_proportion_R, ref_patterns,
est_patterns, **kwargs)
return scores
mir_eval-0.7/mir_eval/segment.py 0000664 0000000 0000000 00000136121 14203260312 0016765 0 ustar 00root root 0000000 0000000 # CREATED:2013-08-13 12:02:42 by Brian McFee
'''
Evaluation criteria for structural segmentation fall into two categories:
boundary annotation and structural annotation. Boundary annotation is the task
of predicting the times at which structural changes occur, such as when a verse
transitions to a refrain. Metrics for boundary annotation compare estimated
segment boundaries to reference boundaries. Structural annotation is the task
of assigning labels to detected segments. The estimated labels may be
arbitrary strings - such as A, B, C, - and they need not describe functional
concepts. Metrics for structural annotation are similar to those used for
clustering data.
Conventions
-----------
Both boundary and structural annotation metrics require two dimensional arrays
with two columns, one for boundary start times and one for boundary end times.
Structural annotation further require lists of reference and estimated segment
labels which must have a length which is equal to the number of rows in the
corresponding list of boundary edges. In both tasks, we assume that
annotations express a partitioning of the track into intervals. The function
:func:`mir_eval.util.adjust_intervals` can be used to pad or crop the segment
boundaries to span the duration of the entire track.
Metrics
-------
* :func:`mir_eval.segment.detection`: An estimated boundary is considered
correct if it falls within a window around a reference boundary
[#turnbull2007]_
* :func:`mir_eval.segment.deviation`: Computes the median absolute time
difference from a reference boundary to its nearest estimated boundary, and
vice versa [#turnbull2007]_
* :func:`mir_eval.segment.pairwise`: For classifying pairs of sampled time
instants as belonging to the same structural component [#levy2008]_
* :func:`mir_eval.segment.rand_index`: Clusters reference and estimated
annotations and compares them by the Rand Index
* :func:`mir_eval.segment.ari`: Computes the Rand index, adjusted for chance
* :func:`mir_eval.segment.nce`: Interprets sampled reference and estimated
labels as samples of random variables :math:`Y_R, Y_E` from which the
conditional entropy of :math:`Y_R` given :math:`Y_E` (Under-Segmentation) and
:math:`Y_E` given :math:`Y_R` (Over-Segmentation) are estimated
[#lukashevich2008]_
* :func:`mir_eval.segment.mutual_information`: Computes the standard,
normalized, and adjusted mutual information of sampled reference and
estimated segments
* :func:`mir_eval.segment.vmeasure`: Computes the V-Measure, which is similar
to the conditional entropy metrics, but uses the marginal distributions
as normalization rather than the maximum entropy distribution
[#rosenberg2007]_
References
----------
.. [#turnbull2007] Turnbull, D., Lanckriet, G. R., Pampalk, E.,
& Goto, M. A Supervised Approach for Detecting Boundaries in Music
Using Difference Features and Boosting. In ISMIR (pp. 51-54).
.. [#levy2008] Levy, M., & Sandler, M.
Structural segmentation of musical audio by constrained clustering.
IEEE transactions on audio, speech, and language processing, 16(2),
318-326.
.. [#lukashevich2008] Lukashevich, H. M.
Towards Quantitative Measures of Evaluating Song Segmentation.
In ISMIR (pp. 375-380).
.. [#rosenberg2007] Rosenberg, A., & Hirschberg, J.
V-Measure: A Conditional Entropy-Based External Cluster Evaluation
Measure.
In EMNLP-CoNLL (Vol. 7, pp. 410-420).
'''
import collections
import warnings
import numpy as np
import scipy.stats
import scipy.sparse
import scipy.misc
import scipy.special
from . import util
def validate_boundary(reference_intervals, estimated_intervals, trim):
"""Checks that the input annotations to a segment boundary estimation
metric (i.e. one that only takes in segment intervals) look like valid
segment times, and throws helpful errors if not.
Parameters
----------
reference_intervals : np.ndarray, shape=(n, 2)
reference segment intervals, in the format returned by
:func:`mir_eval.io.load_intervals` or
:func:`mir_eval.io.load_labeled_intervals`.
estimated_intervals : np.ndarray, shape=(m, 2)
estimated segment intervals, in the format returned by
:func:`mir_eval.io.load_intervals` or
:func:`mir_eval.io.load_labeled_intervals`.
trim : bool
will the start and end events be trimmed?
"""
if trim:
# If we're trimming, then we need at least 2 intervals
min_size = 2
else:
# If we're not trimming, then we only need one interval
min_size = 1
if len(reference_intervals) < min_size:
warnings.warn("Reference intervals are empty.")
if len(estimated_intervals) < min_size:
warnings.warn("Estimated intervals are empty.")
for intervals in [reference_intervals, estimated_intervals]:
util.validate_intervals(intervals)
def validate_structure(reference_intervals, reference_labels,
estimated_intervals, estimated_labels):
"""Checks that the input annotations to a structure estimation metric (i.e.
one that takes in both segment boundaries and their labels) look like valid
segment times and labels, and throws helpful errors if not.
Parameters
----------
reference_intervals : np.ndarray, shape=(n, 2)
reference segment intervals, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
reference_labels : list, shape=(n,)
reference segment labels, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
estimated_intervals : np.ndarray, shape=(m, 2)
estimated segment intervals, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
estimated_labels : list, shape=(m,)
estimated segment labels, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
"""
for (intervals, labels) in [(reference_intervals, reference_labels),
(estimated_intervals, estimated_labels)]:
util.validate_intervals(intervals)
if intervals.shape[0] != len(labels):
raise ValueError('Number of intervals does not match number '
'of labels')
# Check only when intervals are non-empty
if intervals.size > 0:
# Make sure intervals start at 0
if not np.allclose(intervals.min(), 0.0):
raise ValueError('Segment intervals do not start at 0')
if reference_intervals.size == 0:
warnings.warn("Reference intervals are empty.")
if estimated_intervals.size == 0:
warnings.warn("Estimated intervals are empty.")
# Check only when intervals are non-empty
if reference_intervals.size > 0 and estimated_intervals.size > 0:
if not np.allclose(reference_intervals.max(),
estimated_intervals.max()):
raise ValueError('End times do not match')
def detection(reference_intervals, estimated_intervals,
window=0.5, beta=1.0, trim=False):
"""Boundary detection hit-rate.
A hit is counted whenever an reference boundary is within ``window`` of a
estimated boundary. Note that each boundary is matched at most once: this
is achieved by computing the size of a maximal matching between reference
and estimated boundary points, subject to the window constraint.
Examples
--------
>>> ref_intervals, _ = mir_eval.io.load_labeled_intervals('ref.lab')
>>> est_intervals, _ = mir_eval.io.load_labeled_intervals('est.lab')
>>> # With 0.5s windowing
>>> P05, R05, F05 = mir_eval.segment.detection(ref_intervals,
... est_intervals,
... window=0.5)
>>> # With 3s windowing
>>> P3, R3, F3 = mir_eval.segment.detection(ref_intervals,
... est_intervals,
... window=3)
>>> # Ignoring hits for the beginning and end of track
>>> P, R, F = mir_eval.segment.detection(ref_intervals,
... est_intervals,
... window=0.5,
... trim=True)
Parameters
----------
reference_intervals : np.ndarray, shape=(n, 2)
reference segment intervals, in the format returned by
:func:`mir_eval.io.load_intervals` or
:func:`mir_eval.io.load_labeled_intervals`.
estimated_intervals : np.ndarray, shape=(m, 2)
estimated segment intervals, in the format returned by
:func:`mir_eval.io.load_intervals` or
:func:`mir_eval.io.load_labeled_intervals`.
window : float > 0
size of the window of 'correctness' around ground-truth beats
(in seconds)
(Default value = 0.5)
beta : float > 0
weighting constant for F-measure.
(Default value = 1.0)
trim : boolean
if ``True``, the first and last boundary times are ignored.
Typically, these denote start (0) and end-markers.
(Default value = False)
Returns
-------
precision : float
precision of estimated predictions
recall : float
recall of reference reference boundaries
f_measure : float
F-measure (weighted harmonic mean of ``precision`` and ``recall``)
"""
validate_boundary(reference_intervals, estimated_intervals, trim)
# Convert intervals to boundaries
reference_boundaries = util.intervals_to_boundaries(reference_intervals)
estimated_boundaries = util.intervals_to_boundaries(estimated_intervals)
# Suppress the first and last intervals
if trim:
reference_boundaries = reference_boundaries[1:-1]
estimated_boundaries = estimated_boundaries[1:-1]
# If we have no boundaries, we get no score.
if len(reference_boundaries) == 0 or len(estimated_boundaries) == 0:
return 0.0, 0.0, 0.0
matching = util.match_events(reference_boundaries,
estimated_boundaries,
window)
precision = float(len(matching)) / len(estimated_boundaries)
recall = float(len(matching)) / len(reference_boundaries)
f_measure = util.f_measure(precision, recall, beta=beta)
return precision, recall, f_measure
def deviation(reference_intervals, estimated_intervals, trim=False):
"""Compute the median deviations between reference
and estimated boundary times.
Examples
--------
>>> ref_intervals, _ = mir_eval.io.load_labeled_intervals('ref.lab')
>>> est_intervals, _ = mir_eval.io.load_labeled_intervals('est.lab')
>>> r_to_e, e_to_r = mir_eval.boundary.deviation(ref_intervals,
... est_intervals)
Parameters
----------
reference_intervals : np.ndarray, shape=(n, 2)
reference segment intervals, in the format returned by
:func:`mir_eval.io.load_intervals` or
:func:`mir_eval.io.load_labeled_intervals`.
estimated_intervals : np.ndarray, shape=(m, 2)
estimated segment intervals, in the format returned by
:func:`mir_eval.io.load_intervals` or
:func:`mir_eval.io.load_labeled_intervals`.
trim : boolean
if ``True``, the first and last intervals are ignored.
Typically, these denote start (0.0) and end-of-track markers.
(Default value = False)
Returns
-------
reference_to_estimated : float
median time from each reference boundary to the
closest estimated boundary
estimated_to_reference : float
median time from each estimated boundary to the
closest reference boundary
"""
validate_boundary(reference_intervals, estimated_intervals, trim)
# Convert intervals to boundaries
reference_boundaries = util.intervals_to_boundaries(reference_intervals)
estimated_boundaries = util.intervals_to_boundaries(estimated_intervals)
# Suppress the first and last intervals
if trim:
reference_boundaries = reference_boundaries[1:-1]
estimated_boundaries = estimated_boundaries[1:-1]
# If we have no boundaries, we get no score.
if len(reference_boundaries) == 0 or len(estimated_boundaries) == 0:
return np.nan, np.nan
dist = np.abs(np.subtract.outer(reference_boundaries,
estimated_boundaries))
estimated_to_reference = np.median(dist.min(axis=0))
reference_to_estimated = np.median(dist.min(axis=1))
return reference_to_estimated, estimated_to_reference
def pairwise(reference_intervals, reference_labels,
estimated_intervals, estimated_labels,
frame_size=0.1, beta=1.0):
"""Frame-clustering segmentation evaluation by pair-wise agreement.
Examples
--------
>>> (ref_intervals,
... ref_labels) = mir_eval.io.load_labeled_intervals('ref.lab')
>>> (est_intervals,
... est_labels) = mir_eval.io.load_labeled_intervals('est.lab')
>>> # Trim or pad the estimate to match reference timing
>>> (ref_intervals,
... ref_labels) = mir_eval.util.adjust_intervals(ref_intervals,
... ref_labels,
... t_min=0)
>>> (est_intervals,
... est_labels) = mir_eval.util.adjust_intervals(
... est_intervals, est_labels, t_min=0, t_max=ref_intervals.max())
>>> precision, recall, f = mir_eval.structure.pairwise(ref_intervals,
... ref_labels,
... est_intervals,
... est_labels)
Parameters
----------
reference_intervals : np.ndarray, shape=(n, 2)
reference segment intervals, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
reference_labels : list, shape=(n,)
reference segment labels, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
estimated_intervals : np.ndarray, shape=(m, 2)
estimated segment intervals, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
estimated_labels : list, shape=(m,)
estimated segment labels, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
frame_size : float > 0
length (in seconds) of frames for clustering
(Default value = 0.1)
beta : float > 0
beta value for F-measure
(Default value = 1.0)
Returns
-------
precision : float > 0
Precision of detecting whether frames belong in the same cluster
recall : float > 0
Recall of detecting whether frames belong in the same cluster
f : float > 0
F-measure of detecting whether frames belong in the same cluster
"""
validate_structure(reference_intervals, reference_labels,
estimated_intervals, estimated_labels)
# Check for empty annotations. Don't need to check labels because
# validate_structure makes sure they're the same size as intervals
if reference_intervals.size == 0 or estimated_intervals.size == 0:
return 0., 0., 0.
# Generate the cluster labels
y_ref = util.intervals_to_samples(reference_intervals,
reference_labels,
sample_size=frame_size)[-1]
y_ref = util.index_labels(y_ref)[0]
# Map to index space
y_est = util.intervals_to_samples(estimated_intervals,
estimated_labels,
sample_size=frame_size)[-1]
y_est = util.index_labels(y_est)[0]
# Build the reference label agreement matrix
agree_ref = np.equal.outer(y_ref, y_ref)
# Count the unique pairs
n_agree_ref = (agree_ref.sum() - len(y_ref)) / 2.0
# Repeat for estimate
agree_est = np.equal.outer(y_est, y_est)
n_agree_est = (agree_est.sum() - len(y_est)) / 2.0
# Find where they agree
matches = np.logical_and(agree_ref, agree_est)
n_matches = (matches.sum() - len(y_ref)) / 2.0
precision = n_matches / n_agree_est
recall = n_matches / n_agree_ref
f_measure = util.f_measure(precision, recall, beta=beta)
return precision, recall, f_measure
def rand_index(reference_intervals, reference_labels,
estimated_intervals, estimated_labels,
frame_size=0.1, beta=1.0):
"""(Non-adjusted) Rand index.
Examples
--------
>>> (ref_intervals,
... ref_labels) = mir_eval.io.load_labeled_intervals('ref.lab')
>>> (est_intervals,
... est_labels) = mir_eval.io.load_labeled_intervals('est.lab')
>>> # Trim or pad the estimate to match reference timing
>>> (ref_intervals,
... ref_labels) = mir_eval.util.adjust_intervals(ref_intervals,
... ref_labels,
... t_min=0)
>>> (est_intervals,
... est_labels) = mir_eval.util.adjust_intervals(
... est_intervals, est_labels, t_min=0, t_max=ref_intervals.max())
>>> rand_index = mir_eval.structure.rand_index(ref_intervals,
... ref_labels,
... est_intervals,
... est_labels)
Parameters
----------
reference_intervals : np.ndarray, shape=(n, 2)
reference segment intervals, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
reference_labels : list, shape=(n,)
reference segment labels, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
estimated_intervals : np.ndarray, shape=(m, 2)
estimated segment intervals, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
estimated_labels : list, shape=(m,)
estimated segment labels, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
frame_size : float > 0
length (in seconds) of frames for clustering
(Default value = 0.1)
beta : float > 0
beta value for F-measure
(Default value = 1.0)
Returns
-------
rand_index : float > 0
Rand index
"""
validate_structure(reference_intervals, reference_labels,
estimated_intervals, estimated_labels)
# Check for empty annotations. Don't need to check labels because
# validate_structure makes sure they're the same size as intervals
if reference_intervals.size == 0 or estimated_intervals.size == 0:
return 0., 0., 0.
# Generate the cluster labels
y_ref = util.intervals_to_samples(reference_intervals,
reference_labels,
sample_size=frame_size)[-1]
y_ref = util.index_labels(y_ref)[0]
# Map to index space
y_est = util.intervals_to_samples(estimated_intervals,
estimated_labels,
sample_size=frame_size)[-1]
y_est = util.index_labels(y_est)[0]
# Build the reference label agreement matrix
agree_ref = np.equal.outer(y_ref, y_ref)
# Repeat for estimate
agree_est = np.equal.outer(y_est, y_est)
# Find where they agree
matches_pos = np.logical_and(agree_ref, agree_est)
# Find where they disagree
matches_neg = np.logical_and(~agree_ref, ~agree_est)
n_pairs = len(y_ref) * (len(y_ref) - 1) / 2.0
n_matches_pos = (matches_pos.sum() - len(y_ref)) / 2.0
n_matches_neg = matches_neg.sum() / 2.0
rand = (n_matches_pos + n_matches_neg) / n_pairs
return rand
def _contingency_matrix(reference_indices, estimated_indices):
"""Computes the contingency matrix of a true labeling vs an estimated one.
Parameters
----------
reference_indices : np.ndarray
Array of reference indices
estimated_indices : np.ndarray
Array of estimated indices
Returns
-------
contingency_matrix : np.ndarray
Contingency matrix, shape=(#reference indices, #estimated indices)
.. note:: Based on sklearn.metrics.cluster.contingency_matrix
"""
ref_classes, ref_class_idx = np.unique(reference_indices,
return_inverse=True)
est_classes, est_class_idx = np.unique(estimated_indices,
return_inverse=True)
n_ref_classes = ref_classes.shape[0]
n_est_classes = est_classes.shape[0]
# Using coo_matrix is faster than histogram2d
return scipy.sparse.coo_matrix((np.ones(ref_class_idx.shape[0]),
(ref_class_idx, est_class_idx)),
shape=(n_ref_classes, n_est_classes),
dtype=np.int).toarray()
def _adjusted_rand_index(reference_indices, estimated_indices):
"""Compute the Rand index, adjusted for change.
Parameters
----------
reference_indices : np.ndarray
Array of reference indices
estimated_indices : np.ndarray
Array of estimated indices
Returns
-------
ari : float
Adjusted Rand index
.. note:: Based on sklearn.metrics.cluster.adjusted_rand_score
"""
n_samples = len(reference_indices)
ref_classes = np.unique(reference_indices)
est_classes = np.unique(estimated_indices)
# Special limit cases: no clustering since the data is not split;
# or trivial clustering where each document is assigned a unique cluster.
# These are perfect matches hence return 1.0.
if (ref_classes.shape[0] == est_classes.shape[0] == 1 or
ref_classes.shape[0] == est_classes.shape[0] == 0 or
(ref_classes.shape[0] == est_classes.shape[0] ==
len(reference_indices))):
return 1.0
contingency = _contingency_matrix(reference_indices, estimated_indices)
# Compute the ARI using the contingency data
sum_comb_c = sum(scipy.special.comb(n_c, 2, exact=1) for n_c in
contingency.sum(axis=1))
sum_comb_k = sum(scipy.special.comb(n_k, 2, exact=1) for n_k in
contingency.sum(axis=0))
sum_comb = sum((scipy.special.comb(n_ij, 2, exact=1) for n_ij in
contingency.flatten()))
prod_comb = (sum_comb_c * sum_comb_k)/float(scipy.special.comb(n_samples,
2))
mean_comb = (sum_comb_k + sum_comb_c)/2.
return (sum_comb - prod_comb)/(mean_comb - prod_comb)
def ari(reference_intervals, reference_labels,
estimated_intervals, estimated_labels,
frame_size=0.1):
"""Adjusted Rand Index (ARI) for frame clustering segmentation evaluation.
Examples
--------
>>> (ref_intervals,
... ref_labels) = mir_eval.io.load_labeled_intervals('ref.lab')
>>> (est_intervals,
... est_labels) = mir_eval.io.load_labeled_intervals('est.lab')
>>> # Trim or pad the estimate to match reference timing
>>> (ref_intervals,
... ref_labels) = mir_eval.util.adjust_intervals(ref_intervals,
... ref_labels,
... t_min=0)
>>> (est_intervals,
... est_labels) = mir_eval.util.adjust_intervals(
... est_intervals, est_labels, t_min=0, t_max=ref_intervals.max())
>>> ari_score = mir_eval.structure.ari(ref_intervals, ref_labels,
... est_intervals, est_labels)
Parameters
----------
reference_intervals : np.ndarray, shape=(n, 2)
reference segment intervals, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
reference_labels : list, shape=(n,)
reference segment labels, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
estimated_intervals : np.ndarray, shape=(m, 2)
estimated segment intervals, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
estimated_labels : list, shape=(m,)
estimated segment labels, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
frame_size : float > 0
length (in seconds) of frames for clustering
(Default value = 0.1)
Returns
-------
ari_score : float > 0
Adjusted Rand index between segmentations.
"""
validate_structure(reference_intervals, reference_labels,
estimated_intervals, estimated_labels)
# Check for empty annotations. Don't need to check labels because
# validate_structure makes sure they're the same size as intervals
if reference_intervals.size == 0 or estimated_intervals.size == 0:
return 0., 0., 0.
# Generate the cluster labels
y_ref = util.intervals_to_samples(reference_intervals,
reference_labels,
sample_size=frame_size)[-1]
y_ref = util.index_labels(y_ref)[0]
# Map to index space
y_est = util.intervals_to_samples(estimated_intervals,
estimated_labels,
sample_size=frame_size)[-1]
y_est = util.index_labels(y_est)[0]
return _adjusted_rand_index(y_ref, y_est)
def _mutual_info_score(reference_indices, estimated_indices, contingency=None):
"""Compute the mutual information between two sequence labelings.
Parameters
----------
reference_indices : np.ndarray
Array of reference indices
estimated_indices : np.ndarray
Array of estimated indices
contingency : np.ndarray
Pre-computed contingency matrix. If None, one will be computed.
(Default value = None)
Returns
-------
mi : float
Mutual information
.. note:: Based on sklearn.metrics.cluster.mutual_info_score
"""
if contingency is None:
contingency = _contingency_matrix(reference_indices,
estimated_indices).astype(float)
contingency_sum = np.sum(contingency)
pi = np.sum(contingency, axis=1)
pj = np.sum(contingency, axis=0)
outer = np.outer(pi, pj)
nnz = contingency != 0.0
# normalized contingency
contingency_nm = contingency[nnz]
log_contingency_nm = np.log(contingency_nm)
contingency_nm /= contingency_sum
# log(a / b) should be calculated as log(a) - log(b) for
# possible loss of precision
log_outer = -np.log(outer[nnz]) + np.log(pi.sum()) + np.log(pj.sum())
mi = (contingency_nm * (log_contingency_nm - np.log(contingency_sum)) +
contingency_nm * log_outer)
return mi.sum()
def _entropy(labels):
"""Calculates the entropy for a labeling.
Parameters
----------
labels : list-like
List of labels.
Returns
-------
entropy : float
Entropy of the labeling.
.. note:: Based on sklearn.metrics.cluster.entropy
"""
if len(labels) == 0:
return 1.0
label_idx = np.unique(labels, return_inverse=True)[1]
pi = np.bincount(label_idx).astype(np.float)
pi = pi[pi > 0]
pi_sum = np.sum(pi)
# log(a / b) should be calculated as log(a) - log(b) for
# possible loss of precision
return -np.sum((pi / pi_sum) * (np.log(pi) - np.log(pi_sum)))
def _adjusted_mutual_info_score(reference_indices, estimated_indices):
"""Compute the mutual information between two sequence labelings, adjusted for
chance.
Parameters
----------
reference_indices : np.ndarray
Array of reference indices
estimated_indices : np.ndarray
Array of estimated indices
Returns
-------
ami : float <= 1.0
Mutual information
.. note:: Based on sklearn.metrics.cluster.adjusted_mutual_info_score
and sklearn.metrics.cluster.expected_mutual_info_score
"""
n_samples = len(reference_indices)
ref_classes = np.unique(reference_indices)
est_classes = np.unique(estimated_indices)
# Special limit cases: no clustering since the data is not split.
# This is a perfect match hence return 1.0.
if (ref_classes.shape[0] == est_classes.shape[0] == 1 or
ref_classes.shape[0] == est_classes.shape[0] == 0):
return 1.0
contingency = _contingency_matrix(reference_indices,
estimated_indices).astype(float)
# Calculate the MI for the two clusterings
mi = _mutual_info_score(reference_indices, estimated_indices,
contingency=contingency)
# The following code is based on
# sklearn.metrics.cluster.expected_mutual_information
R, C = contingency.shape
N = float(n_samples)
a = np.sum(contingency, axis=1).astype(np.int32)
b = np.sum(contingency, axis=0).astype(np.int32)
# There are three major terms to the EMI equation, which are multiplied to
# and then summed over varying nij values.
# While nijs[0] will never be used, having it simplifies the indexing.
nijs = np.arange(0, max(np.max(a), np.max(b)) + 1, dtype='float')
# Stops divide by zero warnings. As its not used, no issue.
nijs[0] = 1
# term1 is nij / N
term1 = nijs / N
# term2 is log((N*nij) / (a * b)) == log(N * nij) - log(a * b)
# term2 uses the outer product
log_ab_outer = np.log(np.outer(a, b))
# term2 uses N * nij
log_Nnij = np.log(N * nijs)
# term3 is large, and involved many factorials. Calculate these in log
# space to stop overflows.
gln_a = scipy.special.gammaln(a + 1)
gln_b = scipy.special.gammaln(b + 1)
gln_Na = scipy.special.gammaln(N - a + 1)
gln_Nb = scipy.special.gammaln(N - b + 1)
gln_N = scipy.special.gammaln(N + 1)
gln_nij = scipy.special.gammaln(nijs + 1)
# start and end values for nij terms for each summation.
start = np.array([[v - N + w for w in b] for v in a], dtype='int')
start = np.maximum(start, 1)
end = np.minimum(np.resize(a, (C, R)).T, np.resize(b, (R, C))) + 1
# emi itself is a summation over the various values.
emi = 0
for i in range(R):
for j in range(C):
for nij in range(start[i, j], end[i, j]):
term2 = log_Nnij[nij] - log_ab_outer[i, j]
# Numerators are positive, denominators are negative.
gln = (gln_a[i] + gln_b[j] + gln_Na[i] + gln_Nb[j] -
gln_N - gln_nij[nij] -
scipy.special.gammaln(a[i] - nij + 1) -
scipy.special.gammaln(b[j] - nij + 1) -
scipy.special.gammaln(N - a[i] - b[j] + nij + 1))
term3 = np.exp(gln)
emi += (term1[nij] * term2 * term3)
# Calculate entropy for each labeling
h_true, h_pred = _entropy(reference_indices), _entropy(estimated_indices)
ami = (mi - emi) / (max(h_true, h_pred) - emi)
return ami
def _normalized_mutual_info_score(reference_indices, estimated_indices):
"""Compute the mutual information between two sequence labelings, adjusted for
chance.
Parameters
----------
reference_indices : np.ndarray
Array of reference indices
estimated_indices : np.ndarray
Array of estimated indices
Returns
-------
nmi : float <= 1.0
Normalized mutual information
.. note:: Based on sklearn.metrics.cluster.normalized_mutual_info_score
"""
ref_classes = np.unique(reference_indices)
est_classes = np.unique(estimated_indices)
# Special limit cases: no clustering since the data is not split.
# This is a perfect match hence return 1.0.
if (ref_classes.shape[0] == est_classes.shape[0] == 1 or
ref_classes.shape[0] == est_classes.shape[0] == 0):
return 1.0
contingency = _contingency_matrix(reference_indices,
estimated_indices).astype(float)
contingency = np.array(contingency, dtype='float')
# Calculate the MI for the two clusterings
mi = _mutual_info_score(reference_indices, estimated_indices,
contingency=contingency)
# Calculate the expected value for the mutual information
# Calculate entropy for each labeling
h_true, h_pred = _entropy(reference_indices), _entropy(estimated_indices)
nmi = mi / max(np.sqrt(h_true * h_pred), 1e-10)
return nmi
def mutual_information(reference_intervals, reference_labels,
estimated_intervals, estimated_labels,
frame_size=0.1):
"""Frame-clustering segmentation: mutual information metrics.
Examples
--------
>>> (ref_intervals,
... ref_labels) = mir_eval.io.load_labeled_intervals('ref.lab')
>>> (est_intervals,
... est_labels) = mir_eval.io.load_labeled_intervals('est.lab')
>>> # Trim or pad the estimate to match reference timing
>>> (ref_intervals,
... ref_labels) = mir_eval.util.adjust_intervals(ref_intervals,
... ref_labels,
... t_min=0)
>>> (est_intervals,
... est_labels) = mir_eval.util.adjust_intervals(
... est_intervals, est_labels, t_min=0, t_max=ref_intervals.max())
>>> mi, ami, nmi = mir_eval.structure.mutual_information(ref_intervals,
... ref_labels,
... est_intervals,
... est_labels)
Parameters
----------
reference_intervals : np.ndarray, shape=(n, 2)
reference segment intervals, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
reference_labels : list, shape=(n,)
reference segment labels, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
estimated_intervals : np.ndarray, shape=(m, 2)
estimated segment intervals, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
estimated_labels : list, shape=(m,)
estimated segment labels, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
frame_size : float > 0
length (in seconds) of frames for clustering
(Default value = 0.1)
Returns
-------
MI : float > 0
Mutual information between segmentations
AMI : float
Adjusted mutual information between segmentations.
NMI : float > 0
Normalize mutual information between segmentations
"""
validate_structure(reference_intervals, reference_labels,
estimated_intervals, estimated_labels)
# Check for empty annotations. Don't need to check labels because
# validate_structure makes sure they're the same size as intervals
if reference_intervals.size == 0 or estimated_intervals.size == 0:
return 0., 0., 0.
# Generate the cluster labels
y_ref = util.intervals_to_samples(reference_intervals,
reference_labels,
sample_size=frame_size)[-1]
y_ref = util.index_labels(y_ref)[0]
# Map to index space
y_est = util.intervals_to_samples(estimated_intervals,
estimated_labels,
sample_size=frame_size)[-1]
y_est = util.index_labels(y_est)[0]
# Mutual information
mutual_info = _mutual_info_score(y_ref, y_est)
# Adjusted mutual information
adj_mutual_info = _adjusted_mutual_info_score(y_ref, y_est)
# Normalized mutual information
norm_mutual_info = _normalized_mutual_info_score(y_ref, y_est)
return mutual_info, adj_mutual_info, norm_mutual_info
def nce(reference_intervals, reference_labels, estimated_intervals,
estimated_labels, frame_size=0.1, beta=1.0, marginal=False):
"""Frame-clustering segmentation: normalized conditional entropy
Computes cross-entropy of cluster assignment, normalized by the
max-entropy.
Examples
--------
>>> (ref_intervals,
... ref_labels) = mir_eval.io.load_labeled_intervals('ref.lab')
>>> (est_intervals,
... est_labels) = mir_eval.io.load_labeled_intervals('est.lab')
>>> # Trim or pad the estimate to match reference timing
>>> (ref_intervals,
... ref_labels) = mir_eval.util.adjust_intervals(ref_intervals,
... ref_labels,
... t_min=0)
>>> (est_intervals,
... est_labels) = mir_eval.util.adjust_intervals(
... est_intervals, est_labels, t_min=0, t_max=ref_intervals.max())
>>> S_over, S_under, S_F = mir_eval.structure.nce(ref_intervals,
... ref_labels,
... est_intervals,
... est_labels)
Parameters
----------
reference_intervals : np.ndarray, shape=(n, 2)
reference segment intervals, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
reference_labels : list, shape=(n,)
reference segment labels, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
estimated_intervals : np.ndarray, shape=(m, 2)
estimated segment intervals, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
estimated_labels : list, shape=(m,)
estimated segment labels, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
frame_size : float > 0
length (in seconds) of frames for clustering
(Default value = 0.1)
beta : float > 0
beta for F-measure
(Default value = 1.0)
marginal : bool
If `False`, normalize conditional entropy by uniform entropy.
If `True`, normalize conditional entropy by the marginal entropy.
(Default value = False)
Returns
-------
S_over
Over-clustering score:
- For `marginal=False`, ``1 - H(y_est | y_ref) / log(|y_est|)``
- For `marginal=True`, ``1 - H(y_est | y_ref) / H(y_est)``
If `|y_est|==1`, then `S_over` will be 0.
S_under
Under-clustering score:
- For `marginal=False`, ``1 - H(y_ref | y_est) / log(|y_ref|)``
- For `marginal=True`, ``1 - H(y_ref | y_est) / H(y_ref)``
If `|y_ref|==1`, then `S_under` will be 0.
S_F
F-measure for (S_over, S_under)
"""
validate_structure(reference_intervals, reference_labels,
estimated_intervals, estimated_labels)
# Check for empty annotations. Don't need to check labels because
# validate_structure makes sure they're the same size as intervals
if reference_intervals.size == 0 or estimated_intervals.size == 0:
return 0., 0., 0.
# Generate the cluster labels
y_ref = util.intervals_to_samples(reference_intervals,
reference_labels,
sample_size=frame_size)[-1]
y_ref = util.index_labels(y_ref)[0]
# Map to index space
y_est = util.intervals_to_samples(estimated_intervals,
estimated_labels,
sample_size=frame_size)[-1]
y_est = util.index_labels(y_est)[0]
# Make the contingency table: shape = (n_ref, n_est)
contingency = _contingency_matrix(y_ref, y_est).astype(float)
# Normalize by the number of frames
contingency = contingency / len(y_ref)
# Compute the marginals
p_est = contingency.sum(axis=0)
p_ref = contingency.sum(axis=1)
# H(true | prediction) = sum_j P[estimated = j] *
# sum_i P[true = i | estimated = j] log P[true = i | estimated = j]
# entropy sums over axis=0, which is true labels
true_given_est = p_est.dot(scipy.stats.entropy(contingency, base=2))
pred_given_ref = p_ref.dot(scipy.stats.entropy(contingency.T, base=2))
if marginal:
# Normalize conditional entropy by marginal entropy
z_ref = scipy.stats.entropy(p_ref, base=2)
z_est = scipy.stats.entropy(p_est, base=2)
else:
z_ref = np.log2(contingency.shape[0])
z_est = np.log2(contingency.shape[1])
score_under = 0.0
if z_ref > 0:
score_under = 1. - true_given_est / z_ref
score_over = 0.0
if z_est > 0:
score_over = 1. - pred_given_ref / z_est
f_measure = util.f_measure(score_over, score_under, beta=beta)
return score_over, score_under, f_measure
def vmeasure(reference_intervals, reference_labels, estimated_intervals,
estimated_labels, frame_size=0.1, beta=1.0):
"""Frame-clustering segmentation: v-measure
Computes cross-entropy of cluster assignment, normalized by the
marginal-entropy.
This is equivalent to `nce(..., marginal=True)`.
Examples
--------
>>> (ref_intervals,
... ref_labels) = mir_eval.io.load_labeled_intervals('ref.lab')
>>> (est_intervals,
... est_labels) = mir_eval.io.load_labeled_intervals('est.lab')
>>> # Trim or pad the estimate to match reference timing
>>> (ref_intervals,
... ref_labels) = mir_eval.util.adjust_intervals(ref_intervals,
... ref_labels,
... t_min=0)
>>> (est_intervals,
... est_labels) = mir_eval.util.adjust_intervals(
... est_intervals, est_labels, t_min=0, t_max=ref_intervals.max())
>>> V_precision, V_recall, V_F = mir_eval.structure.vmeasure(ref_intervals,
... ref_labels,
... est_intervals,
... est_labels)
Parameters
----------
reference_intervals : np.ndarray, shape=(n, 2)
reference segment intervals, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
reference_labels : list, shape=(n,)
reference segment labels, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
estimated_intervals : np.ndarray, shape=(m, 2)
estimated segment intervals, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
estimated_labels : list, shape=(m,)
estimated segment labels, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
frame_size : float > 0
length (in seconds) of frames for clustering
(Default value = 0.1)
beta : float > 0
beta for F-measure
(Default value = 1.0)
Returns
-------
V_precision
Over-clustering score:
``1 - H(y_est | y_ref) / H(y_est)``
If `|y_est|==1`, then `V_precision` will be 0.
V_recall
Under-clustering score:
``1 - H(y_ref | y_est) / H(y_ref)``
If `|y_ref|==1`, then `V_recall` will be 0.
V_F
F-measure for (V_precision, V_recall)
"""
return nce(reference_intervals, reference_labels,
estimated_intervals, estimated_labels,
frame_size=frame_size, beta=beta,
marginal=True)
def evaluate(ref_intervals, ref_labels, est_intervals, est_labels, **kwargs):
"""Compute all metrics for the given reference and estimated annotations.
Examples
--------
>>> (ref_intervals,
... ref_labels) = mir_eval.io.load_labeled_intervals('ref.lab')
>>> (est_intervals,
... est_labels) = mir_eval.io.load_labeled_intervals('est.lab')
>>> scores = mir_eval.segment.evaluate(ref_intervals, ref_labels,
... est_intervals, est_labels)
Parameters
----------
ref_intervals : np.ndarray, shape=(n, 2)
reference segment intervals, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
ref_labels : list, shape=(n,)
reference segment labels, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
est_intervals : np.ndarray, shape=(m, 2)
estimated segment intervals, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
est_labels : list, shape=(m,)
estimated segment labels, in the format returned by
:func:`mir_eval.io.load_labeled_intervals`.
kwargs
Additional keyword arguments which will be passed to the
appropriate metric or preprocessing functions.
Returns
-------
scores : dict
Dictionary of scores, where the key is the metric name (str) and
the value is the (float) score achieved.
"""
# Adjust timespan of estimations relative to ground truth
ref_intervals, ref_labels = \
util.adjust_intervals(ref_intervals, labels=ref_labels, t_min=0.0)
est_intervals, est_labels = \
util.adjust_intervals(est_intervals, labels=est_labels, t_min=0.0,
t_max=ref_intervals.max())
# Now compute all the metrics
scores = collections.OrderedDict()
# Boundary detection
# Force these values for window
kwargs['window'] = .5
scores['Precision@0.5'], scores['Recall@0.5'], scores['F-measure@0.5'] = \
util.filter_kwargs(detection, ref_intervals, est_intervals, **kwargs)
kwargs['window'] = 3.0
scores['Precision@3.0'], scores['Recall@3.0'], scores['F-measure@3.0'] = \
util.filter_kwargs(detection, ref_intervals, est_intervals, **kwargs)
# Boundary deviation
scores['Ref-to-est deviation'], scores['Est-to-ref deviation'] = \
util.filter_kwargs(deviation, ref_intervals, est_intervals, **kwargs)
# Pairwise clustering
(scores['Pairwise Precision'],
scores['Pairwise Recall'],
scores['Pairwise F-measure']) = util.filter_kwargs(pairwise,
ref_intervals,
ref_labels,
est_intervals,
est_labels, **kwargs)
# Rand index
scores['Rand Index'] = util.filter_kwargs(rand_index, ref_intervals,
ref_labels, est_intervals,
est_labels, **kwargs)
# Adjusted rand index
scores['Adjusted Rand Index'] = util.filter_kwargs(ari, ref_intervals,
ref_labels,
est_intervals,
est_labels, **kwargs)
# Mutual information metrics
(scores['Mutual Information'],
scores['Adjusted Mutual Information'],
scores['Normalized Mutual Information']) = \
util.filter_kwargs(mutual_information, ref_intervals, ref_labels,
est_intervals, est_labels, **kwargs)
# Conditional entropy metrics
scores['NCE Over'], scores['NCE Under'], scores['NCE F-measure'] = \
util.filter_kwargs(nce, ref_intervals, ref_labels, est_intervals,
est_labels, **kwargs)
# V-measure metrics
scores['V Precision'], scores['V Recall'], scores['V-measure'] = \
util.filter_kwargs(vmeasure, ref_intervals, ref_labels, est_intervals,
est_labels, **kwargs)
return scores
mir_eval-0.7/mir_eval/separation.py 0000664 0000000 0000000 00000114132 14203260312 0017466 0 ustar 00root root 0000000 0000000 # -*- coding: utf-8 -*-
'''
Source separation algorithms attempt to extract recordings of individual
sources from a recording of a mixture of sources. Evaluation methods for
source separation compare the extracted sources from reference sources and
attempt to measure the perceptual quality of the separation.
See also the bss_eval MATLAB toolbox:
http://bass-db.gforge.inria.fr/bss_eval/
Conventions
-----------
An audio signal is expected to be in the format of a 1-dimensional array where
the entries are the samples of the audio signal. When providing a group of
estimated or reference sources, they should be provided in a 2-dimensional
array, where the first dimension corresponds to the source number and the
second corresponds to the samples.
Metrics
-------
* :func:`mir_eval.separation.bss_eval_sources`: Computes the bss_eval_sources
metrics from bss_eval, which optionally optimally match the estimated sources
to the reference sources and measure the distortion and artifacts present in
the estimated sources as well as the interference between them.
* :func:`mir_eval.separation.bss_eval_sources_framewise`: Computes the
bss_eval_sources metrics on a frame-by-frame basis.
* :func:`mir_eval.separation.bss_eval_images`: Computes the bss_eval_images
metrics from bss_eval, which includes the metrics in
:func:`mir_eval.separation.bss_eval_sources` plus the image to spatial
distortion ratio.
* :func:`mir_eval.separation.bss_eval_images_framewise`: Computes the
bss_eval_images metrics on a frame-by-frame basis.
References
----------
.. [#vincent2006performance] Emmanuel Vincent, Rémi Gribonval, and Cédric
Févotte, "Performance measurement in blind audio source separation," IEEE
Trans. on Audio, Speech and Language Processing, 14(4):1462-1469, 2006.
'''
import numpy as np
import scipy.fftpack
from scipy.linalg import toeplitz
from scipy.signal import fftconvolve
import collections
import itertools
import warnings
from . import util
# The maximum allowable number of sources (prevents insane computational load)
MAX_SOURCES = 100
def validate(reference_sources, estimated_sources):
"""Checks that the input data to a metric are valid, and throws helpful
errors if not.
Parameters
----------
reference_sources : np.ndarray, shape=(nsrc, nsampl)
matrix containing true sources
estimated_sources : np.ndarray, shape=(nsrc, nsampl)
matrix containing estimated sources
"""
if reference_sources.shape != estimated_sources.shape:
raise ValueError('The shape of estimated sources and the true '
'sources should match. reference_sources.shape '
'= {}, estimated_sources.shape '
'= {}'.format(reference_sources.shape,
estimated_sources.shape))
if reference_sources.ndim > 3 or estimated_sources.ndim > 3:
raise ValueError('The number of dimensions is too high (must be less '
'than 3). reference_sources.ndim = {}, '
'estimated_sources.ndim '
'= {}'.format(reference_sources.ndim,
estimated_sources.ndim))
if reference_sources.size == 0:
warnings.warn("reference_sources is empty, should be of size "
"(nsrc, nsample). sdr, sir, sar, and perm will all "
"be empty np.ndarrays")
elif _any_source_silent(reference_sources):
raise ValueError('All the reference sources should be non-silent (not '
'all-zeros), but at least one of the reference '
'sources is all 0s, which introduces ambiguity to the'
' evaluation. (Otherwise we can add infinitely many '
'all-zero sources.)')
if estimated_sources.size == 0:
warnings.warn("estimated_sources is empty, should be of size "
"(nsrc, nsample). sdr, sir, sar, and perm will all "
"be empty np.ndarrays")
elif _any_source_silent(estimated_sources):
raise ValueError('All the estimated sources should be non-silent (not '
'all-zeros), but at least one of the estimated '
'sources is all 0s. Since we require each reference '
'source to be non-silent, having a silent estimated '
'source will result in an underdetermined system.')
if (estimated_sources.shape[0] > MAX_SOURCES or
reference_sources.shape[0] > MAX_SOURCES):
raise ValueError('The supplied matrices should be of shape (nsrc,'
' nsampl) but reference_sources.shape[0] = {} and '
'estimated_sources.shape[0] = {} which is greater '
'than mir_eval.separation.MAX_SOURCES = {}. To '
'override this check, set '
'mir_eval.separation.MAX_SOURCES to a '
'larger value.'.format(reference_sources.shape[0],
estimated_sources.shape[0],
MAX_SOURCES))
def _any_source_silent(sources):
"""Returns true if the parameter sources has any silent first dimensions"""
return np.any(np.all(np.sum(
sources, axis=tuple(range(2, sources.ndim))) == 0, axis=1))
def bss_eval_sources(reference_sources, estimated_sources,
compute_permutation=True):
"""
Ordering and measurement of the separation quality for estimated source
signals in terms of filtered true source, interference and artifacts.
The decomposition allows a time-invariant filter distortion of length
512, as described in Section III.B of [#vincent2006performance]_.
Passing ``False`` for ``compute_permutation`` will improve the computation
performance of the evaluation; however, it is not always appropriate and
is not the way that the BSS_EVAL Matlab toolbox computes bss_eval_sources.
Examples
--------
>>> # reference_sources[n] should be an ndarray of samples of the
>>> # n'th reference source
>>> # estimated_sources[n] should be the same for the n'th estimated
>>> # source
>>> (sdr, sir, sar,
... perm) = mir_eval.separation.bss_eval_sources(reference_sources,
... estimated_sources)
Parameters
----------
reference_sources : np.ndarray, shape=(nsrc, nsampl)
matrix containing true sources (must have same shape as
estimated_sources)
estimated_sources : np.ndarray, shape=(nsrc, nsampl)
matrix containing estimated sources (must have same shape as
reference_sources)
compute_permutation : bool, optional
compute permutation of estimate/source combinations (True by default)
Returns
-------
sdr : np.ndarray, shape=(nsrc,)
vector of Signal to Distortion Ratios (SDR)
sir : np.ndarray, shape=(nsrc,)
vector of Source to Interference Ratios (SIR)
sar : np.ndarray, shape=(nsrc,)
vector of Sources to Artifacts Ratios (SAR)
perm : np.ndarray, shape=(nsrc,)
vector containing the best ordering of estimated sources in
the mean SIR sense (estimated source number ``perm[j]`` corresponds to
true source number ``j``). Note: ``perm`` will be ``[0, 1, ...,
nsrc-1]`` if ``compute_permutation`` is ``False``.
References
----------
.. [#] Emmanuel Vincent, Shoko Araki, Fabian J. Theis, Guido Nolte, Pau
Bofill, Hiroshi Sawada, Alexey Ozerov, B. Vikrham Gowreesunker, Dominik
Lutter and Ngoc Q.K. Duong, "The Signal Separation Evaluation Campaign
(2007-2010): Achievements and remaining challenges", Signal Processing,
92, pp. 1928-1936, 2012.
"""
# make sure the input is of shape (nsrc, nsampl)
if estimated_sources.ndim == 1:
estimated_sources = estimated_sources[np.newaxis, :]
if reference_sources.ndim == 1:
reference_sources = reference_sources[np.newaxis, :]
validate(reference_sources, estimated_sources)
# If empty matrices were supplied, return empty lists (special case)
if reference_sources.size == 0 or estimated_sources.size == 0:
return np.array([]), np.array([]), np.array([]), np.array([])
nsrc = estimated_sources.shape[0]
# does user desire permutations?
if compute_permutation:
# compute criteria for all possible pair matches
sdr = np.empty((nsrc, nsrc))
sir = np.empty((nsrc, nsrc))
sar = np.empty((nsrc, nsrc))
for jest in range(nsrc):
for jtrue in range(nsrc):
s_true, e_spat, e_interf, e_artif = \
_bss_decomp_mtifilt(reference_sources,
estimated_sources[jest],
jtrue, 512)
sdr[jest, jtrue], sir[jest, jtrue], sar[jest, jtrue] = \
_bss_source_crit(s_true, e_spat, e_interf, e_artif)
# select the best ordering
perms = list(itertools.permutations(list(range(nsrc))))
mean_sir = np.empty(len(perms))
dum = np.arange(nsrc)
for (i, perm) in enumerate(perms):
mean_sir[i] = np.mean(sir[perm, dum])
popt = perms[np.argmax(mean_sir)]
idx = (popt, dum)
return (sdr[idx], sir[idx], sar[idx], np.asarray(popt))
else:
# compute criteria for only the simple correspondence
# (estimate 1 is estimate corresponding to reference source 1, etc.)
sdr = np.empty(nsrc)
sir = np.empty(nsrc)
sar = np.empty(nsrc)
for j in range(nsrc):
s_true, e_spat, e_interf, e_artif = \
_bss_decomp_mtifilt(reference_sources,
estimated_sources[j],
j, 512)
sdr[j], sir[j], sar[j] = \
_bss_source_crit(s_true, e_spat, e_interf, e_artif)
# return the default permutation for compatibility
popt = np.arange(nsrc)
return (sdr, sir, sar, popt)
def bss_eval_sources_framewise(reference_sources, estimated_sources,
window=30*44100, hop=15*44100,
compute_permutation=False):
"""Framewise computation of bss_eval_sources
Please be aware that this function does not compute permutations (by
default) on the possible relations between reference_sources and
estimated_sources due to the dangers of a changing permutation. Therefore
(by default), it assumes that ``reference_sources[i]`` corresponds to
``estimated_sources[i]``. To enable computing permutations please set
``compute_permutation`` to be ``True`` and check that the returned ``perm``
is identical for all windows.
NOTE: if ``reference_sources`` and ``estimated_sources`` would be evaluated
using only a single window or are shorter than the window length, the
result of :func:`mir_eval.separation.bss_eval_sources` called on
``reference_sources`` and ``estimated_sources`` (with the
``compute_permutation`` parameter passed to
:func:`mir_eval.separation.bss_eval_sources`) is returned.
Examples
--------
>>> # reference_sources[n] should be an ndarray of samples of the
>>> # n'th reference source
>>> # estimated_sources[n] should be the same for the n'th estimated
>>> # source
>>> (sdr, sir, sar,
... perm) = mir_eval.separation.bss_eval_sources_framewise(
reference_sources,
... estimated_sources)
Parameters
----------
reference_sources : np.ndarray, shape=(nsrc, nsampl)
matrix containing true sources (must have the same shape as
``estimated_sources``)
estimated_sources : np.ndarray, shape=(nsrc, nsampl)
matrix containing estimated sources (must have the same shape as
``reference_sources``)
window : int, optional
Window length for framewise evaluation (default value is 30s at a
sample rate of 44.1kHz)
hop : int, optional
Hop size for framewise evaluation (default value is 15s at a
sample rate of 44.1kHz)
compute_permutation : bool, optional
compute permutation of estimate/source combinations for all windows
(False by default)
Returns
-------
sdr : np.ndarray, shape=(nsrc, nframes)
vector of Signal to Distortion Ratios (SDR)
sir : np.ndarray, shape=(nsrc, nframes)
vector of Source to Interference Ratios (SIR)
sar : np.ndarray, shape=(nsrc, nframes)
vector of Sources to Artifacts Ratios (SAR)
perm : np.ndarray, shape=(nsrc, nframes)
vector containing the best ordering of estimated sources in
the mean SIR sense (estimated source number ``perm[j]`` corresponds to
true source number ``j``). Note: ``perm`` will be ``range(nsrc)`` for
all windows if ``compute_permutation`` is ``False``
"""
# make sure the input is of shape (nsrc, nsampl)
if estimated_sources.ndim == 1:
estimated_sources = estimated_sources[np.newaxis, :]
if reference_sources.ndim == 1:
reference_sources = reference_sources[np.newaxis, :]
validate(reference_sources, estimated_sources)
# If empty matrices were supplied, return empty lists (special case)
if reference_sources.size == 0 or estimated_sources.size == 0:
return np.array([]), np.array([]), np.array([]), np.array([])
nsrc = reference_sources.shape[0]
nwin = int(
np.floor((reference_sources.shape[1] - window + hop) / hop)
)
# if fewer than 2 windows would be evaluated, return the sources result
if nwin < 2:
result = bss_eval_sources(reference_sources,
estimated_sources,
compute_permutation)
return [np.expand_dims(score, -1) for score in result]
# compute the criteria across all windows
sdr = np.empty((nsrc, nwin))
sir = np.empty((nsrc, nwin))
sar = np.empty((nsrc, nwin))
perm = np.empty((nsrc, nwin))
# k iterates across all the windows
for k in range(nwin):
win_slice = slice(k * hop, k * hop + window)
ref_slice = reference_sources[:, win_slice]
est_slice = estimated_sources[:, win_slice]
# check for a silent frame
if (not _any_source_silent(ref_slice) and
not _any_source_silent(est_slice)):
sdr[:, k], sir[:, k], sar[:, k], perm[:, k] = bss_eval_sources(
ref_slice, est_slice, compute_permutation
)
else:
# if we have a silent frame set results as np.nan
sdr[:, k] = sir[:, k] = sar[:, k] = perm[:, k] = np.nan
return sdr, sir, sar, perm
def bss_eval_images(reference_sources, estimated_sources,
compute_permutation=True):
"""Implementation of the bss_eval_images function from the
BSS_EVAL Matlab toolbox.
Ordering and measurement of the separation quality for estimated source
signals in terms of filtered true source, interference and artifacts.
This method also provides the ISR measure.
The decomposition allows a time-invariant filter distortion of length
512, as described in Section III.B of [#vincent2006performance]_.
Passing ``False`` for ``compute_permutation`` will improve the computation
performance of the evaluation; however, it is not always appropriate and
is not the way that the BSS_EVAL Matlab toolbox computes bss_eval_images.
Examples
--------
>>> # reference_sources[n] should be an ndarray of samples of the
>>> # n'th reference source
>>> # estimated_sources[n] should be the same for the n'th estimated
>>> # source
>>> (sdr, isr, sir, sar,
... perm) = mir_eval.separation.bss_eval_images(reference_sources,
... estimated_sources)
Parameters
----------
reference_sources : np.ndarray, shape=(nsrc, nsampl, nchan)
matrix containing true sources
estimated_sources : np.ndarray, shape=(nsrc, nsampl, nchan)
matrix containing estimated sources
compute_permutation : bool, optional
compute permutation of estimate/source combinations (True by default)
Returns
-------
sdr : np.ndarray, shape=(nsrc,)
vector of Signal to Distortion Ratios (SDR)
isr : np.ndarray, shape=(nsrc,)
vector of source Image to Spatial distortion Ratios (ISR)
sir : np.ndarray, shape=(nsrc,)
vector of Source to Interference Ratios (SIR)
sar : np.ndarray, shape=(nsrc,)
vector of Sources to Artifacts Ratios (SAR)
perm : np.ndarray, shape=(nsrc,)
vector containing the best ordering of estimated sources in
the mean SIR sense (estimated source number ``perm[j]`` corresponds to
true source number ``j``). Note: ``perm`` will be ``(1,2,...,nsrc)``
if ``compute_permutation`` is ``False``.
References
----------
.. [#] Emmanuel Vincent, Shoko Araki, Fabian J. Theis, Guido Nolte, Pau
Bofill, Hiroshi Sawada, Alexey Ozerov, B. Vikrham Gowreesunker, Dominik
Lutter and Ngoc Q.K. Duong, "The Signal Separation Evaluation Campaign
(2007-2010): Achievements and remaining challenges", Signal Processing,
92, pp. 1928-1936, 2012.
"""
# make sure the input has 3 dimensions
# assuming input is in shape (nsampl) or (nsrc, nsampl)
estimated_sources = np.atleast_3d(estimated_sources)
reference_sources = np.atleast_3d(reference_sources)
# we will ensure input doesn't have more than 3 dimensions in validate
validate(reference_sources, estimated_sources)
# If empty matrices were supplied, return empty lists (special case)
if reference_sources.size == 0 or estimated_sources.size == 0:
return np.array([]), np.array([]), np.array([]), \
np.array([]), np.array([])
# determine size parameters
nsrc = estimated_sources.shape[0]
nsampl = estimated_sources.shape[1]
nchan = estimated_sources.shape[2]
# does the user desire permutation?
if compute_permutation:
# compute criteria for all possible pair matches
sdr = np.empty((nsrc, nsrc))
isr = np.empty((nsrc, nsrc))
sir = np.empty((nsrc, nsrc))
sar = np.empty((nsrc, nsrc))
for jest in range(nsrc):
for jtrue in range(nsrc):
s_true, e_spat, e_interf, e_artif = \
_bss_decomp_mtifilt_images(
reference_sources,
np.reshape(
estimated_sources[jest],
(nsampl, nchan),
order='F'
),
jtrue,
512
)
sdr[jest, jtrue], isr[jest, jtrue], \
sir[jest, jtrue], sar[jest, jtrue] = \
_bss_image_crit(s_true, e_spat, e_interf, e_artif)
# select the best ordering
perms = list(itertools.permutations(range(nsrc)))
mean_sir = np.empty(len(perms))
dum = np.arange(nsrc)
for (i, perm) in enumerate(perms):
mean_sir[i] = np.mean(sir[perm, dum])
popt = perms[np.argmax(mean_sir)]
idx = (popt, dum)
return (sdr[idx], isr[idx], sir[idx], sar[idx], np.asarray(popt))
else:
# compute criteria for only the simple correspondence
# (estimate 1 is estimate corresponding to reference source 1, etc.)
sdr = np.empty(nsrc)
isr = np.empty(nsrc)
sir = np.empty(nsrc)
sar = np.empty(nsrc)
Gj = [0] * nsrc # prepare G matrics with zeroes
G = np.zeros(1)
for j in range(nsrc):
# save G matrix to avoid recomputing it every call
s_true, e_spat, e_interf, e_artif, Gj_temp, G = \
_bss_decomp_mtifilt_images(reference_sources,
np.reshape(estimated_sources[j],
(nsampl, nchan),
order='F'),
j, 512, Gj[j], G)
Gj[j] = Gj_temp
sdr[j], isr[j], sir[j], sar[j] = \
_bss_image_crit(s_true, e_spat, e_interf, e_artif)
# return the default permutation for compatibility
popt = np.arange(nsrc)
return (sdr, isr, sir, sar, popt)
def bss_eval_images_framewise(reference_sources, estimated_sources,
window=30*44100, hop=15*44100,
compute_permutation=False):
"""Framewise computation of bss_eval_images
Please be aware that this function does not compute permutations (by
default) on the possible relations between ``reference_sources`` and
``estimated_sources`` due to the dangers of a changing permutation.
Therefore (by default), it assumes that ``reference_sources[i]``
corresponds to ``estimated_sources[i]``. To enable computing permutations
please set ``compute_permutation`` to be ``True`` and check that the
returned ``perm`` is identical for all windows.
NOTE: if ``reference_sources`` and ``estimated_sources`` would be evaluated
using only a single window or are shorter than the window length, the
result of ``bss_eval_images`` called on ``reference_sources`` and
``estimated_sources`` (with the ``compute_permutation`` parameter passed to
``bss_eval_images``) is returned
Examples
--------
>>> # reference_sources[n] should be an ndarray of samples of the
>>> # n'th reference source
>>> # estimated_sources[n] should be the same for the n'th estimated
>>> # source
>>> (sdr, isr, sir, sar,
... perm) = mir_eval.separation.bss_eval_images_framewise(
reference_sources,
... estimated_sources,
window,
.... hop)
Parameters
----------
reference_sources : np.ndarray, shape=(nsrc, nsampl, nchan)
matrix containing true sources (must have the same shape as
``estimated_sources``)
estimated_sources : np.ndarray, shape=(nsrc, nsampl, nchan)
matrix containing estimated sources (must have the same shape as
``reference_sources``)
window : int
Window length for framewise evaluation
hop : int
Hop size for framewise evaluation
compute_permutation : bool, optional
compute permutation of estimate/source combinations for all windows
(False by default)
Returns
-------
sdr : np.ndarray, shape=(nsrc, nframes)
vector of Signal to Distortion Ratios (SDR)
isr : np.ndarray, shape=(nsrc, nframes)
vector of source Image to Spatial distortion Ratios (ISR)
sir : np.ndarray, shape=(nsrc, nframes)
vector of Source to Interference Ratios (SIR)
sar : np.ndarray, shape=(nsrc, nframes)
vector of Sources to Artifacts Ratios (SAR)
perm : np.ndarray, shape=(nsrc, nframes)
vector containing the best ordering of estimated sources in
the mean SIR sense (estimated source number perm[j] corresponds to
true source number j)
Note: perm will be range(nsrc) for all windows if compute_permutation
is False
"""
# make sure the input has 3 dimensions
# assuming input is in shape (nsampl) or (nsrc, nsampl)
estimated_sources = np.atleast_3d(estimated_sources)
reference_sources = np.atleast_3d(reference_sources)
# we will ensure input doesn't have more than 3 dimensions in validate
validate(reference_sources, estimated_sources)
# If empty matrices were supplied, return empty lists (special case)
if reference_sources.size == 0 or estimated_sources.size == 0:
return np.array([]), np.array([]), np.array([]), np.array([])
nsrc = reference_sources.shape[0]
nwin = int(
np.floor((reference_sources.shape[1] - window + hop) / hop)
)
# if fewer than 2 windows would be evaluated, return the images result
if nwin < 2:
result = bss_eval_images(reference_sources,
estimated_sources,
compute_permutation)
return [np.expand_dims(score, -1) for score in result]
# compute the criteria across all windows
sdr = np.empty((nsrc, nwin))
isr = np.empty((nsrc, nwin))
sir = np.empty((nsrc, nwin))
sar = np.empty((nsrc, nwin))
perm = np.empty((nsrc, nwin))
# k iterates across all the windows
for k in range(nwin):
win_slice = slice(k * hop, k * hop + window)
ref_slice = reference_sources[:, win_slice, :]
est_slice = estimated_sources[:, win_slice, :]
# check for a silent frame
if (not _any_source_silent(ref_slice) and
not _any_source_silent(est_slice)):
sdr[:, k], isr[:, k], sir[:, k], sar[:, k], perm[:, k] = \
bss_eval_images(
ref_slice, est_slice, compute_permutation
)
else:
# if we have a silent frame set results as np.nan
sdr[:, k] = sir[:, k] = sar[:, k] = perm[:, k] = np.nan
return sdr, isr, sir, sar, perm
def _bss_decomp_mtifilt(reference_sources, estimated_source, j, flen):
"""Decomposition of an estimated source image into four components
representing respectively the true source image, spatial (or filtering)
distortion, interference and artifacts, derived from the true source
images using multichannel time-invariant filters.
"""
nsampl = estimated_source.size
# decomposition
# true source image
s_true = np.hstack((reference_sources[j], np.zeros(flen - 1)))
# spatial (or filtering) distortion
e_spat = _project(reference_sources[j, np.newaxis, :], estimated_source,
flen) - s_true
# interference
e_interf = _project(reference_sources,
estimated_source, flen) - s_true - e_spat
# artifacts
e_artif = -s_true - e_spat - e_interf
e_artif[:nsampl] += estimated_source
return (s_true, e_spat, e_interf, e_artif)
def _bss_decomp_mtifilt_images(reference_sources, estimated_source, j, flen,
Gj=None, G=None):
"""Decomposition of an estimated source image into four components
representing respectively the true source image, spatial (or filtering)
distortion, interference and artifacts, derived from the true source
images using multichannel time-invariant filters.
Adapted version to work with multichannel sources.
Improved performance can be gained by passing Gj and G parameters initially
as all zeros. These parameters store the results from the computation of
the G matrix in _project_images and then return them for subsequent calls
to this function. This only works when not computing permuations.
"""
nsampl = np.shape(estimated_source)[0]
nchan = np.shape(estimated_source)[1]
# are we saving the Gj and G parameters?
saveg = Gj is not None and G is not None
# decomposition
# true source image
s_true = np.hstack((np.reshape(reference_sources[j],
(nsampl, nchan),
order="F").transpose(),
np.zeros((nchan, flen - 1))))
# spatial (or filtering) distortion
if saveg:
e_spat, Gj = _project_images(reference_sources[j, np.newaxis, :],
estimated_source, flen, Gj)
else:
e_spat = _project_images(reference_sources[j, np.newaxis, :],
estimated_source, flen)
e_spat = e_spat - s_true
# interference
if saveg:
e_interf, G = _project_images(reference_sources,
estimated_source, flen, G)
else:
e_interf = _project_images(reference_sources,
estimated_source, flen)
e_interf = e_interf - s_true - e_spat
# artifacts
e_artif = -s_true - e_spat - e_interf
e_artif[:, :nsampl] += estimated_source.transpose()
# return Gj and G only if they were passed in
if saveg:
return (s_true, e_spat, e_interf, e_artif, Gj, G)
else:
return (s_true, e_spat, e_interf, e_artif)
def _project(reference_sources, estimated_source, flen):
"""Least-squares projection of estimated source on the subspace spanned by
delayed versions of reference sources, with delays between 0 and flen-1
"""
nsrc = reference_sources.shape[0]
nsampl = reference_sources.shape[1]
# computing coefficients of least squares problem via FFT ##
# zero padding and FFT of input data
reference_sources = np.hstack((reference_sources,
np.zeros((nsrc, flen - 1))))
estimated_source = np.hstack((estimated_source, np.zeros(flen - 1)))
n_fft = int(2**np.ceil(np.log2(nsampl + flen - 1.)))
sf = scipy.fftpack.fft(reference_sources, n=n_fft, axis=1)
sef = scipy.fftpack.fft(estimated_source, n=n_fft)
# inner products between delayed versions of reference_sources
G = np.zeros((nsrc * flen, nsrc * flen))
for i in range(nsrc):
for j in range(nsrc):
ssf = sf[i] * np.conj(sf[j])
ssf = np.real(scipy.fftpack.ifft(ssf))
ss = toeplitz(np.hstack((ssf[0], ssf[-1:-flen:-1])),
r=ssf[:flen])
G[i * flen: (i+1) * flen, j * flen: (j+1) * flen] = ss
G[j * flen: (j+1) * flen, i * flen: (i+1) * flen] = ss.T
# inner products between estimated_source and delayed versions of
# reference_sources
D = np.zeros(nsrc * flen)
for i in range(nsrc):
ssef = sf[i] * np.conj(sef)
ssef = np.real(scipy.fftpack.ifft(ssef))
D[i * flen: (i+1) * flen] = np.hstack((ssef[0], ssef[-1:-flen:-1]))
# Computing projection
# Distortion filters
try:
C = np.linalg.solve(G, D).reshape(flen, nsrc, order='F')
except np.linalg.linalg.LinAlgError:
C = np.linalg.lstsq(G, D)[0].reshape(flen, nsrc, order='F')
# Filtering
sproj = np.zeros(nsampl + flen - 1)
for i in range(nsrc):
sproj += fftconvolve(C[:, i], reference_sources[i])[:nsampl + flen - 1]
return sproj
def _project_images(reference_sources, estimated_source, flen, G=None):
"""Least-squares projection of estimated source on the subspace spanned by
delayed versions of reference sources, with delays between 0 and flen-1.
Passing G as all zeros will populate the G matrix and return it so it can
be passed into the next call to avoid recomputing G (this will only works
if not computing permutations).
"""
nsrc = reference_sources.shape[0]
nsampl = reference_sources.shape[1]
nchan = reference_sources.shape[2]
reference_sources = np.reshape(np.transpose(reference_sources, (2, 0, 1)),
(nchan*nsrc, nsampl), order='F')
# computing coefficients of least squares problem via FFT ##
# zero padding and FFT of input data
reference_sources = np.hstack((reference_sources,
np.zeros((nchan*nsrc, flen - 1))))
estimated_source = \
np.hstack((estimated_source.transpose(), np.zeros((nchan, flen - 1))))
n_fft = int(2**np.ceil(np.log2(nsampl + flen - 1.)))
sf = scipy.fftpack.fft(reference_sources, n=n_fft, axis=1)
sef = scipy.fftpack.fft(estimated_source, n=n_fft)
# inner products between delayed versions of reference_sources
if G is None:
saveg = False
G = np.zeros((nchan * nsrc * flen, nchan * nsrc * flen))
for i in range(nchan * nsrc):
for j in range(i+1):
ssf = sf[i] * np.conj(sf[j])
ssf = np.real(scipy.fftpack.ifft(ssf))
ss = toeplitz(np.hstack((ssf[0], ssf[-1:-flen:-1])),
r=ssf[:flen])
G[i * flen: (i+1) * flen, j * flen: (j+1) * flen] = ss
G[j * flen: (j+1) * flen, i * flen: (i+1) * flen] = ss.T
else: # avoid recomputing G (only works if no permutation is desired)
saveg = True # return G
if np.all(G == 0): # only compute G if passed as 0
G = np.zeros((nchan * nsrc * flen, nchan * nsrc * flen))
for i in range(nchan * nsrc):
for j in range(i+1):
ssf = sf[i] * np.conj(sf[j])
ssf = np.real(scipy.fftpack.ifft(ssf))
ss = toeplitz(np.hstack((ssf[0], ssf[-1:-flen:-1])),
r=ssf[:flen])
G[i * flen: (i+1) * flen, j * flen: (j+1) * flen] = ss
G[j * flen: (j+1) * flen, i * flen: (i+1) * flen] = ss.T
# inner products between estimated_source and delayed versions of
# reference_sources
D = np.zeros((nchan * nsrc * flen, nchan))
for k in range(nchan * nsrc):
for i in range(nchan):
ssef = sf[k] * np.conj(sef[i])
ssef = np.real(scipy.fftpack.ifft(ssef))
D[k * flen: (k+1) * flen, i] = \
np.hstack((ssef[0], ssef[-1:-flen:-1])).transpose()
# Computing projection
# Distortion filters
try:
C = np.linalg.solve(G, D).reshape(flen, nchan*nsrc, nchan, order='F')
except np.linalg.linalg.LinAlgError:
C = np.linalg.lstsq(G, D)[0].reshape(flen, nchan*nsrc, nchan,
order='F')
# Filtering
sproj = np.zeros((nchan, nsampl + flen - 1))
for k in range(nchan * nsrc):
for i in range(nchan):
sproj[i] += fftconvolve(C[:, k, i].transpose(),
reference_sources[k])[:nsampl + flen - 1]
# return G only if it was passed in
if saveg:
return sproj, G
else:
return sproj
def _bss_source_crit(s_true, e_spat, e_interf, e_artif):
"""Measurement of the separation quality for a given source in terms of
filtered true source, interference and artifacts.
"""
# energy ratios
s_filt = s_true + e_spat
sdr = _safe_db(np.sum(s_filt**2), np.sum((e_interf + e_artif)**2))
sir = _safe_db(np.sum(s_filt**2), np.sum(e_interf**2))
sar = _safe_db(np.sum((s_filt + e_interf)**2), np.sum(e_artif**2))
return (sdr, sir, sar)
def _bss_image_crit(s_true, e_spat, e_interf, e_artif):
"""Measurement of the separation quality for a given image in terms of
filtered true source, spatial error, interference and artifacts.
"""
# energy ratios
sdr = _safe_db(np.sum(s_true**2), np.sum((e_spat+e_interf+e_artif)**2))
isr = _safe_db(np.sum(s_true**2), np.sum(e_spat**2))
sir = _safe_db(np.sum((s_true+e_spat)**2), np.sum(e_interf**2))
sar = _safe_db(np.sum((s_true+e_spat+e_interf)**2), np.sum(e_artif**2))
return (sdr, isr, sir, sar)
def _safe_db(num, den):
"""Properly handle the potential +Inf db SIR, instead of raising a
RuntimeWarning. Only denominator is checked because the numerator can never
be 0.
"""
if den == 0:
return np.Inf
return 10 * np.log10(num / den)
def evaluate(reference_sources, estimated_sources, **kwargs):
"""Compute all metrics for the given reference and estimated signals.
NOTE: This will always compute :func:`mir_eval.separation.bss_eval_images`
for any valid input and will additionally compute
:func:`mir_eval.separation.bss_eval_sources` for valid input with fewer
than 3 dimensions.
Examples
--------
>>> # reference_sources[n] should be an ndarray of samples of the
>>> # n'th reference source
>>> # estimated_sources[n] should be the same for the n'th estimated source
>>> scores = mir_eval.separation.evaluate(reference_sources,
... estimated_sources)
Parameters
----------
reference_sources : np.ndarray, shape=(nsrc, nsampl[, nchan])
matrix containing true sources
estimated_sources : np.ndarray, shape=(nsrc, nsampl[, nchan])
matrix containing estimated sources
kwargs
Additional keyword arguments which will be passed to the
appropriate metric or preprocessing functions.
Returns
-------
scores : dict
Dictionary of scores, where the key is the metric name (str) and
the value is the (float) score achieved.
"""
# Compute all the metrics
scores = collections.OrderedDict()
sdr, isr, sir, sar, perm = util.filter_kwargs(
bss_eval_images,
reference_sources,
estimated_sources,
**kwargs
)
scores['Images - Source to Distortion'] = sdr.tolist()
scores['Images - Image to Spatial'] = isr.tolist()
scores['Images - Source to Interference'] = sir.tolist()
scores['Images - Source to Artifact'] = sar.tolist()
scores['Images - Source permutation'] = perm.tolist()
sdr, isr, sir, sar, perm = util.filter_kwargs(
bss_eval_images_framewise,
reference_sources,
estimated_sources,
**kwargs
)
scores['Images Frames - Source to Distortion'] = sdr.tolist()
scores['Images Frames - Image to Spatial'] = isr.tolist()
scores['Images Frames - Source to Interference'] = sir.tolist()
scores['Images Frames - Source to Artifact'] = sar.tolist()
scores['Images Frames - Source permutation'] = perm.tolist()
# Verify we can compute sources on this input
if reference_sources.ndim < 3 and estimated_sources.ndim < 3:
sdr, sir, sar, perm = util.filter_kwargs(
bss_eval_sources_framewise,
reference_sources,
estimated_sources,
**kwargs
)
scores['Sources Frames - Source to Distortion'] = sdr.tolist()
scores['Sources Frames - Source to Interference'] = sir.tolist()
scores['Sources Frames - Source to Artifact'] = sar.tolist()
scores['Sources Frames - Source permutation'] = perm.tolist()
sdr, sir, sar, perm = util.filter_kwargs(
bss_eval_sources,
reference_sources,
estimated_sources,
**kwargs
)
scores['Sources - Source to Distortion'] = sdr.tolist()
scores['Sources - Source to Interference'] = sir.tolist()
scores['Sources - Source to Artifact'] = sar.tolist()
scores['Sources - Source permutation'] = perm.tolist()
return scores
mir_eval-0.7/mir_eval/sonify.py 0000664 0000000 0000000 00000025215 14203260312 0016633 0 ustar 00root root 0000000 0000000 '''
Methods which sonify annotations for "evaluation by ear".
All functions return a raw signal at the specified sampling rate.
'''
import numpy as np
from numpy.lib.stride_tricks import as_strided
from scipy.interpolate import interp1d
from . import util
from . import chord
def clicks(times, fs, click=None, length=None):
"""Returns a signal with the signal 'click' placed at each specified time
Parameters
----------
times : np.ndarray
times to place clicks, in seconds
fs : int
desired sampling rate of the output signal
click : np.ndarray
click signal, defaults to a 1 kHz blip
length : int
desired number of samples in the output signal,
defaults to ``times.max()*fs + click.shape[0] + 1``
Returns
-------
click_signal : np.ndarray
Synthesized click signal
"""
# Create default click signal
if click is None:
# 1 kHz tone, 100ms
click = np.sin(2*np.pi*np.arange(fs*.1)*1000/(1.*fs))
# Exponential decay
click *= np.exp(-np.arange(fs*.1)/(fs*.01))
# Set default length
if length is None:
length = int(times.max()*fs + click.shape[0] + 1)
# Pre-allocate click signal
click_signal = np.zeros(length)
# Place clicks
for time in times:
# Compute the boundaries of the click
start = int(time*fs)
end = start + click.shape[0]
# Make sure we don't try to output past the end of the signal
if start >= length:
break
if end >= length:
click_signal[start:] = click[:length - start]
break
# Normally, just add a click here
click_signal[start:end] = click
return click_signal
def time_frequency(gram, frequencies, times, fs, function=np.sin, length=None,
n_dec=1):
"""Reverse synthesis of a time-frequency representation of a signal
Parameters
----------
gram : np.ndarray
``gram[n, m]`` is the magnitude of ``frequencies[n]``
from ``times[m]`` to ``times[m + 1]``
Non-positive magnitudes are interpreted as silence.
frequencies : np.ndarray
array of size ``gram.shape[0]`` denoting the frequency of
each row of gram
times : np.ndarray, shape= ``(gram.shape[1],)`` or ``(gram.shape[1], 2)``
Either the start time of each column in the gram,
or the time interval corresponding to each column.
fs : int
desired sampling rate of the output signal
function : function
function to use to synthesize notes, should be :math:`2\pi`-periodic
length : int
desired number of samples in the output signal,
defaults to ``times[-1]*fs``
n_dec : int
the number of decimals used to approximate each sonfied frequency.
Defaults to 1 decimal place. Higher precision will be slower.
Returns
-------
output : np.ndarray
synthesized version of the piano roll
"""
# Default value for length
if times.ndim == 1:
# Convert to intervals
times = util.boundaries_to_intervals(times)
if length is None:
length = int(times[-1, 1] * fs)
times, _ = util.adjust_intervals(times, t_max=length)
# Truncate times so that the shape matches gram
n_times = gram.shape[1]
times = times[:n_times]
def _fast_synthesize(frequency):
"""A faster way to synthesize a signal.
Generate one cycle, and simulate arbitrary repetitions
using array indexing tricks.
"""
# hack so that we can ensure an integer number of periods and samples
# rounds frequency to 1st decimal, s.t. 10 * frequency will be an int
frequency = np.round(frequency, n_dec)
# Generate 10*frequency periods at this frequency
# Equivalent to n_samples = int(n_periods * fs / frequency)
# n_periods = 10*frequency is the smallest integer that guarantees
# that n_samples will be an integer, since assuming 10*frequency
# is an integer
n_samples = int(10.0**n_dec * fs)
short_signal = function(2.0 * np.pi * np.arange(n_samples) *
frequency / fs)
# Calculate the number of loops we need to fill the duration
n_repeats = int(np.ceil(length/float(short_signal.shape[0])))
# Simulate tiling the short buffer by using stride tricks
long_signal = as_strided(short_signal,
shape=(n_repeats, len(short_signal)),
strides=(0, short_signal.itemsize))
# Use a flatiter to simulate a long 1D buffer
return long_signal.flat
def _const_interpolator(value):
"""Return a function that returns `value`
no matter the input.
"""
def __interpolator(x):
return value
return __interpolator
# Threshold the tfgram to remove non-positive values
gram = np.maximum(gram, 0)
# Pre-allocate output signal
output = np.zeros(length)
time_centers = np.mean(times, axis=1) * float(fs)
for n, frequency in enumerate(frequencies):
# Get a waveform of length samples at this frequency
wave = _fast_synthesize(frequency)
# Interpolate the values in gram over the time grid
if len(time_centers) > 1:
gram_interpolator = interp1d(
time_centers, gram[n, :],
kind='linear', bounds_error=False,
fill_value=(gram[n, 0], gram[n, -1]))
# If only one time point, create constant interpolator
else:
gram_interpolator = _const_interpolator(gram[n, 0])
# Scale each time interval by the piano roll magnitude
for m, (start, end) in enumerate((times * fs).astype(int)):
# Clip the timings to make sure the indices are valid
start, end = max(start, 0), min(end, length)
# add to waveform
output[start:end] += (
wave[start:end] * gram_interpolator(np.arange(start, end)))
# Normalize, but only if there's non-zero values
norm = np.abs(output).max()
if norm >= np.finfo(output.dtype).tiny:
output /= norm
return output
def pitch_contour(times, frequencies, fs, amplitudes=None, function=np.sin,
length=None, kind='linear'):
'''Sonify a pitch contour.
Parameters
----------
times : np.ndarray
time indices for each frequency measurement, in seconds
frequencies : np.ndarray
frequency measurements, in Hz.
Non-positive measurements will be interpreted as un-voiced samples.
fs : int
desired sampling rate of the output signal
amplitudes : np.ndarray
amplitude measurments, nonnegative
defaults to ``np.ones((length,))``
function : function
function to use to synthesize notes, should be :math:`2\pi`-periodic
length : int
desired number of samples in the output signal,
defaults to ``max(times)*fs``
kind : str
Interpolation mode for the frequency and amplitude values.
See: ``scipy.interpolate.interp1d`` for valid settings.
Returns
-------
output : np.ndarray
synthesized version of the pitch contour
'''
fs = float(fs)
if length is None:
length = int(times.max() * fs)
# Squash the negative frequencies.
# wave(0) = 0, so clipping here will un-voice the corresponding instants
frequencies = np.maximum(frequencies, 0.0)
# Build a frequency interpolator
f_interp = interp1d(times * fs, 2 * np.pi * frequencies / fs, kind=kind,
fill_value=0.0, bounds_error=False, copy=False)
# Estimate frequency at sample points
f_est = f_interp(np.arange(length))
if amplitudes is None:
a_est = np.ones((length, ))
else:
# build an amplitude interpolator
a_interp = interp1d(
times * fs, amplitudes, kind=kind,
fill_value=0.0, bounds_error=False, copy=False)
a_est = a_interp(np.arange(length))
# Sonify the waveform
return a_est * function(np.cumsum(f_est))
def chroma(chromagram, times, fs, **kwargs):
"""Reverse synthesis of a chromagram (semitone matrix)
Parameters
----------
chromagram : np.ndarray, shape=(12, times.shape[0])
Chromagram matrix, where each row represents a semitone [C->Bb]
i.e., ``chromagram[3, j]`` is the magnitude of D# from ``times[j]`` to
``times[j + 1]``
times: np.ndarray, shape=(len(chord_labels),) or (len(chord_labels), 2)
Either the start time of each column in the chromagram,
or the time interval corresponding to each column.
fs : int
Sampling rate to synthesize audio data at
kwargs
Additional keyword arguments to pass to
:func:`mir_eval.sonify.time_frequency`
Returns
-------
output : np.ndarray
Synthesized chromagram
"""
# We'll just use time_frequency with a Shepard tone-gram
# To create the Shepard tone-gram, we copy the chromagram across 7 octaves
n_octaves = 7
# starting from C2
base_note = 24
# and weight each octave by a normal distribution
# The normal distribution has mean 72 (one octave above middle C)
# and std 6 (one half octave)
mean = 72
std = 6
notes = np.arange(12*n_octaves) + base_note
shepard_weight = np.exp(-(notes - mean)**2./(2.*std**2.))
# Copy the chromagram matrix vertically n_octaves times
gram = np.tile(chromagram.T, n_octaves).T
# This fixes issues if the supplied chromagram is int type
gram = gram.astype(float)
# Apply Sheppard weighting
gram *= shepard_weight.reshape(-1, 1)
# Compute frequencies
frequencies = 440.0*(2.0**((notes - 69)/12.0))
return time_frequency(gram, frequencies, times, fs, **kwargs)
def chords(chord_labels, intervals, fs, **kwargs):
"""Synthesizes chord labels
Parameters
----------
chord_labels : list of str
List of chord label strings.
intervals : np.ndarray, shape=(len(chord_labels), 2)
Start and end times of each chord label
fs : int
Sampling rate to synthesize at
kwargs
Additional keyword arguments to pass to
:func:`mir_eval.sonify.time_frequency`
Returns
-------
output : np.ndarray
Synthesized chord labels
"""
util.validate_intervals(intervals)
# Convert from labels to chroma
roots, interval_bitmaps, _ = chord.encode_many(chord_labels)
chromagram = np.array([np.roll(interval_bitmap, root)
for (interval_bitmap, root)
in zip(interval_bitmaps, roots)]).T
return chroma(chromagram, intervals, fs, **kwargs)
mir_eval-0.7/mir_eval/tempo.py 0000664 0000000 0000000 00000012344 14203260312 0016447 0 ustar 00root root 0000000 0000000 '''
The goal of a tempo estimation algorithm is to automatically detect the tempo
of a piece of music, measured in beats per minute (BPM).
See http://www.music-ir.org/mirex/wiki/2014:Audio_Tempo_Estimation for a
description of the task and evaluation criteria.
Conventions
-----------
Reference and estimated tempi should be positive, and provided in ascending
order as a numpy array of length 2.
The weighting value from the reference must be a float in the range [0, 1].
Metrics
-------
* :func:`mir_eval.tempo.detection`: Relative error, hits, and weighted
precision of tempo estimation.
'''
import warnings
import numpy as np
import collections
from . import util
def validate_tempi(tempi, reference=True):
"""Checks that there are two non-negative tempi.
For a reference value, at least one tempo has to be greater than zero.
Parameters
----------
tempi : np.ndarray
length-2 array of tempo, in bpm
reference : bool
indicates a reference value
"""
if tempi.size != 2:
raise ValueError('tempi must have exactly two values')
if not np.all(np.isfinite(tempi)) or np.any(tempi < 0):
raise ValueError('tempi={} must be non-negative numbers'.format(tempi))
if reference and np.all(tempi == 0):
raise ValueError('reference tempi={} must have one'
' value greater than zero'.format(tempi))
def validate(reference_tempi, reference_weight, estimated_tempi):
"""Checks that the input annotations to a metric look like valid tempo
annotations.
Parameters
----------
reference_tempi : np.ndarray
reference tempo values, in bpm
reference_weight : float
perceptual weight of slow vs fast in reference
estimated_tempi : np.ndarray
estimated tempo values, in bpm
"""
validate_tempi(reference_tempi, reference=True)
validate_tempi(estimated_tempi, reference=False)
if reference_weight < 0 or reference_weight > 1:
raise ValueError('Reference weight must lie in range [0, 1]')
def detection(reference_tempi, reference_weight, estimated_tempi, tol=0.08):
"""Compute the tempo detection accuracy metric.
Parameters
----------
reference_tempi : np.ndarray, shape=(2,)
Two non-negative reference tempi
reference_weight : float > 0
The relative strength of ``reference_tempi[0]`` vs
``reference_tempi[1]``.
estimated_tempi : np.ndarray, shape=(2,)
Two non-negative estimated tempi.
tol : float in [0, 1]:
The maximum allowable deviation from a reference tempo to
count as a hit.
``|est_t - ref_t| <= tol * ref_t``
(Default value = 0.08)
Returns
-------
p_score : float in [0, 1]
Weighted average of recalls:
``reference_weight * hits[0] + (1 - reference_weight) * hits[1]``
one_correct : bool
True if at least one reference tempo was correctly estimated
both_correct : bool
True if both reference tempi were correctly estimated
Raises
------
ValueError
If the input tempi are ill-formed
If the reference weight is not in the range [0, 1]
If ``tol < 0`` or ``tol > 1``.
"""
validate(reference_tempi, reference_weight, estimated_tempi)
if tol < 0 or tol > 1:
raise ValueError('invalid tolerance {}: must lie in the range '
'[0, 1]'.format(tol))
if tol == 0.:
warnings.warn('A tolerance of 0.0 may not '
'lead to the results you expect.')
hits = [False, False]
for i, ref_t in enumerate(reference_tempi):
if ref_t > 0:
# Compute the relative error for this reference tempo
f_ref_t = float(ref_t)
relative_error = np.min(np.abs(ref_t - estimated_tempi) / f_ref_t)
# Count the hits
hits[i] = relative_error <= tol
p_score = reference_weight * hits[0] + (1.0-reference_weight) * hits[1]
one_correct = bool(np.max(hits))
both_correct = bool(np.min(hits))
return p_score, one_correct, both_correct
def evaluate(reference_tempi, reference_weight, estimated_tempi, **kwargs):
"""Compute all metrics for the given reference and estimated annotations.
Parameters
----------
reference_tempi : np.ndarray, shape=(2,)
Two non-negative reference tempi
reference_weight : float > 0
The relative strength of ``reference_tempi[0]`` vs
``reference_tempi[1]``.
estimated_tempi : np.ndarray, shape=(2,)
Two non-negative estimated tempi.
kwargs
Additional keyword arguments which will be passed to the
appropriate metric or preprocessing functions.
Returns
-------
scores : dict
Dictionary of scores, where the key is the metric name (str) and
the value is the (float) score achieved.
"""
# Compute all metrics
scores = collections.OrderedDict()
(scores['P-score'],
scores['One-correct'],
scores['Both-correct']) = util.filter_kwargs(detection, reference_tempi,
reference_weight,
estimated_tempi,
**kwargs)
return scores
mir_eval-0.7/mir_eval/transcription.py 0000664 0000000 0000000 00000110340 14203260312 0020215 0 ustar 00root root 0000000 0000000 '''
The aim of a transcription algorithm is to produce a symbolic representation of
a recorded piece of music in the form of a set of discrete notes. There are
different ways to represent notes symbolically. Here we use the piano-roll
convention, meaning each note has a start time, a duration (or end time), and
a single, constant, pitch value. Pitch values can be quantized (e.g. to a
semitone grid tuned to 440 Hz), but do not have to be. Also, the transcription
can contain the notes of a single instrument or voice (for example the melody),
or the notes of all instruments/voices in the recording. This module is
instrument agnostic: all notes in the estimate are compared against all notes
in the reference.
There are many metrics for evaluating transcription algorithms. Here we limit
ourselves to the most simple and commonly used: given two sets of notes, we
count how many estimated notes match the reference, and how many do not. Based
on these counts we compute the precision, recall, f-measure and overlap ratio
of the estimate given the reference. The default criteria for considering two
notes to be a match are adopted from the `MIREX Multiple fundamental frequency
estimation and tracking, Note Tracking subtask (task 2)
`_:
"This subtask is evaluated in two different ways. In the first setup , a
returned note is assumed correct if its onset is within +-50ms of a reference
note and its F0 is within +- quarter tone of the corresponding reference note,
ignoring the returned offset values. In the second setup, on top of the above
requirements, a correct returned note is required to have an offset value
within 20% of the reference note's duration around the reference note's
offset, or within 50ms whichever is larger."
In short, we compute precision, recall, f-measure and overlap ratio, once
without taking offsets into account, and the second time with.
For further details see Salamon, 2013 (page 186), and references therein:
Salamon, J. (2013). Melody Extraction from Polyphonic Music Signals.
Ph.D. thesis, Universitat Pompeu Fabra, Barcelona, Spain, 2013.
IMPORTANT NOTE: the evaluation code in ``mir_eval`` contains several important
differences with respect to the code used in MIREX 2015 for the Note Tracking
subtask on the Su dataset (henceforth "MIREX"):
1. ``mir_eval`` uses bipartite graph matching to find the optimal pairing of
reference notes to estimated notes. MIREX uses a greedy matching algorithm,
which can produce sub-optimal note matching. This will result in
``mir_eval``'s metrics being slightly higher compared to MIREX.
2. MIREX rounds down the onset and offset times of each note to 2 decimal
points using ``new_time = 0.01 * floor(time*100)``. ``mir_eval`` rounds down
the note onset and offset times to 4 decinal points. This will bring our
metrics down a notch compared to the MIREX results.
3. In the MIREX wiki, the criterion for matching offsets is that they must be
within ``0.2 * ref_duration`` **or 0.05 seconds from each other, whichever
is greater** (i.e. ``offset_dif <= max(0.2 * ref_duration, 0.05)``. The
MIREX code however only uses a threshold of ``0.2 * ref_duration``, without
the 0.05 second minimum. Since ``mir_eval`` does include this minimum, it
might produce slightly higher results compared to MIREX.
This means that differences 1 and 3 bring ``mir_eval``'s metrics up compared to
MIREX, whilst 2 brings them down. Based on internal testing, overall the effect
of these three differences is that the Precision, Recall and F-measure returned
by ``mir_eval`` will be higher compared to MIREX by about 1%-2%.
Finally, note that different evaluation scripts have been used for the Multi-F0
Note Tracking task in MIREX over the years. In particular, some scripts used
``<`` for matching onsets, offsets, and pitch values, whilst the others used
``<=`` for these checks. ``mir_eval`` provides both options: by default the
latter (``<=``) is used, but you can set ``strict=True`` when calling
:func:`mir_eval.transcription.precision_recall_f1_overlap()` in which case
``<`` will be used. The default value (``strict=False``) is the same as that
used in MIREX 2015 for the Note Tracking subtask on the Su dataset.
Conventions
-----------
Notes should be provided in the form of an interval array and a pitch array.
The interval array contains two columns, one for note onsets and the second
for note offsets (each row represents a single note). The pitch array contains
one column with the corresponding note pitch values (one value per note),
represented by their fundamental frequency (f0) in Hertz.
Metrics
-------
* :func:`mir_eval.transcription.precision_recall_f1_overlap`: The precision,
recall, F-measure, and Average Overlap Ratio of the note transcription,
where an estimated note is considered correct if its pitch, onset and
(optionally) offset are sufficiently close to a reference note.
* :func:`mir_eval.transcription.onset_precision_recall_f1`: The precision,
recall and F-measure of the note transcription, where an estimated note is
considered correct if its onset is sufficiently close to a reference note's
onset. That is, these metrics are computed taking only note onsets into
account, meaning two notes could be matched even if they have very different
pitch values.
* :func:`mir_eval.transcription.offset_precision_recall_f1`: The precision,
recall and F-measure of the note transcription, where an estimated note is
considered correct if its offset is sufficiently close to a reference note's
offset. That is, these metrics are computed taking only note offsets into
account, meaning two notes could be matched even if they have very different
pitch values.
'''
import numpy as np
import collections
from . import util
import warnings
# The number of decimals to keep for onset/offset threshold checks
N_DECIMALS = 4
def validate(ref_intervals, ref_pitches, est_intervals, est_pitches):
"""Checks that the input annotations to a metric look like time intervals
and a pitch list, and throws helpful errors if not.
Parameters
----------
ref_intervals : np.ndarray, shape=(n,2)
Array of reference notes time intervals (onset and offset times)
ref_pitches : np.ndarray, shape=(n,)
Array of reference pitch values in Hertz
est_intervals : np.ndarray, shape=(m,2)
Array of estimated notes time intervals (onset and offset times)
est_pitches : np.ndarray, shape=(m,)
Array of estimated pitch values in Hertz
"""
# Validate intervals
validate_intervals(ref_intervals, est_intervals)
# Make sure intervals and pitches match in length
if not ref_intervals.shape[0] == ref_pitches.shape[0]:
raise ValueError('Reference intervals and pitches have different '
'lengths.')
if not est_intervals.shape[0] == est_pitches.shape[0]:
raise ValueError('Estimated intervals and pitches have different '
'lengths.')
# Make sure all pitch values are positive
if ref_pitches.size > 0 and np.min(ref_pitches) <= 0:
raise ValueError("Reference contains at least one non-positive pitch "
"value")
if est_pitches.size > 0 and np.min(est_pitches) <= 0:
raise ValueError("Estimate contains at least one non-positive pitch "
"value")
def validate_intervals(ref_intervals, est_intervals):
"""Checks that the input annotations to a metric look like time intervals,
and throws helpful errors if not.
Parameters
----------
ref_intervals : np.ndarray, shape=(n,2)
Array of reference notes time intervals (onset and offset times)
est_intervals : np.ndarray, shape=(m,2)
Array of estimated notes time intervals (onset and offset times)
"""
# If reference or estimated notes are empty, warn
if ref_intervals.size == 0:
warnings.warn("Reference notes are empty.")
if est_intervals.size == 0:
warnings.warn("Estimated notes are empty.")
# Validate intervals
util.validate_intervals(ref_intervals)
util.validate_intervals(est_intervals)
def match_note_offsets(ref_intervals, est_intervals, offset_ratio=0.2,
offset_min_tolerance=0.05, strict=False):
"""Compute a maximum matching between reference and estimated notes,
only taking note offsets into account.
Given two note sequences represented by ``ref_intervals`` and
``est_intervals`` (see :func:`mir_eval.io.load_valued_intervals`), we seek
the largest set of correspondences ``(i, j)`` such that the offset of
reference note ``i`` has to be within ``offset_tolerance`` of the offset of
estimated note ``j``, where ``offset_tolerance`` is equal to
``offset_ratio`` times the reference note's duration, i.e. ``offset_ratio
* ref_duration[i]`` where ``ref_duration[i] = ref_intervals[i, 1] -
ref_intervals[i, 0]``. If the resulting ``offset_tolerance`` is less than
``offset_min_tolerance`` (50 ms by default) then ``offset_min_tolerance``
is used instead.
Every reference note is matched against at most one estimated note.
Note there are separate functions :func:`match_note_onsets` and
:func:`match_notes` for matching notes based on onsets only or based on
onset, offset, and pitch, respectively. This is because the rules for
matching note onsets and matching note offsets are different.
Parameters
----------
ref_intervals : np.ndarray, shape=(n,2)
Array of reference notes time intervals (onset and offset times)
est_intervals : np.ndarray, shape=(m,2)
Array of estimated notes time intervals (onset and offset times)
offset_ratio : float > 0
The ratio of the reference note's duration used to define the
``offset_tolerance``. Default is 0.2 (20%), meaning the
``offset_tolerance`` will equal the ``ref_duration * 0.2``, or 0.05 (50
ms), whichever is greater.
offset_min_tolerance : float > 0
The minimum tolerance for offset matching. See ``offset_ratio``
description for an explanation of how the offset tolerance is
determined.
strict : bool
If ``strict=False`` (the default), threshold checks for offset
matching are performed using ``<=`` (less than or equal). If
``strict=True``, the threshold checks are performed using ``<`` (less
than).
Returns
-------
matching : list of tuples
A list of matched reference and estimated notes.
``matching[i] == (i, j)`` where reference note ``i`` matches estimated
note ``j``.
"""
# set the comparison function
if strict:
cmp_func = np.less
else:
cmp_func = np.less_equal
# check for offset matches
offset_distances = np.abs(np.subtract.outer(ref_intervals[:, 1],
est_intervals[:, 1]))
# Round distances to a target precision to avoid the situation where
# if the distance is exactly 50ms (and strict=False) it erroneously
# doesn't match the notes because of precision issues.
offset_distances = np.around(offset_distances, decimals=N_DECIMALS)
ref_durations = util.intervals_to_durations(ref_intervals)
offset_tolerances = np.maximum(offset_ratio * ref_durations,
offset_min_tolerance)
offset_hit_matrix = (
cmp_func(offset_distances, offset_tolerances.reshape(-1, 1)))
# check for hits
hits = np.where(offset_hit_matrix)
# Construct the graph input
# Flip graph so that 'matching' is a list of tuples where the first item
# in each tuple is the reference note index, and the second item is the
# estimated note index.
G = {}
for ref_i, est_i in zip(*hits):
if est_i not in G:
G[est_i] = []
G[est_i].append(ref_i)
# Compute the maximum matching
matching = sorted(util._bipartite_match(G).items())
return matching
def match_note_onsets(ref_intervals, est_intervals, onset_tolerance=0.05,
strict=False):
"""Compute a maximum matching between reference and estimated notes,
only taking note onsets into account.
Given two note sequences represented by ``ref_intervals`` and
``est_intervals`` (see :func:`mir_eval.io.load_valued_intervals`), we see
the largest set of correspondences ``(i,j)`` such that the onset of
reference note ``i`` is within ``onset_tolerance`` of the onset of
estimated note ``j``.
Every reference note is matched against at most one estimated note.
Note there are separate functions :func:`match_note_offsets` and
:func:`match_notes` for matching notes based on offsets only or based on
onset, offset, and pitch, respectively. This is because the rules for
matching note onsets and matching note offsets are different.
Parameters
----------
ref_intervals : np.ndarray, shape=(n,2)
Array of reference notes time intervals (onset and offset times)
est_intervals : np.ndarray, shape=(m,2)
Array of estimated notes time intervals (onset and offset times)
onset_tolerance : float > 0
The tolerance for an estimated note's onset deviating from the
reference note's onset, in seconds. Default is 0.05 (50 ms).
strict : bool
If ``strict=False`` (the default), threshold checks for onset matching
are performed using ``<=`` (less than or equal). If ``strict=True``,
the threshold checks are performed using ``<`` (less than).
Returns
-------
matching : list of tuples
A list of matched reference and estimated notes.
``matching[i] == (i, j)`` where reference note ``i`` matches estimated
note ``j``.
"""
# set the comparison function
if strict:
cmp_func = np.less
else:
cmp_func = np.less_equal
# check for onset matches
onset_distances = np.abs(np.subtract.outer(ref_intervals[:, 0],
est_intervals[:, 0]))
# Round distances to a target precision to avoid the situation where
# if the distance is exactly 50ms (and strict=False) it erroneously
# doesn't match the notes because of precision issues.
onset_distances = np.around(onset_distances, decimals=N_DECIMALS)
onset_hit_matrix = cmp_func(onset_distances, onset_tolerance)
# find hits
hits = np.where(onset_hit_matrix)
# Construct the graph input
# Flip graph so that 'matching' is a list of tuples where the first item
# in each tuple is the reference note index, and the second item is the
# estimated note index.
G = {}
for ref_i, est_i in zip(*hits):
if est_i not in G:
G[est_i] = []
G[est_i].append(ref_i)
# Compute the maximum matching
matching = sorted(util._bipartite_match(G).items())
return matching
def match_notes(ref_intervals, ref_pitches, est_intervals, est_pitches,
onset_tolerance=0.05, pitch_tolerance=50.0, offset_ratio=0.2,
offset_min_tolerance=0.05, strict=False):
"""Compute a maximum matching between reference and estimated notes,
subject to onset, pitch and (optionally) offset constraints.
Given two note sequences represented by ``ref_intervals``, ``ref_pitches``,
``est_intervals`` and ``est_pitches``
(see :func:`mir_eval.io.load_valued_intervals`), we seek the largest set
of correspondences ``(i, j)`` such that:
1. The onset of reference note ``i`` is within ``onset_tolerance`` of the
onset of estimated note ``j``.
2. The pitch of reference note ``i`` is within ``pitch_tolerance`` of the
pitch of estimated note ``j``.
3. If ``offset_ratio`` is not ``None``, the offset of reference note ``i``
has to be within ``offset_tolerance`` of the offset of estimated note
``j``, where ``offset_tolerance`` is equal to ``offset_ratio`` times the
reference note's duration, i.e. ``offset_ratio * ref_duration[i]`` where
``ref_duration[i] = ref_intervals[i, 1] - ref_intervals[i, 0]``. If the
resulting ``offset_tolerance`` is less than 0.05 (50 ms), 0.05 is used
instead.
4. If ``offset_ratio`` is ``None``, note offsets are ignored, and only
criteria 1 and 2 are taken into consideration.
Every reference note is matched against at most one estimated note.
This is useful for computing precision/recall metrics for note
transcription.
Note there are separate functions :func:`match_note_onsets` and
:func:`match_note_offsets` for matching notes based on onsets only or based
on offsets only, respectively.
Parameters
----------
ref_intervals : np.ndarray, shape=(n,2)
Array of reference notes time intervals (onset and offset times)
ref_pitches : np.ndarray, shape=(n,)
Array of reference pitch values in Hertz
est_intervals : np.ndarray, shape=(m,2)
Array of estimated notes time intervals (onset and offset times)
est_pitches : np.ndarray, shape=(m,)
Array of estimated pitch values in Hertz
onset_tolerance : float > 0
The tolerance for an estimated note's onset deviating from the
reference note's onset, in seconds. Default is 0.05 (50 ms).
pitch_tolerance : float > 0
The tolerance for an estimated note's pitch deviating from the
reference note's pitch, in cents. Default is 50.0 (50 cents).
offset_ratio : float > 0 or None
The ratio of the reference note's duration used to define the
offset_tolerance. Default is 0.2 (20%), meaning the
``offset_tolerance`` will equal the ``ref_duration * 0.2``, or 0.05 (50
ms), whichever is greater. If ``offset_ratio`` is set to ``None``,
offsets are ignored in the matching.
offset_min_tolerance : float > 0
The minimum tolerance for offset matching. See offset_ratio description
for an explanation of how the offset tolerance is determined. Note:
this parameter only influences the results if ``offset_ratio`` is not
``None``.
strict : bool
If ``strict=False`` (the default), threshold checks for onset, offset,
and pitch matching are performed using ``<=`` (less than or equal). If
``strict=True``, the threshold checks are performed using ``<`` (less
than).
Returns
-------
matching : list of tuples
A list of matched reference and estimated notes.
``matching[i] == (i, j)`` where reference note ``i`` matches estimated
note ``j``.
"""
# set the comparison function
if strict:
cmp_func = np.less
else:
cmp_func = np.less_equal
# check for onset matches
onset_distances = np.abs(np.subtract.outer(ref_intervals[:, 0],
est_intervals[:, 0]))
# Round distances to a target precision to avoid the situation where
# if the distance is exactly 50ms (and strict=False) it erroneously
# doesn't match the notes because of precision issues.
onset_distances = np.around(onset_distances, decimals=N_DECIMALS)
onset_hit_matrix = cmp_func(onset_distances, onset_tolerance)
# check for pitch matches
pitch_distances = np.abs(1200*np.subtract.outer(np.log2(ref_pitches),
np.log2(est_pitches)))
pitch_hit_matrix = cmp_func(pitch_distances, pitch_tolerance)
# check for offset matches if offset_ratio is not None
if offset_ratio is not None:
offset_distances = np.abs(np.subtract.outer(ref_intervals[:, 1],
est_intervals[:, 1]))
# Round distances to a target precision to avoid the situation where
# if the distance is exactly 50ms (and strict=False) it erroneously
# doesn't match the notes because of precision issues.
offset_distances = np.around(offset_distances, decimals=N_DECIMALS)
ref_durations = util.intervals_to_durations(ref_intervals)
offset_tolerances = np.maximum(offset_ratio * ref_durations,
offset_min_tolerance)
offset_hit_matrix = (
cmp_func(offset_distances, offset_tolerances.reshape(-1, 1)))
else:
offset_hit_matrix = True
# check for overall matches
note_hit_matrix = onset_hit_matrix * pitch_hit_matrix * offset_hit_matrix
hits = np.where(note_hit_matrix)
# Construct the graph input
# Flip graph so that 'matching' is a list of tuples where the first item
# in each tuple is the reference note index, and the second item is the
# estimated note index.
G = {}
for ref_i, est_i in zip(*hits):
if est_i not in G:
G[est_i] = []
G[est_i].append(ref_i)
# Compute the maximum matching
matching = sorted(util._bipartite_match(G).items())
return matching
def precision_recall_f1_overlap(ref_intervals, ref_pitches, est_intervals,
est_pitches, onset_tolerance=0.05,
pitch_tolerance=50.0, offset_ratio=0.2,
offset_min_tolerance=0.05, strict=False,
beta=1.0):
"""Compute the Precision, Recall and F-measure of correct vs incorrectly
transcribed notes, and the Average Overlap Ratio for correctly transcribed
notes (see :func:`average_overlap_ratio`). "Correctness" is determined
based on note onset, pitch and (optionally) offset: an estimated note is
assumed correct if its onset is within +-50ms of a reference note and its
pitch (F0) is within +- quarter tone (50 cents) of the corresponding
reference note. If ``offset_ratio`` is ``None``, note offsets are ignored
in the comparison. Otherwise, on top of the above requirements, a correct
returned note is required to have an offset value within 20% (by default,
adjustable via the ``offset_ratio`` parameter) of the reference note's
duration around the reference note's offset, or within
``offset_min_tolerance`` (50 ms by default), whichever is larger.
Examples
--------
>>> ref_intervals, ref_pitches = mir_eval.io.load_valued_intervals(
... 'reference.txt')
>>> est_intervals, est_pitches = mir_eval.io.load_valued_intervals(
... 'estimated.txt')
>>> (precision,
... recall,
... f_measure) = mir_eval.transcription.precision_recall_f1_overlap(
... ref_intervals, ref_pitches, est_intervals, est_pitches)
>>> (precision_no_offset,
... recall_no_offset,
... f_measure_no_offset) = (
... mir_eval.transcription.precision_recall_f1_overlap(
... ref_intervals, ref_pitches, est_intervals, est_pitches,
... offset_ratio=None))
Parameters
----------
ref_intervals : np.ndarray, shape=(n,2)
Array of reference notes time intervals (onset and offset times)
ref_pitches : np.ndarray, shape=(n,)
Array of reference pitch values in Hertz
est_intervals : np.ndarray, shape=(m,2)
Array of estimated notes time intervals (onset and offset times)
est_pitches : np.ndarray, shape=(m,)
Array of estimated pitch values in Hertz
onset_tolerance : float > 0
The tolerance for an estimated note's onset deviating from the
reference note's onset, in seconds. Default is 0.05 (50 ms).
pitch_tolerance : float > 0
The tolerance for an estimated note's pitch deviating from the
reference note's pitch, in cents. Default is 50.0 (50 cents).
offset_ratio : float > 0 or None
The ratio of the reference note's duration used to define the
offset_tolerance. Default is 0.2 (20%), meaning the
``offset_tolerance`` will equal the ``ref_duration * 0.2``, or
``offset_min_tolerance`` (0.05 by default, i.e. 50 ms), whichever is
greater. If ``offset_ratio`` is set to ``None``, offsets are ignored in
the evaluation.
offset_min_tolerance : float > 0
The minimum tolerance for offset matching. See ``offset_ratio``
description for an explanation of how the offset tolerance is
determined. Note: this parameter only influences the results if
``offset_ratio`` is not ``None``.
strict : bool
If ``strict=False`` (the default), threshold checks for onset, offset,
and pitch matching are performed using ``<=`` (less than or equal). If
``strict=True``, the threshold checks are performed using ``<`` (less
than).
beta : float > 0
Weighting factor for f-measure (default value = 1.0).
Returns
-------
precision : float
The computed precision score
recall : float
The computed recall score
f_measure : float
The computed F-measure score
avg_overlap_ratio : float
The computed Average Overlap Ratio score
"""
validate(ref_intervals, ref_pitches, est_intervals, est_pitches)
# When reference notes are empty, metrics are undefined, return 0's
if len(ref_pitches) == 0 or len(est_pitches) == 0:
return 0., 0., 0., 0.
matching = match_notes(ref_intervals, ref_pitches, est_intervals,
est_pitches, onset_tolerance=onset_tolerance,
pitch_tolerance=pitch_tolerance,
offset_ratio=offset_ratio,
offset_min_tolerance=offset_min_tolerance,
strict=strict)
precision = float(len(matching))/len(est_pitches)
recall = float(len(matching))/len(ref_pitches)
f_measure = util.f_measure(precision, recall, beta=beta)
avg_overlap_ratio = average_overlap_ratio(ref_intervals, est_intervals,
matching)
return precision, recall, f_measure, avg_overlap_ratio
def average_overlap_ratio(ref_intervals, est_intervals, matching):
"""Compute the Average Overlap Ratio between a reference and estimated
note transcription. Given a reference and corresponding estimated note,
their overlap ratio (OR) is defined as the ratio between the duration of
the time segment in which the two notes overlap and the time segment
spanned by the two notes combined (earliest onset to latest offset):
>>> OR = ((min(ref_offset, est_offset) - max(ref_onset, est_onset)) /
... (max(ref_offset, est_offset) - min(ref_onset, est_onset)))
The Average Overlap Ratio (AOR) is given by the mean OR computed over all
matching reference and estimated notes. The metric goes from 0 (worst) to 1
(best).
Note: this function assumes the matching of reference and estimated notes
(see :func:`match_notes`) has already been performed and is provided by the
``matching`` parameter. Furthermore, it is highly recommended to validate
the intervals (see :func:`validate_intervals`) before calling this
function, otherwise it is possible (though unlikely) for this function to
attempt a divide-by-zero operation.
Parameters
----------
ref_intervals : np.ndarray, shape=(n,2)
Array of reference notes time intervals (onset and offset times)
est_intervals : np.ndarray, shape=(m,2)
Array of estimated notes time intervals (onset and offset times)
matching : list of tuples
A list of matched reference and estimated notes.
``matching[i] == (i, j)`` where reference note ``i`` matches estimated
note ``j``.
Returns
-------
avg_overlap_ratio : float
The computed Average Overlap Ratio score
"""
ratios = []
for match in matching:
ref_int = ref_intervals[match[0]]
est_int = est_intervals[match[1]]
overlap_ratio = (
(min(ref_int[1], est_int[1]) - max(ref_int[0], est_int[0])) /
(max(ref_int[1], est_int[1]) - min(ref_int[0], est_int[0])))
ratios.append(overlap_ratio)
if len(ratios) == 0:
return 0
else:
return np.mean(ratios)
def onset_precision_recall_f1(ref_intervals, est_intervals,
onset_tolerance=0.05, strict=False, beta=1.0):
"""Compute the Precision, Recall and F-measure of note onsets: an estimated
onset is considered correct if it is within +-50ms of a reference onset.
Note that this metric completely ignores note offset and note pitch. This
means an estimated onset will be considered correct if it matches a
reference onset, even if the onsets come from notes with completely
different pitches (i.e. notes that would not match with
:func:`match_notes`).
Examples
--------
>>> ref_intervals, _ = mir_eval.io.load_valued_intervals(
... 'reference.txt')
>>> est_intervals, _ = mir_eval.io.load_valued_intervals(
... 'estimated.txt')
>>> (onset_precision,
... onset_recall,
... onset_f_measure) = mir_eval.transcription.onset_precision_recall_f1(
... ref_intervals, est_intervals)
Parameters
----------
ref_intervals : np.ndarray, shape=(n,2)
Array of reference notes time intervals (onset and offset times)
est_intervals : np.ndarray, shape=(m,2)
Array of estimated notes time intervals (onset and offset times)
onset_tolerance : float > 0
The tolerance for an estimated note's onset deviating from the
reference note's onset, in seconds. Default is 0.05 (50 ms).
strict : bool
If ``strict=False`` (the default), threshold checks for onset matching
are performed using ``<=`` (less than or equal). If ``strict=True``,
the threshold checks are performed using ``<`` (less than).
beta : float > 0
Weighting factor for f-measure (default value = 1.0).
Returns
-------
precision : float
The computed precision score
recall : float
The computed recall score
f_measure : float
The computed F-measure score
"""
validate_intervals(ref_intervals, est_intervals)
# When reference notes are empty, metrics are undefined, return 0's
if len(ref_intervals) == 0 or len(est_intervals) == 0:
return 0., 0., 0.
matching = match_note_onsets(ref_intervals, est_intervals,
onset_tolerance=onset_tolerance,
strict=strict)
onset_precision = float(len(matching))/len(est_intervals)
onset_recall = float(len(matching))/len(ref_intervals)
onset_f_measure = util.f_measure(onset_precision, onset_recall, beta=beta)
return onset_precision, onset_recall, onset_f_measure
def offset_precision_recall_f1(ref_intervals, est_intervals, offset_ratio=0.2,
offset_min_tolerance=0.05, strict=False,
beta=1.0):
"""Compute the Precision, Recall and F-measure of note offsets: an
estimated offset is considered correct if it is within +-50ms (or 20% of
the ref note duration, which ever is greater) of a reference offset. Note
that this metric completely ignores note onsets and note pitch. This means
an estimated offset will be considered correct if it matches a
reference offset, even if the offsets come from notes with completely
different pitches (i.e. notes that would not match with
:func:`match_notes`).
Examples
--------
>>> ref_intervals, _ = mir_eval.io.load_valued_intervals(
... 'reference.txt')
>>> est_intervals, _ = mir_eval.io.load_valued_intervals(
... 'estimated.txt')
>>> (offset_precision,
... offset_recall,
... offset_f_measure) = mir_eval.transcription.offset_precision_recall_f1(
... ref_intervals, est_intervals)
Parameters
----------
ref_intervals : np.ndarray, shape=(n,2)
Array of reference notes time intervals (onset and offset times)
est_intervals : np.ndarray, shape=(m,2)
Array of estimated notes time intervals (onset and offset times)
offset_ratio : float > 0 or None
The ratio of the reference note's duration used to define the
offset_tolerance. Default is 0.2 (20%), meaning the
``offset_tolerance`` will equal the ``ref_duration * 0.2``, or
``offset_min_tolerance`` (0.05 by default, i.e. 50 ms), whichever is
greater.
offset_min_tolerance : float > 0
The minimum tolerance for offset matching. See ``offset_ratio``
description for an explanation of how the offset tolerance is
determined.
strict : bool
If ``strict=False`` (the default), threshold checks for onset matching
are performed using ``<=`` (less than or equal). If ``strict=True``,
the threshold checks are performed using ``<`` (less than).
beta : float > 0
Weighting factor for f-measure (default value = 1.0).
Returns
-------
precision : float
The computed precision score
recall : float
The computed recall score
f_measure : float
The computed F-measure score
"""
validate_intervals(ref_intervals, est_intervals)
# When reference notes are empty, metrics are undefined, return 0's
if len(ref_intervals) == 0 or len(est_intervals) == 0:
return 0., 0., 0.
matching = match_note_offsets(ref_intervals, est_intervals,
offset_ratio=offset_ratio,
offset_min_tolerance=offset_min_tolerance,
strict=strict)
offset_precision = float(len(matching))/len(est_intervals)
offset_recall = float(len(matching))/len(ref_intervals)
offset_f_measure = util.f_measure(offset_precision, offset_recall,
beta=beta)
return offset_precision, offset_recall, offset_f_measure
def evaluate(ref_intervals, ref_pitches, est_intervals, est_pitches, **kwargs):
"""Compute all metrics for the given reference and estimated annotations.
Examples
--------
>>> ref_intervals, ref_pitches = mir_eval.io.load_valued_intervals(
... 'reference.txt')
>>> est_intervals, est_pitches = mir_eval.io.load_valued_intervals(
... 'estimate.txt')
>>> scores = mir_eval.transcription.evaluate(ref_intervals, ref_pitches,
... est_intervals, est_pitches)
Parameters
----------
ref_intervals : np.ndarray, shape=(n,2)
Array of reference notes time intervals (onset and offset times)
ref_pitches : np.ndarray, shape=(n,)
Array of reference pitch values in Hertz
est_intervals : np.ndarray, shape=(m,2)
Array of estimated notes time intervals (onset and offset times)
est_pitches : np.ndarray, shape=(m,)
Array of estimated pitch values in Hertz
kwargs
Additional keyword arguments which will be passed to the
appropriate metric or preprocessing functions.
Returns
-------
scores : dict
Dictionary of scores, where the key is the metric name (str) and
the value is the (float) score achieved.
"""
# Compute all the metrics
scores = collections.OrderedDict()
# Precision, recall and f-measure taking note offsets into account
kwargs.setdefault('offset_ratio', 0.2)
orig_offset_ratio = kwargs['offset_ratio']
if kwargs['offset_ratio'] is not None:
(scores['Precision'],
scores['Recall'],
scores['F-measure'],
scores['Average_Overlap_Ratio']) = util.filter_kwargs(
precision_recall_f1_overlap, ref_intervals, ref_pitches,
est_intervals, est_pitches, **kwargs)
# Precision, recall and f-measure NOT taking note offsets into account
kwargs['offset_ratio'] = None
(scores['Precision_no_offset'],
scores['Recall_no_offset'],
scores['F-measure_no_offset'],
scores['Average_Overlap_Ratio_no_offset']) = (
util.filter_kwargs(precision_recall_f1_overlap,
ref_intervals, ref_pitches,
est_intervals, est_pitches, **kwargs))
# onset-only metrics
(scores['Onset_Precision'],
scores['Onset_Recall'],
scores['Onset_F-measure']) = (
util.filter_kwargs(onset_precision_recall_f1,
ref_intervals, est_intervals, **kwargs))
# offset-only metrics
kwargs['offset_ratio'] = orig_offset_ratio
if kwargs['offset_ratio'] is not None:
(scores['Offset_Precision'],
scores['Offset_Recall'],
scores['Offset_F-measure']) = (
util.filter_kwargs(offset_precision_recall_f1,
ref_intervals, est_intervals, **kwargs))
return scores
mir_eval-0.7/mir_eval/transcription_velocity.py 0000664 0000000 0000000 00000040517 14203260312 0022143 0 ustar 00root root 0000000 0000000 """
Transcription evaluation, as defined in :mod:`mir_eval.transcription`, does not
take into account the velocities of reference and estimated notes. This
submodule implements a variant of
:func:`mir_eval.transcription.precision_recall_f1_overlap` which
additionally considers note velocity when determining whether a note is
correctly transcribed. This is done by defining a new function
:func:`mir_eval.transcription_velocity.match_notes` which first calls
:func:`mir_eval.transcription.match_notes` to get a note matching based on
onset, offset, and pitch. Then, we follow the evaluation procedure described in
[#hawthorne2018onsets]_ to test whether an estimated note should be considered
correct:
1. Reference velocities are re-scaled to the range [0, 1].
2. A linear regression is performed to estimate global scale and offset
parameters which minimize the L2 distance between matched estimated and
(rescaled) reference notes.
3. The scale and offset parameters are used to rescale estimated
velocities.
4. An estimated/reference note pair which has been matched according to the
onset, offset, and pitch is further only considered correct if the rescaled
velocities are within a predefined threshold, defaulting to 0.1.
:func:`mir_eval.transcription_velocity.match_notes` is used to define a new
variant :func:`mir_eval.transcription_velocity.precision_recall_f1_overlap`
which considers velocity.
Conventions
-----------
This submodule follows the conventions of :mod:`mir_eval.transcription` and
additionally requires velocities to be provided as MIDI velocities in the range
[0, 127].
Metrics
-------
* :func:`mir_eval.transcription_velocity.precision_recall_f1_overlap`: The
precision, recall, F-measure, and Average Overlap Ratio of the note
transcription, where an estimated note is considered correct if its pitch,
onset, velocity and (optionally) offset are sufficiently close to a reference
note.
References
----------
.. [#hawthorne2018onsets] Curtis Hawthorne, Erich Elsen, Jialin Song, Adam
Roberts, Ian Simon, Colin Raffel, Jesse Engel, Sageev Oore, and Douglas
Eck, "Onsets and Frames: Dual-Objective Piano Transcription", Proceedings
of the 19th International Society for Music Information Retrieval
Conference, 2018.
"""
import collections
import numpy as np
from . import transcription
from . import util
def validate(ref_intervals, ref_pitches, ref_velocities, est_intervals,
est_pitches, est_velocities):
"""Checks that the input annotations have valid time intervals, pitches,
and velocities, and throws helpful errors if not.
Parameters
----------
ref_intervals : np.ndarray, shape=(n,2)
Array of reference notes time intervals (onset and offset times)
ref_pitches : np.ndarray, shape=(n,)
Array of reference pitch values in Hertz
ref_velocities : np.ndarray, shape=(n,)
Array of MIDI velocities (i.e. between 0 and 127) of reference notes
est_intervals : np.ndarray, shape=(m,2)
Array of estimated notes time intervals (onset and offset times)
est_pitches : np.ndarray, shape=(m,)
Array of estimated pitch values in Hertz
est_velocities : np.ndarray, shape=(m,)
Array of MIDI velocities (i.e. between 0 and 127) of estimated notes
"""
transcription.validate(ref_intervals, ref_pitches, est_intervals,
est_pitches)
# Check that velocities have the same length as intervals/pitches
if not ref_velocities.shape[0] == ref_pitches.shape[0]:
raise ValueError('Reference velocities must have the same length as '
'pitches and intervals.')
if not est_velocities.shape[0] == est_pitches.shape[0]:
raise ValueError('Estimated velocities must have the same length as '
'pitches and intervals.')
# Check that the velocities are positive
if ref_velocities.size > 0 and np.min(ref_velocities) < 0:
raise ValueError('Reference velocities must be positive.')
if est_velocities.size > 0 and np.min(est_velocities) < 0:
raise ValueError('Estimated velocities must be positive.')
def match_notes(
ref_intervals, ref_pitches, ref_velocities, est_intervals, est_pitches,
est_velocities, onset_tolerance=0.05, pitch_tolerance=50.0,
offset_ratio=0.2, offset_min_tolerance=0.05, strict=False,
velocity_tolerance=0.1):
"""Match notes, taking note velocity into consideration.
This function first calls :func:`mir_eval.transcription.match_notes` to
match notes according to the supplied intervals, pitches, onset, offset,
and pitch tolerances. The velocities of the matched notes are then used to
estimate a slope and intercept which can rescale the estimated velocities
so that they are as close as possible (in L2 sense) to their matched
reference velocities. Velocities are then normalized to the range [0, 1]. A
estimated note is then further only considered correct if its velocity is
within ``velocity_tolerance`` of its matched (according to pitch and
timing) reference note.
Parameters
----------
ref_intervals : np.ndarray, shape=(n,2)
Array of reference notes time intervals (onset and offset times)
ref_pitches : np.ndarray, shape=(n,)
Array of reference pitch values in Hertz
ref_velocities : np.ndarray, shape=(n,)
Array of MIDI velocities (i.e. between 0 and 127) of reference notes
est_intervals : np.ndarray, shape=(m,2)
Array of estimated notes time intervals (onset and offset times)
est_pitches : np.ndarray, shape=(m,)
Array of estimated pitch values in Hertz
est_velocities : np.ndarray, shape=(m,)
Array of MIDI velocities (i.e. between 0 and 127) of estimated notes
onset_tolerance : float > 0
The tolerance for an estimated note's onset deviating from the
reference note's onset, in seconds. Default is 0.05 (50 ms).
pitch_tolerance : float > 0
The tolerance for an estimated note's pitch deviating from the
reference note's pitch, in cents. Default is 50.0 (50 cents).
offset_ratio : float > 0 or None
The ratio of the reference note's duration used to define the
offset_tolerance. Default is 0.2 (20%), meaning the
``offset_tolerance`` will equal the ``ref_duration * 0.2``, or 0.05 (50
ms), whichever is greater. If ``offset_ratio`` is set to ``None``,
offsets are ignored in the matching.
offset_min_tolerance : float > 0
The minimum tolerance for offset matching. See offset_ratio description
for an explanation of how the offset tolerance is determined. Note:
this parameter only influences the results if ``offset_ratio`` is not
``None``.
strict : bool
If ``strict=False`` (the default), threshold checks for onset, offset,
and pitch matching are performed using ``<=`` (less than or equal). If
``strict=True``, the threshold checks are performed using ``<`` (less
than).
velocity_tolerance : float > 0
Estimated notes are considered correct if, after rescaling and
normalization to [0, 1], they are within ``velocity_tolerance`` of a
matched reference note.
Returns
-------
matching : list of tuples
A list of matched reference and estimated notes.
``matching[i] == (i, j)`` where reference note ``i`` matches estimated
note ``j``.
"""
# Compute note matching as usual using standard transcription function
matching = transcription.match_notes(
ref_intervals, ref_pitches, est_intervals, est_pitches,
onset_tolerance, pitch_tolerance, offset_ratio, offset_min_tolerance,
strict)
# Rescale reference velocities to the range [0, 1]
min_velocity, max_velocity = np.min(ref_velocities), np.max(ref_velocities)
# Make the smallest possible range 1 to avoid divide by zero
velocity_range = max(1, max_velocity - min_velocity)
ref_velocities = (ref_velocities - min_velocity)/float(velocity_range)
# Convert matching list-of-tuples to array for fancy indexing
matching = np.array(matching)
# When there is no matching, return an empty list
if matching.size == 0:
return []
# Grab velocities for matched notes
ref_matched_velocities = ref_velocities[matching[:, 0]]
est_matched_velocities = est_velocities[matching[:, 1]]
# Find slope and intercept of line which produces best least-squares fit
# between matched est and ref velocities
slope, intercept = np.linalg.lstsq(
np.vstack([est_matched_velocities,
np.ones(len(est_matched_velocities))]).T,
ref_matched_velocities)[0]
# Re-scale est velocities to match ref
est_matched_velocities = slope*est_matched_velocities + intercept
# Compute the absolute error of (rescaled) estimated velocities vs.
# normalized reference velocities. Error will be in [0, 1]
velocity_diff = np.abs(est_matched_velocities - ref_matched_velocities)
# Check whether each error is within the provided tolerance
velocity_within_tolerance = (velocity_diff < velocity_tolerance)
# Only keep matches whose velocity was within the provided tolerance
matching = matching[velocity_within_tolerance]
# Convert back to list-of-tuple format
matching = [tuple(_) for _ in matching]
return matching
def precision_recall_f1_overlap(
ref_intervals, ref_pitches, ref_velocities, est_intervals, est_pitches,
est_velocities, onset_tolerance=0.05, pitch_tolerance=50.0,
offset_ratio=0.2, offset_min_tolerance=0.05, strict=False,
velocity_tolerance=0.1, beta=1.0):
"""Compute the Precision, Recall and F-measure of correct vs incorrectly
transcribed notes, and the Average Overlap Ratio for correctly transcribed
notes (see :func:`mir_eval.transcription.average_overlap_ratio`).
"Correctness" is determined based on note onset, velocity, pitch and
(optionally) offset. An estimated note is considered correct if
1. Its onset is within ``onset_tolerance`` (default +-50ms) of a
reference note
2. Its pitch (F0) is within +/- ``pitch_tolerance`` (default one
quarter tone, 50 cents) of the corresponding reference note
3. Its velocity, after normalizing reference velocities to the range
[0, 1] and globally rescaling estimated velocities to minimize L2
distance between matched reference notes, is within
``velocity_tolerance`` (default 0.1) the corresponding reference note
4. If ``offset_ratio`` is ``None``, note offsets are ignored in the
comparison. Otherwise, on top of the above requirements, a correct
returned note is required to have an offset value within
`offset_ratio`` (default 20%) of the reference note's duration around
the reference note's offset, or within ``offset_min_tolerance``
(default 50 ms), whichever is larger.
Parameters
----------
ref_intervals : np.ndarray, shape=(n,2)
Array of reference notes time intervals (onset and offset times)
ref_pitches : np.ndarray, shape=(n,)
Array of reference pitch values in Hertz
ref_velocities : np.ndarray, shape=(n,)
Array of MIDI velocities (i.e. between 0 and 127) of reference notes
est_intervals : np.ndarray, shape=(m,2)
Array of estimated notes time intervals (onset and offset times)
est_pitches : np.ndarray, shape=(m,)
Array of estimated pitch values in Hertz
est_velocities : np.ndarray, shape=(n,)
Array of MIDI velocities (i.e. between 0 and 127) of estimated notes
onset_tolerance : float > 0
The tolerance for an estimated note's onset deviating from the
reference note's onset, in seconds. Default is 0.05 (50 ms).
pitch_tolerance : float > 0
The tolerance for an estimated note's pitch deviating from the
reference note's pitch, in cents. Default is 50.0 (50 cents).
offset_ratio : float > 0 or None
The ratio of the reference note's duration used to define the
offset_tolerance. Default is 0.2 (20%), meaning the
``offset_tolerance`` will equal the ``ref_duration * 0.2``, or
``offset_min_tolerance`` (0.05 by default, i.e. 50 ms), whichever is
greater. If ``offset_ratio`` is set to ``None``, offsets are ignored in
the evaluation.
offset_min_tolerance : float > 0
The minimum tolerance for offset matching. See ``offset_ratio``
description for an explanation of how the offset tolerance is
determined. Note: this parameter only influences the results if
``offset_ratio`` is not ``None``.
strict : bool
If ``strict=False`` (the default), threshold checks for onset, offset,
and pitch matching are performed using ``<=`` (less than or equal). If
``strict=True``, the threshold checks are performed using ``<`` (less
than).
velocity_tolerance : float > 0
Estimated notes are considered correct if, after rescaling and
normalization to [0, 1], they are within ``velocity_tolerance`` of a
matched reference note.
beta : float > 0
Weighting factor for f-measure (default value = 1.0).
Returns
-------
precision : float
The computed precision score
recall : float
The computed recall score
f_measure : float
The computed F-measure score
avg_overlap_ratio : float
The computed Average Overlap Ratio score
"""
validate(ref_intervals, ref_pitches, ref_velocities, est_intervals,
est_pitches, est_velocities)
# When reference notes are empty, metrics are undefined, return 0's
if len(ref_pitches) == 0 or len(est_pitches) == 0:
return 0., 0., 0., 0.
matching = match_notes(
ref_intervals, ref_pitches, ref_velocities, est_intervals, est_pitches,
est_velocities, onset_tolerance, pitch_tolerance, offset_ratio,
offset_min_tolerance, strict, velocity_tolerance)
precision = float(len(matching))/len(est_pitches)
recall = float(len(matching))/len(ref_pitches)
f_measure = util.f_measure(precision, recall, beta=beta)
avg_overlap_ratio = transcription.average_overlap_ratio(
ref_intervals, est_intervals, matching)
return precision, recall, f_measure, avg_overlap_ratio
def evaluate(ref_intervals, ref_pitches, ref_velocities, est_intervals,
est_pitches, est_velocities, **kwargs):
"""Compute all metrics for the given reference and estimated annotations.
Parameters
----------
ref_intervals : np.ndarray, shape=(n,2)
Array of reference notes time intervals (onset and offset times)
ref_pitches : np.ndarray, shape=(n,)
Array of reference pitch values in Hertz
ref_velocities : np.ndarray, shape=(n,)
Array of MIDI velocities (i.e. between 0 and 127) of reference notes
est_intervals : np.ndarray, shape=(m,2)
Array of estimated notes time intervals (onset and offset times)
est_pitches : np.ndarray, shape=(m,)
Array of estimated pitch values in Hertz
est_velocities : np.ndarray, shape=(n,)
Array of MIDI velocities (i.e. between 0 and 127) of estimated notes
kwargs
Additional keyword arguments which will be passed to the
appropriate metric or preprocessing functions.
Returns
-------
scores : dict
Dictionary of scores, where the key is the metric name (str) and
the value is the (float) score achieved.
"""
# Compute all the metrics
scores = collections.OrderedDict()
# Precision, recall and f-measure taking note offsets into account
kwargs.setdefault('offset_ratio', 0.2)
if kwargs['offset_ratio'] is not None:
(scores['Precision'],
scores['Recall'],
scores['F-measure'],
scores['Average_Overlap_Ratio']) = util.filter_kwargs(
precision_recall_f1_overlap, ref_intervals, ref_pitches,
ref_velocities, est_intervals, est_pitches, est_velocities,
**kwargs)
# Precision, recall and f-measure NOT taking note offsets into account
kwargs['offset_ratio'] = None
(scores['Precision_no_offset'],
scores['Recall_no_offset'],
scores['F-measure_no_offset'],
scores['Average_Overlap_Ratio_no_offset']) = util.filter_kwargs(
precision_recall_f1_overlap, ref_intervals, ref_pitches,
ref_velocities, est_intervals, est_pitches, est_velocities, **kwargs)
return scores
mir_eval-0.7/mir_eval/util.py 0000664 0000000 0000000 00000071260 14203260312 0016302 0 ustar 00root root 0000000 0000000 '''
This submodule collects useful functionality required across the task
submodules, such as preprocessing, validation, and common computations.
'''
import os
import inspect
import six
import numpy as np
def index_labels(labels, case_sensitive=False):
"""Convert a list of string identifiers into numerical indices.
Parameters
----------
labels : list of strings, shape=(n,)
A list of annotations, e.g., segment or chord labels from an
annotation file.
case_sensitive : bool
Set to True to enable case-sensitive label indexing
(Default value = False)
Returns
-------
indices : list, shape=(n,)
Numerical representation of ``labels``
index_to_label : dict
Mapping to convert numerical indices back to labels.
``labels[i] == index_to_label[indices[i]]``
"""
label_to_index = {}
index_to_label = {}
# If we're not case-sensitive,
if not case_sensitive:
labels = [str(s).lower() for s in labels]
# First, build the unique label mapping
for index, s in enumerate(sorted(set(labels))):
label_to_index[s] = index
index_to_label[index] = s
# Remap the labels to indices
indices = [label_to_index[s] for s in labels]
# Return the converted labels, and the inverse mapping
return indices, index_to_label
def generate_labels(items, prefix='__'):
"""Given an array of items (e.g. events, intervals), create a synthetic label
for each event of the form '(label prefix)(item number)'
Parameters
----------
items : list-like
A list or array of events or intervals
prefix : str
This prefix will be prepended to all synthetically generated labels
(Default value = '__')
Returns
-------
labels : list of str
Synthetically generated labels
"""
return ['{}{}'.format(prefix, n) for n in range(len(items))]
def intervals_to_samples(intervals, labels, offset=0, sample_size=0.1,
fill_value=None):
"""Convert an array of labeled time intervals to annotated samples.
Parameters
----------
intervals : np.ndarray, shape=(n, d)
An array of time intervals, as returned by
:func:`mir_eval.io.load_intervals()` or
:func:`mir_eval.io.load_labeled_intervals()`.
The ``i`` th interval spans time ``intervals[i, 0]`` to
``intervals[i, 1]``.
labels : list, shape=(n,)
The annotation for each interval
offset : float > 0
Phase offset of the sampled time grid (in seconds)
(Default value = 0)
sample_size : float > 0
duration of each sample to be generated (in seconds)
(Default value = 0.1)
fill_value : type(labels[0])
Object to use for the label with out-of-range time points.
(Default value = None)
Returns
-------
sample_times : list
list of sample times
sample_labels : list
array of labels for each generated sample
Notes
-----
Intervals will be rounded down to the nearest multiple
of ``sample_size``.
"""
# Round intervals to the sample size
num_samples = int(np.floor(intervals.max() / sample_size))
sample_indices = np.arange(num_samples, dtype=np.float32)
sample_times = (sample_indices*sample_size + offset).tolist()
sampled_labels = interpolate_intervals(
intervals, labels, sample_times, fill_value)
return sample_times, sampled_labels
def interpolate_intervals(intervals, labels, time_points, fill_value=None):
"""Assign labels to a set of points in time given a set of intervals.
Time points that do not lie within an interval are mapped to `fill_value`.
Parameters
----------
intervals : np.ndarray, shape=(n, 2)
An array of time intervals, as returned by
:func:`mir_eval.io.load_intervals()`.
The ``i`` th interval spans time ``intervals[i, 0]`` to
``intervals[i, 1]``.
Intervals are assumed to be disjoint.
labels : list, shape=(n,)
The annotation for each interval
time_points : array_like, shape=(m,)
Points in time to assign labels. These must be in
non-decreasing order.
fill_value : type(labels[0])
Object to use for the label with out-of-range time points.
(Default value = None)
Returns
-------
aligned_labels : list
Labels corresponding to the given time points.
Raises
------
ValueError
If `time_points` is not in non-decreasing order.
"""
# Verify that time_points is sorted
time_points = np.asarray(time_points)
if np.any(time_points[1:] < time_points[:-1]):
raise ValueError('time_points must be in non-decreasing order')
aligned_labels = [fill_value] * len(time_points)
starts = np.searchsorted(time_points, intervals[:, 0], side='left')
ends = np.searchsorted(time_points, intervals[:, 1], side='right')
for (start, end, lab) in zip(starts, ends, labels):
aligned_labels[start:end] = [lab] * (end - start)
return aligned_labels
def sort_labeled_intervals(intervals, labels=None):
'''Sort intervals, and optionally, their corresponding labels
according to start time.
Parameters
----------
intervals : np.ndarray, shape=(n, 2)
The input intervals
labels : list, optional
Labels for each interval
Returns
-------
intervals_sorted or (intervals_sorted, labels_sorted)
Labels are only returned if provided as input
'''
idx = np.argsort(intervals[:, 0])
intervals_sorted = intervals[idx]
if labels is None:
return intervals_sorted
else:
return intervals_sorted, [labels[_] for _ in idx]
def f_measure(precision, recall, beta=1.0):
"""Compute the f-measure from precision and recall scores.
Parameters
----------
precision : float in (0, 1]
Precision
recall : float in (0, 1]
Recall
beta : float > 0
Weighting factor for f-measure
(Default value = 1.0)
Returns
-------
f_measure : float
The weighted f-measure
"""
if precision == 0 and recall == 0:
return 0.0
return (1 + beta**2)*precision*recall/((beta**2)*precision + recall)
def intervals_to_boundaries(intervals, q=5):
"""Convert interval times into boundaries.
Parameters
----------
intervals : np.ndarray, shape=(n_events, 2)
Array of interval start and end-times
q : int
Number of decimals to round to. (Default value = 5)
Returns
-------
boundaries : np.ndarray
Interval boundary times, including the end of the final interval
"""
return np.unique(np.ravel(np.round(intervals, decimals=q)))
def boundaries_to_intervals(boundaries):
"""Convert an array of event times into intervals
Parameters
----------
boundaries : list-like
List-like of event times. These are assumed to be unique
timestamps in ascending order.
Returns
-------
intervals : np.ndarray, shape=(n_intervals, 2)
Start and end time for each interval
"""
if not np.allclose(boundaries, np.unique(boundaries)):
raise ValueError('Boundary times are not unique or not ascending.')
intervals = np.asarray(list(zip(boundaries[:-1], boundaries[1:])))
return intervals
def adjust_intervals(intervals,
labels=None,
t_min=0.0,
t_max=None,
start_label='__T_MIN',
end_label='__T_MAX'):
"""Adjust a list of time intervals to span the range ``[t_min, t_max]``.
Any intervals lying completely outside the specified range will be removed.
Any intervals lying partially outside the specified range will be cropped.
If the specified range exceeds the span of the provided data in either
direction, additional intervals will be appended. If an interval is
appended at the beginning, it will be given the label ``start_label``; if
an interval is appended at the end, it will be given the label
``end_label``.
Parameters
----------
intervals : np.ndarray, shape=(n_events, 2)
Array of interval start and end-times
labels : list, len=n_events or None
List of labels
(Default value = None)
t_min : float or None
Minimum interval start time.
(Default value = 0.0)
t_max : float or None
Maximum interval end time.
(Default value = None)
start_label : str or float or int
Label to give any intervals appended at the beginning
(Default value = '__T_MIN')
end_label : str or float or int
Label to give any intervals appended at the end
(Default value = '__T_MAX')
Returns
-------
new_intervals : np.ndarray
Intervals spanning ``[t_min, t_max]``
new_labels : list
List of labels for ``new_labels``
"""
# When supplied intervals are empty and t_max and t_min are supplied,
# create one interval from t_min to t_max with the label start_label
if t_min is not None and t_max is not None and intervals.size == 0:
return np.array([[t_min, t_max]]), [start_label]
# When intervals are empty and either t_min or t_max are not supplied,
# we can't append new intervals
elif (t_min is None or t_max is None) and intervals.size == 0:
raise ValueError("Supplied intervals are empty, can't append new"
" intervals")
if t_min is not None:
# Find the intervals that end at or after t_min
first_idx = np.argwhere(intervals[:, 1] >= t_min)
if len(first_idx) > 0:
# If we have events below t_min, crop them out
if labels is not None:
labels = labels[int(first_idx[0]):]
# Clip to the range (t_min, +inf)
intervals = intervals[int(first_idx[0]):]
intervals = np.maximum(t_min, intervals)
if intervals.min() > t_min:
# Lowest boundary is higher than t_min:
# add a new boundary and label
intervals = np.vstack(([t_min, intervals.min()], intervals))
if labels is not None:
labels.insert(0, start_label)
if t_max is not None:
# Find the intervals that begin after t_max
last_idx = np.argwhere(intervals[:, 0] > t_max)
if len(last_idx) > 0:
# We have boundaries above t_max.
# Trim to only boundaries <= t_max
if labels is not None:
labels = labels[:int(last_idx[0])]
# Clip to the range (-inf, t_max)
intervals = intervals[:int(last_idx[0])]
intervals = np.minimum(t_max, intervals)
if intervals.max() < t_max:
# Last boundary is below t_max: add a new boundary and label
intervals = np.vstack((intervals, [intervals.max(), t_max]))
if labels is not None:
labels.append(end_label)
return intervals, labels
def adjust_events(events, labels=None, t_min=0.0,
t_max=None, label_prefix='__'):
"""Adjust the given list of event times to span the range
``[t_min, t_max]``.
Any event times outside of the specified range will be removed.
If the times do not span ``[t_min, t_max]``, additional events will be
added with the prefix ``label_prefix``.
Parameters
----------
events : np.ndarray
Array of event times (seconds)
labels : list or None
List of labels
(Default value = None)
t_min : float or None
Minimum valid event time.
(Default value = 0.0)
t_max : float or None
Maximum valid event time.
(Default value = None)
label_prefix : str
Prefix string to use for synthetic labels
(Default value = '__')
Returns
-------
new_times : np.ndarray
Event times corrected to the given range.
"""
if t_min is not None:
first_idx = np.argwhere(events >= t_min)
if len(first_idx) > 0:
# We have events below t_min
# Crop them out
if labels is not None:
labels = labels[int(first_idx[0]):]
events = events[int(first_idx[0]):]
if events[0] > t_min:
# Lowest boundary is higher than t_min:
# add a new boundary and label
events = np.concatenate(([t_min], events))
if labels is not None:
labels.insert(0, '%sT_MIN' % label_prefix)
if t_max is not None:
last_idx = np.argwhere(events > t_max)
if len(last_idx) > 0:
# We have boundaries above t_max.
# Trim to only boundaries <= t_max
if labels is not None:
labels = labels[:int(last_idx[0])]
events = events[:int(last_idx[0])]
if events[-1] < t_max:
# Last boundary is below t_max: add a new boundary and label
events = np.concatenate((events, [t_max]))
if labels is not None:
labels.append('%sT_MAX' % label_prefix)
return events, labels
def intersect_files(flist1, flist2):
"""Return the intersection of two sets of filepaths, based on the file name
(after the final '/') and ignoring the file extension.
Examples
--------
>>> flist1 = ['/a/b/abc.lab', '/c/d/123.lab', '/e/f/xyz.lab']
>>> flist2 = ['/g/h/xyz.npy', '/i/j/123.txt', '/k/l/456.lab']
>>> sublist1, sublist2 = mir_eval.util.intersect_files(flist1, flist2)
>>> print sublist1
['/e/f/xyz.lab', '/c/d/123.lab']
>>> print sublist2
['/g/h/xyz.npy', '/i/j/123.txt']
Parameters
----------
flist1 : list
first list of filepaths
flist2 : list
second list of filepaths
Returns
-------
sublist1 : list
subset of filepaths with matching stems from ``flist1``
sublist2 : list
corresponding filepaths from ``flist2``
"""
def fname(abs_path):
"""Returns the filename given an absolute path.
Parameters
----------
abs_path :
Returns
-------
"""
return os.path.splitext(os.path.split(abs_path)[-1])[0]
fmap = dict([(fname(f), f) for f in flist1])
pairs = [list(), list()]
for f in flist2:
if fname(f) in fmap:
pairs[0].append(fmap[fname(f)])
pairs[1].append(f)
return pairs
def merge_labeled_intervals(x_intervals, x_labels, y_intervals, y_labels):
r"""Merge the time intervals of two sequences.
Parameters
----------
x_intervals : np.ndarray
Array of interval times (seconds)
x_labels : list or None
List of labels
y_intervals : np.ndarray
Array of interval times (seconds)
y_labels : list or None
List of labels
Returns
-------
new_intervals : np.ndarray
New interval times of the merged sequences.
new_x_labels : list
New labels for the sequence ``x``
new_y_labels : list
New labels for the sequence ``y``
"""
align_check = [x_intervals[0, 0] == y_intervals[0, 0],
x_intervals[-1, 1] == y_intervals[-1, 1]]
if False in align_check:
raise ValueError(
"Time intervals do not align; did you mean to call "
"'adjust_intervals()' first?")
time_boundaries = np.unique(
np.concatenate([x_intervals, y_intervals], axis=0))
output_intervals = np.array(
[time_boundaries[:-1], time_boundaries[1:]]).T
x_labels_out, y_labels_out = [], []
x_label_range = np.arange(len(x_labels))
y_label_range = np.arange(len(y_labels))
for t0, _ in output_intervals:
x_idx = x_label_range[(t0 >= x_intervals[:, 0])]
x_labels_out.append(x_labels[x_idx[-1]])
y_idx = y_label_range[(t0 >= y_intervals[:, 0])]
y_labels_out.append(y_labels[y_idx[-1]])
return output_intervals, x_labels_out, y_labels_out
def _bipartite_match(graph):
"""Find maximum cardinality matching of a bipartite graph (U,V,E).
The input format is a dictionary mapping members of U to a list
of their neighbors in V.
The output is a dict M mapping members of V to their matches in U.
Parameters
----------
graph : dictionary : left-vertex -> list of right vertices
The input bipartite graph. Each edge need only be specified once.
Returns
-------
matching : dictionary : right-vertex -> left vertex
A maximal bipartite matching.
"""
# Adapted from:
#
# Hopcroft-Karp bipartite max-cardinality matching and max independent set
# David Eppstein, UC Irvine, 27 Apr 2002
# initialize greedy matching (redundant, but faster than full search)
matching = {}
for u in graph:
for v in graph[u]:
if v not in matching:
matching[v] = u
break
while True:
# structure residual graph into layers
# pred[u] gives the neighbor in the previous layer for u in U
# preds[v] gives a list of neighbors in the previous layer for v in V
# unmatched gives a list of unmatched vertices in final layer of V,
# and is also used as a flag value for pred[u] when u is in the first
# layer
preds = {}
unmatched = []
pred = dict([(u, unmatched) for u in graph])
for v in matching:
del pred[matching[v]]
layer = list(pred)
# repeatedly extend layering structure by another pair of layers
while layer and not unmatched:
new_layer = {}
for u in layer:
for v in graph[u]:
if v not in preds:
new_layer.setdefault(v, []).append(u)
layer = []
for v in new_layer:
preds[v] = new_layer[v]
if v in matching:
layer.append(matching[v])
pred[matching[v]] = v
else:
unmatched.append(v)
# did we finish layering without finding any alternating paths?
if not unmatched:
unlayered = {}
for u in graph:
for v in graph[u]:
if v not in preds:
unlayered[v] = None
return matching
def recurse(v):
"""Recursively search backward through layers to find alternating
paths. recursion returns true if found path, false otherwise
"""
if v in preds:
L = preds[v]
del preds[v]
for u in L:
if u in pred:
pu = pred[u]
del pred[u]
if pu is unmatched or recurse(pu):
matching[v] = u
return True
return False
for v in unmatched:
recurse(v)
def _outer_distance_mod_n(ref, est, modulus=12):
"""Compute the absolute outer distance modulo n.
Using this distance, d(11, 0) = 1 (modulo 12)
Parameters
----------
ref : np.ndarray, shape=(n,)
Array of reference values.
est : np.ndarray, shape=(m,)
Array of estimated values.
modulus : int
The modulus.
12 by default for octave equivalence.
Returns
-------
outer_distance : np.ndarray, shape=(n, m)
The outer circular distance modulo n.
"""
ref_mod_n = np.mod(ref, modulus)
est_mod_n = np.mod(est, modulus)
abs_diff = np.abs(np.subtract.outer(ref_mod_n, est_mod_n))
return np.minimum(abs_diff, modulus - abs_diff)
def match_events(ref, est, window, distance=None):
"""Compute a maximum matching between reference and estimated event times,
subject to a window constraint.
Given two lists of event times ``ref`` and ``est``, we seek the largest set
of correspondences ``(ref[i], est[j])`` such that
``distance(ref[i], est[j]) <= window``, and each
``ref[i]`` and ``est[j]`` is matched at most once.
This is useful for computing precision/recall metrics in beat tracking,
onset detection, and segmentation.
Parameters
----------
ref : np.ndarray, shape=(n,)
Array of reference values
est : np.ndarray, shape=(m,)
Array of estimated values
window : float > 0
Size of the window.
distance : function
function that computes the outer distance of ref and est.
By default uses ``|ref[i] - est[j]|``
Returns
-------
matching : list of tuples
A list of matched reference and event numbers.
``matching[i] == (i, j)`` where ``ref[i]`` matches ``est[j]``.
"""
if distance is not None:
# Compute the indices of feasible pairings
hits = np.where(distance(ref, est) <= window)
else:
hits = _fast_hit_windows(ref, est, window)
# Construct the graph input
G = {}
for ref_i, est_i in zip(*hits):
if est_i not in G:
G[est_i] = []
G[est_i].append(ref_i)
# Compute the maximum matching
matching = sorted(_bipartite_match(G).items())
return matching
def _fast_hit_windows(ref, est, window):
'''Fast calculation of windowed hits for time events.
Given two lists of event times ``ref`` and ``est``, and a
tolerance window, computes a list of pairings
``(i, j)`` where ``|ref[i] - est[j]| <= window``.
This is equivalent to, but more efficient than the following:
>>> hit_ref, hit_est = np.where(np.abs(np.subtract.outer(ref, est))
... <= window)
Parameters
----------
ref : np.ndarray, shape=(n,)
Array of reference values
est : np.ndarray, shape=(m,)
Array of estimated values
window : float >= 0
Size of the tolerance window
Returns
-------
hit_ref : np.ndarray
hit_est : np.ndarray
indices such that ``|hit_ref[i] - hit_est[i]| <= window``
'''
ref = np.asarray(ref)
est = np.asarray(est)
ref_idx = np.argsort(ref)
ref_sorted = ref[ref_idx]
left_idx = np.searchsorted(ref_sorted, est - window, side='left')
right_idx = np.searchsorted(ref_sorted, est + window, side='right')
hit_ref, hit_est = [], []
for j, (start, end) in enumerate(zip(left_idx, right_idx)):
hit_ref.extend(ref_idx[start:end])
hit_est.extend([j] * (end - start))
return hit_ref, hit_est
def validate_intervals(intervals):
"""Checks that an (n, 2) interval ndarray is well-formed, and raises errors
if not.
Parameters
----------
intervals : np.ndarray, shape=(n, 2)
Array of interval start/end locations.
"""
# Validate interval shape
if intervals.ndim != 2 or intervals.shape[1] != 2:
raise ValueError('Intervals should be n-by-2 numpy ndarray, '
'but shape={}'.format(intervals.shape))
# Make sure no times are negative
if (intervals < 0).any():
raise ValueError('Negative interval times found')
# Make sure all intervals have strictly positive duration
if (intervals[:, 1] <= intervals[:, 0]).any():
raise ValueError('All interval durations must be strictly positive')
def validate_events(events, max_time=30000.):
"""Checks that a 1-d event location ndarray is well-formed, and raises
errors if not.
Parameters
----------
events : np.ndarray, shape=(n,)
Array of event times
max_time : float
If an event is found above this time, a ValueError will be raised.
(Default value = 30000.)
"""
# Make sure no event times are huge
if (events > max_time).any():
raise ValueError('An event at time {} was found which is greater than '
'the maximum allowable time of max_time = {} (did you'
' supply event times in '
'seconds?)'.format(events.max(), max_time))
# Make sure event locations are 1-d np ndarrays
if events.ndim != 1:
raise ValueError('Event times should be 1-d numpy ndarray, '
'but shape={}'.format(events.shape))
# Make sure event times are increasing
if (np.diff(events) < 0).any():
raise ValueError('Events should be in increasing order.')
def validate_frequencies(frequencies, max_freq, min_freq,
allow_negatives=False):
"""Checks that a 1-d frequency ndarray is well-formed, and raises
errors if not.
Parameters
----------
frequencies : np.ndarray, shape=(n,)
Array of frequency values
max_freq : float
If a frequency is found above this pitch, a ValueError will be raised.
(Default value = 5000.)
min_freq : float
If a frequency is found below this pitch, a ValueError will be raised.
(Default value = 20.)
allow_negatives : bool
Whether or not to allow negative frequency values.
"""
# If flag is true, map frequencies to their absolute value.
if allow_negatives:
frequencies = np.abs(frequencies)
# Make sure no frequency values are huge
if (np.abs(frequencies) > max_freq).any():
raise ValueError('A frequency of {} was found which is greater than '
'the maximum allowable value of max_freq = {} (did '
'you supply frequency values in '
'Hz?)'.format(frequencies.max(), max_freq))
# Make sure no frequency values are tiny
if (np.abs(frequencies) < min_freq).any():
raise ValueError('A frequency of {} was found which is less than the '
'minimum allowable value of min_freq = {} (did you '
'supply frequency values in '
'Hz?)'.format(frequencies.min(), min_freq))
# Make sure frequency values are 1-d np ndarrays
if frequencies.ndim != 1:
raise ValueError('Frequencies should be 1-d numpy ndarray, '
'but shape={}'.format(frequencies.shape))
def has_kwargs(function):
r'''Determine whether a function has \*\*kwargs.
Parameters
----------
function : callable
The function to test
Returns
-------
True if function accepts arbitrary keyword arguments.
False otherwise.
'''
if six.PY2:
return inspect.getargspec(function).keywords is not None
else:
sig = inspect.signature(function)
for param in sig.parameters.values():
if param.kind == param.VAR_KEYWORD:
return True
return False
def filter_kwargs(_function, *args, **kwargs):
"""Given a function and args and keyword args to pass to it, call the function
but using only the keyword arguments which it accepts. This is equivalent
to redefining the function with an additional \*\*kwargs to accept slop
keyword args.
If the target function already accepts \*\*kwargs parameters, no filtering
is performed.
Parameters
----------
_function : callable
Function to call. Can take in any number of args or kwargs
"""
if has_kwargs(_function):
return _function(*args, **kwargs)
# Get the list of function arguments
func_code = six.get_function_code(_function)
function_args = func_code.co_varnames[:func_code.co_argcount]
# Construct a dict of those kwargs which appear in the function
filtered_kwargs = {}
for kwarg, value in list(kwargs.items()):
if kwarg in function_args:
filtered_kwargs[kwarg] = value
# Call the function with the supplied args and the filtered kwarg dict
return _function(*args, **filtered_kwargs)
def intervals_to_durations(intervals):
"""Converts an array of n intervals to their n durations.
Parameters
----------
intervals : np.ndarray, shape=(n, 2)
An array of time intervals, as returned by
:func:`mir_eval.io.load_intervals()`.
The ``i`` th interval spans time ``intervals[i, 0]`` to
``intervals[i, 1]``.
Returns
-------
durations : np.ndarray, shape=(n,)
Array of the duration of each interval.
"""
validate_intervals(intervals)
return np.abs(np.diff(intervals, axis=-1)).flatten()
def hz_to_midi(freqs):
'''Convert Hz to MIDI numbers
Parameters
----------
freqs : number or ndarray
Frequency/frequencies in Hz
Returns
-------
midi : number or ndarray
MIDI note numbers corresponding to input frequencies.
Note that these may be fractional.
'''
return 12.0 * (np.log2(freqs) - np.log2(440.0)) + 69.0
def midi_to_hz(midi):
'''Convert MIDI numbers to Hz
Parameters
----------
midi : number or ndarray
MIDI notes
Returns
-------
freqs : number or ndarray
Frequency/frequencies in Hz corresponding to `midi`
'''
return 440.0 * (2.0 ** ((midi - 69.0)/12.0))
mir_eval-0.7/setup.py 0000664 0000000 0000000 00000002100 14203260312 0014652 0 ustar 00root root 0000000 0000000 from setuptools import setup
with open('README.rst') as file:
long_description = file.read()
setup(
name='mir_eval',
version='0.7',
description='Common metrics for common audio/music processing tasks.',
author='Colin Raffel',
author_email='craffel@gmail.com',
url='https://github.com/craffel/mir_eval',
packages=['mir_eval'],
long_description=long_description,
classifiers=[
"License :: OSI Approved :: MIT License",
"Programming Language :: Python",
'Development Status :: 5 - Production/Stable',
"Intended Audience :: Developers",
"Topic :: Multimedia :: Sound/Audio :: Analysis",
"Programming Language :: Python :: 2.7",
"Programming Language :: Python :: 3",
],
keywords='audio music mir dsp',
license='MIT',
install_requires=[
'numpy >= 1.7.0',
'scipy >= 1.0.0',
'future',
'six'
],
extras_require={
'display': ['matplotlib>=1.5.0',
'scipy>=1.0.0'],
'testing': ['matplotlib>=2.1.0,<3']
}
)
mir_eval-0.7/tests/ 0000775 0000000 0000000 00000000000 14203260312 0014311 5 ustar 00root root 0000000 0000000 mir_eval-0.7/tests/baseline_images/ 0000775 0000000 0000000 00000000000 14203260312 0017420 5 ustar 00root root 0000000 0000000 mir_eval-0.7/tests/baseline_images/test_display/ 0000775 0000000 0000000 00000000000 14203260312 0022124 5 ustar 00root root 0000000 0000000 mir_eval-0.7/tests/baseline_images/test_display/events.png 0000664 0000000 0000000 00000022707 14203260312 0024146 0 ustar 00root root 0000000 0000000 PNG
IHDR 5 sBIT|d pHYs a a?i IDATx{uײ˲9mbbyHLVn;bZff?~myvkw7)a,\y